]> git.donarmstrong.com Git - ape.git/blobdiff - R/ace.R
improved ace()
[ape.git] / R / ace.R
diff --git a/R/ace.R b/R/ace.R
index b45a88e2009cc5ee252c2e1e1246c6db0162423f..91994bed1ffeb9a35fd2dfdc9a9bc20aa2d11384 100644 (file)
--- a/R/ace.R
+++ b/R/ace.R
@@ -1,4 +1,4 @@
-## ace.R (2009-05-10)
+## ace.R (2009-06-10)
 
 ##   Ancestral Character Estimation
 
@@ -149,10 +149,13 @@ as the number of categories in `x'")
         rate[rate == 0] <- np + 1 # to avoid 0's since we will use this an numeric indexing
 
         liks <- matrix(0, nb.tip + nb.node, nl)
-        for (i in 1:nb.tip) liks[i, x[i]] <- 1
+        TIPS <- 1:nb.tip
+        liks[cbind(TIPS, x)] <- 1
         phy <- reorder(phy, "pruningwise")
 
         Q <- matrix(0, nl, nl)
+        ## from Rich FitzJohn:
+        comp <- numeric(nb.tip + nb.node) # Storage...
         dev <- function(p, output.liks = FALSE) {
             Q[] <- c(p, 0)[rate]
             diag(Q) <- -rowSums(Q)
@@ -161,14 +164,14 @@ as the number of categories in `x'")
                 anc <- phy$edge[i, 1]
                 des1 <- phy$edge[i, 2]
                 des2 <- phy$edge[j, 2]
-                tmp <- eigen(Q * phy$edge.length[i], symmetric = FALSE)
-                P1 <- tmp$vectors %*% diag(exp(tmp$values)) %*% solve(tmp$vectors)
-                tmp <- eigen(Q * phy$edge.length[j], symmetric = FALSE)
-                P2 <- tmp$vectors %*% diag(exp(tmp$values)) %*% solve(tmp$vectors)
-                liks[anc, ] <- P1 %*% liks[des1, ] * P2 %*% liks[des2, ]
+                v.l <- matexpo(Q * phy$edge.length[i]) %*% liks[des1, ]
+                v.r <- matexpo(Q * phy$edge.length[j]) %*% liks[des2, ]
+                v <- v.l * v.r
+                comp[anc] <- sum(v)
+                liks[anc, ] <- v/comp[anc]
             }
-            if (output.liks) return(liks[-(1:nb.tip), ])
-            - 2 * log(sum(liks[nb.tip + 1, ]))
+            if (output.liks) return(liks[-TIPS, ])
+            -2 * sum(log(comp[-TIPS]))
         }
         out <- nlminb(rep(ip, length.out = np), function(p) dev(p),
                       lower = rep(0, np), upper = rep(Inf, np))
@@ -184,10 +187,8 @@ as the number of categories in `x'")
         else obj$se <- sqrt(diag(solve(h)))
         obj$index.matrix <- index.matrix
         if (CI) {
-            lik.anc <- dev(obj$rates, TRUE)
-            lik.anc <- lik.anc / rowSums(lik.anc)
-            colnames(lik.anc) <- lvls
-            obj$lik.anc <- lik.anc
+            obj$lik.anc <- dev(obj$rates, TRUE)
+            colnames(obj$lik.anc) <- lvls
         }
     }
     obj$call <- match.call()