]> git.donarmstrong.com Git - ape.git/blobdiff - src/nj.c
faster nj()!!!
[ape.git] / src / nj.c
index af53abbd91ce9669d3038b704be47d7cba8d4ff5..355c68f34086a082d250d481604689afa1185358 100644 (file)
--- a/src/nj.c
+++ b/src/nj.c
@@ -1,4 +1,4 @@
-/* nj.c       2009-07-09 */
+/* nj.c       2009-07-17 */
 
 /* Copyright 2006-2009 Emmanuel Paradis
 
@@ -58,29 +58,9 @@ j 4  2  6  9
        return(sum);
 }
 
-#define GET_I_AND_J                                               \
-/* Find the 'R' indices of the two corresponding OTUs */          \
-/* The indices of the first element of the pair in the            \
-   distance matrix are n-1 times 1, n-2 times 2, n-3 times 3,     \
-   ..., once n-1. Given this, the algorithm below is quite        \
-   straightforward.*/                                             \
-    i = 0;                                                        \
-    for (OTU1 = 1; OTU1 < n; OTU1++) {                            \
-        i += n - OTU1;                                            \
-       if (i >= smallest + 1) break;                             \
-    }                                                             \
-    /* Finding the second OTU is easier! */                       \
-    OTU2 = smallest + 1 + OTU1 - n*(OTU1 - 1) + OTU1*(OTU1 - 1)/2
-
-#define SET_CLADE                           \
-/* give the node and tip numbers to edge */ \
-    edge2[k] = otu_label[OTU1 - 1];         \
-    edge2[k + 1] = otu_label[OTU2 - 1];     \
-    edge1[k] = edge1[k + 1] = cur_nod
-
 void nj(double *D, int *N, int *edge1, int *edge2, double *edge_length)
 {
-       double SUMD, *S, Sdist, Ndist, *new_dist, A, B, *DI, d_i, x, y;
+       double *S, Sdist, Ndist, *new_dist, A, B, smallest_S, *DI, d_i, x, y;
        int n, i, j, k, ij, smallest, OTU1, OTU2, cur_nod, o_l, *otu_label;
 
        S = &Sdist;
@@ -91,7 +71,7 @@ void nj(double *D, int *N, int *edge1, int *edge2, double *edge_length)
        n = *N;
        cur_nod = 2*n - 2;
 
-       S = (double*)R_alloc(n*(n - 1)/2, sizeof(double));
+       S = (double*)R_alloc(n, sizeof(double));
        new_dist = (double*)R_alloc(n*(n - 1)/2, sizeof(double));
        otu_label = (int*)R_alloc(n, sizeof(int));
        DI = (double*)R_alloc(n - 2, sizeof(double));
@@ -101,27 +81,28 @@ void nj(double *D, int *N, int *edge1, int *edge2, double *edge_length)
 
        while (n > 3) {
 
-               SUMD = 0;
-               for (i = 0; i < n*(n - 1)/2; i++) SUMD += D[i];
+               for (i = 0; i < n; i++)
+                       S[i] = sum_dist_to_i(n, D, i + 1);
 
                ij = 0;
-               for (i = 1; i < n; i++) {
-                       for (j = i + 1; j <= n; j++) {
-                               A = sum_dist_to_i(n, D, i) - D[ij];
-                               B = sum_dist_to_i(n, D, j) - D[ij];
-                               S[ij] = (A + B)/(2*n - 4) + 0.5*D[ij]
-                                       + (SUMD - A - B - D[ij])/(n - 2);
+               smallest_S = 1e50;
+               B = n - 2;
+               for (i = 0; i < n - 1; i++) {
+                       for (j = i + 1; j < n; j++) {
+                               A = D[ij] - (S[i] + S[j])/B;
+                               if (A < smallest_S) {
+                                       OTU1 = i + 1;
+                                       OTU2 = j + 1;
+                                       smallest_S = A;
+                                       smallest = ij;
+                               }
                                ij++;
                        }
                }
 
-               /* find the 'C' index of the smallest value of S */
-               smallest = 0;
-               for (i = 1; i < n*(n - 1)/2; i++)
-                       if (S[smallest] > S[i]) smallest = i;
-
-               GET_I_AND_J;
-               SET_CLADE;
+               edge2[k] = otu_label[OTU1 - 1];
+               edge2[k + 1] = otu_label[OTU2 - 1];
+               edge1[k] = edge1[k + 1] = cur_nod;
 
                /* get the distances between all OTUs but the 2 selected ones
                   and the latter:
@@ -151,9 +132,11 @@ void nj(double *D, int *N, int *edge1, int *edge2, double *edge_length)
                        OTU2 = i;
                }
                if (OTU1 != 1)
-                       for (i = OTU1 - 1; i > 0; i--) otu_label[i] = otu_label[i - 1];
+                       for (i = OTU1 - 1; i > 0; i--)
+                               otu_label[i] = otu_label[i - 1];
                if (OTU2 != n)
-                       for (i = OTU2; i <= n; i++) otu_label[i - 1] = otu_label[i];
+                       for (i = OTU2; i <= n; i++)
+                               otu_label[i - 1] = otu_label[i];
                otu_label[0] = cur_nod;
 
                for (i = 1; i < n; i++) {