]> git.donarmstrong.com Git - ape.git/blob - R/chronos.R
new chronos files and a bunch of various improvements
[ape.git] / R / chronos.R
1 ## chronos.R (2013-01-03)
2
3 ##   Molecular Dating With Penalized and Maximum Likelihood
4
5 ## Copyright 2013 Emmanuel Paradis
6
7 ## This file is part of the R-package `ape'.
8 ## See the file ../COPYING for licensing issues.
9
10 .chronos.ctrl <-
11     list(tol = 1e-8, iter.max = 1e4, eval.max = 1e4, nb.rate.cat = 10,
12          dual.iter.max = 20)
13
14 makeChronosCalib <-
15     function(phy, node = "root", age.min = 1, age.max = age.min,
16              interactive = FALSE, soft.bounds = FALSE)
17 {
18     n <- Ntip(phy)
19     if (interactive) {
20         plot(phy)
21         cat("Click close to a node and enter the ages (right-click to exit)\n\n")
22         node <- integer()
23         age.min <- age.max <- numeric()
24         repeat {
25             ans <- identify(phy, quiet = TRUE)
26             if (is.null(ans)) break
27             NODE <- ans$nodes
28             nodelabels(node = NODE, col = "white", bg = "blue")
29             cat("constraints for node ", NODE, sep = "")
30             cat("\n  youngest age: ")
31             AGE.MIN <- as.numeric(readLines(n = 1))
32             cat("  oldest age (ENTER if not applicable): ")
33             AGE.MAX <- as.numeric(readLines(n = 1))
34             node <- c(node, NODE)
35             age.min <- c(age.min, AGE.MIN)
36             age.max <- c(age.max, AGE.MAX)
37         }
38         s <- is.na(age.max)
39         if (any(s)) age.max[s] <- age.min[s]
40     } else {
41         if (identical(node, "root")) node <- n + 1L
42     }
43
44     if (any(node <= n))
45         stop("node numbers should be greater than the number of tips")
46
47     diff.age <- which(age.max < age.min)
48     if (length(diff.age)) {
49         msg <- "'old age' less than 'young age' for node"
50         if (length(diff.age) > 1) msg <- paste(msg, "s", sep = "")
51         stop(paste(msg, paste(node[diff.age], collapse = ", ")))
52     }
53
54     data.frame(node, age.min, age.max, soft.bounds = soft.bounds)
55 }
56
57 chronos.control <- function(...)
58 {
59     dots <- list(...)
60     x <- .chronos.ctrl
61     if (length(dots)) {
62         chk.nms <- names(dots) %in% names(x)
63         if (any(!chk.nms)) {
64             warning("some control parameter names do not match: they were ignored")
65             dots <- dots[chk.nms]
66         }
67         x[names(dots)] <- dots
68     }
69     x
70 }
71
72 chronos <-
73     function(phy, lambda = 1, model = "correlated", quiet = FALSE,
74              calibration = makeChronosCalib(phy),
75              control = chronos.control())
76 {
77     model <- match.arg(tolower(model), c("correlated", "relaxed", "discrete"))
78     n <- Ntip(phy)
79     ROOT <- n + 1L
80     m <- phy$Nnode
81     el <- phy$edge.length
82     if (any(el < 0)) stop("some branch lengths are negative")
83     e1 <- phy$edge[, 1L]
84     e2 <- phy$edge[, 2L]
85     N <- length(e1)
86     TIPS <- 1:n
87     EDGES <- 1:N
88
89     tol <- control$tol
90
91     node <- calibration$node
92     age.min <- calibration$age.min
93     age.max <- calibration$age.max
94
95     if (model == "correlated") {
96 ### `basal' contains the indices of the basal edges
97 ### (ie, linked to the root):
98         basal <- which(e1 == ROOT)
99         Nbasal <- length(basal)
100
101 ### 'ind1' contains the index of all nonbasal edges, and 'ind2' the
102 ### index of the edges where these edges come from (ie, they contain
103 ### pairs of contiguous edges), eg:
104
105 ###         ___b___    ind1  ind2
106 ###        |           |   ||   |
107 ### ___a___|           | b || a |
108 ###        |           | c || a |
109 ###        |___c___    |   ||   |
110
111         ind1 <- EDGES[-basal]
112         ind2 <- match(e1[EDGES[-basal]], e2)
113     }
114
115     age <- numeric(n + m)
116
117 ### This bit sets 'ini.time' and should result in no negative branch lengths
118
119     if (!quiet) cat("\nSetting initial dates...\n")
120     seq.nod <- .Call("seq_root2tip", phy$edge, n, phy$Nnode, PACKAGE = "ape")
121
122     ii <- 1L
123     repeat {
124         ini.time <- age
125         ini.time[ROOT:(n + m)] <- NA
126
127         ini.time[node] <-
128             if (is.null(age.max)) age.min
129             else runif(length(node), age.min, age.max) # (age.min + age.max) / 2
130
131         ## if no age given for the root, find one approximately:
132         if (is.na(ini.time[ROOT]))
133             ini.time[ROOT] <- if (is.null(age.max)) 3 * max(age.min) else 3 * max(age.max)
134
135         ISnotNA.ALL <- unlist(lapply(seq.nod, function(x) sum(!is.na(ini.time[x]))))
136         o <- order(ISnotNA.ALL, decreasing = TRUE)
137
138         for (y in seq.nod[o]) {
139             ISNA <- is.na(ini.time[y])
140             if (any(ISNA)) {
141                 i <- 2L # we know the 1st value is not NA, so we start at the 2nd one
142                 while (i <= length(y)) {
143                     if (ISNA[i]) { # we stop at the next NA
144                         j <- i + 1L
145                         while (ISNA[j]) j <- j + 1L # look for the next non-NA
146                         nb.val <- j - i
147                         by <- (ini.time[y[i - 1L]] - ini.time[y[j]]) / (nb.val + 1)
148                         ini.time[y[i:(j - 1L)]] <- ini.time[y[i - 1L]] - by * seq_len(nb.val)
149                         i <- j + 1L
150                     } else i <- i + 1L
151                 }
152             }
153         }
154         if (all(ini.time[e1] - ini.time[e2] >= 0)) break
155         ii <- ii + 1L
156         if (ii > 1000)
157             stop("cannot find reasonable starting dates after 1000 tries:
158 maybe you need to adjust the calibration dates")
159     }
160 ### 'ini.time' set
161
162     #ini.time[ROOT:(n+m)] <- branching.times(chr.dis)
163     ## ini.time[ROOT:(n+m)] <- ini.time[ROOT:(n+m)] + rnorm(m, 0, 5)
164     #print(ini.time)
165
166
167 ### Setting 'ini.rate'
168     ini.rate <- el/(ini.time[e1] - ini.time[e2])
169
170     if (model == "discrete") {
171         Nb.rates <- control$nb.rate.cat
172         minmax <- range(ini.rate)
173         if (Nb.rates == 1) {
174             ini.rate <- sum(minmax)/2
175         } else {
176             inc <- diff(minmax)/Nb.rates
177             ini.rate <- seq(minmax[1] + inc/2, minmax[2] - inc/2, inc)
178             ini.freq <- rep(1/Nb.rates, Nb.rates - 1)
179             lower.freq <- rep(0, Nb.rates - 1)
180             upper.freq <- rep(1, Nb.rates - 1)
181         }
182     } else Nb.rates <- N
183 ## 'ini.rate' set
184
185 ### Setting bounds for the node ages
186
187     ## `unknown.ages' will contain the index of the nodes of unknown age:
188     unknown.ages <- 1:m + n
189
190     ## initialize vectors for all nodes:
191     lower.age <- rep(tol, m)
192     upper.age <- rep(1/tol, m)
193
194     lower.age[node - n] <- age.min
195     upper.age[node - n] <- age.max
196
197     ## find nodes known within an interval:
198     ii <- which(age.min != age.max)
199     ## drop them from 'node' since they will be estimated:
200     if (length(ii)) {
201         node <- node[-ii]
202         if (length(node))
203             age[node] <- age.min[-ii] # update 'age'
204     } else age[node] <- age.min
205
206     ## finally adjust the 3 vectors:
207     if (length(node)) {
208         unknown.ages <- unknown.ages[n - node] # 'n - node' is simplification for '-(node - n)'
209         lower.age <- lower.age[n - node]
210         upper.age <- upper.age[n - node]
211     }
212 ### Bounds for the node ages set
213
214     ## 'known.ages' contains the index of all nodes
215     ## (internal and terminal) of known age:
216     known.ages <- c(TIPS, node)
217
218     ## the bounds for the rates:
219     lower.rate <- rep(tol, Nb.rates)
220     upper.rate <- rep(100 - tol, Nb.rates) # needs to be adjusted to higher values?
221
222 ### Gradient
223     degree_node <- tabulate(phy$edge)
224     eta_i <- degree_node[e1]
225     eta_i[e2 <= n] <- 1L
226     ## eta_i[i] is the number of contiguous branches for branch 'i'
227
228     ## use of a list of indices is slightly faster than an incidence matrix
229     ## and takes much less memory (60 Kb vs. 8 Mb for n = 500)
230     X <- vector("list", N)
231     for (i in EDGES) {
232         j <- integer()
233         if (e1[i] != ROOT) j <- c(j, which(e2 == e1[i]))
234         if (e2[i] >= n) j <- c(j, which(e1 == e2[i]))
235         X[[i]] <- j
236     }
237     ## X is a list whose i-th element gives the indices of the branches
238     ## that are contiguous to branch 'i'
239
240     ## D_ki and A_ki are defined in the SI of the paper
241     D_ki <- match(unknown.ages, e2)
242     A_ki <- lapply(unknown.ages, function(x) which(x == e1))
243
244     gradient.poisson <- function(rate, node.time) {
245         age[unknown.ages] <- node.time
246         real.edge.length <- age[e1] - age[e2]
247         #if (any(real.edge.length < 0))
248         #    return(numeric(N + length(unknown.ages)))
249         ## gradient for the rates:
250         gr <- el/rate - real.edge.length
251
252         ## gradient for the dates:
253         tmp <- el/real.edge.length - rate
254         gr.dates <- sapply(A_ki, function(x) sum(tmp[x])) - tmp[D_ki]
255
256         c(gr, gr.dates)
257     }
258
259     ## gradient of the penalized lik (must be multiplied by -1 before calling nlminb)
260     gradient <-
261         switch(model,
262                "correlated" =
263                function(rate, node.time) {
264                    gr <- gradient.poisson(rate, node.time)
265                    #if (all(gr == 0)) return(gr)
266
267                    ## contribution of the penalty for the rates:
268                    gr[RATE] <- gr[RATE] - lambda * 2 * (eta_i * rate - sapply(X, function(x) sum(rate[x])))
269                    ## the contribution of the root variance term:
270                    if (Nbasal == 2) { # the simpler formulae if there's a basal dichotomy
271                        i <- basal[1]
272                        j <- basal[2]
273                        gr[i] <- gr[i] - lambda * (rate[i] - rate[j])
274                        gr[j] <- gr[j] - lambda * (rate[j] - rate[i])
275                    } else { # the general case
276                        for (i in 1:Nbasal)
277                            j <- basal[i]
278                            gr[j] <- gr[j] -
279                                lambda*2*(rate[j]*(1 - 1/Nbasal) - sum(rate[basal[-i]])/Nbasal)/(Nbasal - 1)
280                    }
281                    gr
282                },
283                "relaxed" =
284                function(rate, node.time) {
285                    gr <- gradient.poisson(rate, node.time)
286                    #if (all(gr == 0)) return(gr)
287
288                    ## contribution of the penalty for the rates:
289                    mean.rate <- mean(rate)
290                    ## rank(rate)/Nb.rates is the same than ecdf(rate)(rate) but faster
291                    gr[RATE] <- gr[RATE] + lambda*2*dgamma(rate, mean.rate)*(rank(rate)/Nb.rates - pgamma(rate, mean.rate))
292                    gr
293                },
294                "discrete" = NULL)
295
296     log.lik.poisson <- function(rate, node.time) {
297         age[unknown.ages] <- node.time
298         real.edge.length <- age[e1] - age[e2]
299         if (isTRUE(any(real.edge.length < 0))) return(-1e100)
300         B <- rate * real.edge.length
301         sum(el * log(B) - B - lfactorial(el))
302     }
303
304 ### penalized log-likelihood
305     penal.loglik <-
306         switch(model,
307                "correlated" =
308                function(rate, node.time) {
309                    loglik <- log.lik.poisson(rate, node.time)
310                    if (!is.finite(loglik)) return(-1e100)
311                    loglik - lambda * (sum((rate[ind1] - rate[ind2])^2)
312                                       + var(rate[basal]))
313                },
314                "relaxed" =
315                function(rate, node.time) {
316                    loglik <- log.lik.poisson(rate, node.time)
317                    if (!is.finite(loglik)) return(-1e100)
318                    mu <- mean(rate)
319                    ## loglik - lambda * sum((1:N/N - pbeta(sort(rate), mu/(1 + mu), 1))^2) # avec loi beta
320                    ## loglik - lambda * sum((1:N/N - pcauchy(sort(rate)))^2) # avec loi Cauchy
321                    loglik - lambda * sum((1:N/N - pgamma(sort(rate), mean(rate)))^2) # avec loi Gamma
322                },
323                "discrete" =
324                if (Nb.rates == 1)
325                    function(rate, node.time) log.lik.poisson(rate, node.time)
326                else function(rate, node.time, freq) {
327                    if (isTRUE(sum(freq) > 1)) return(-1e100)
328                    rate.freq <- sum(c(freq, 1 - sum(freq)) * rate)
329                    log.lik.poisson(rate.freq, node.time)
330                })
331
332     opt.ctrl <- list(eval.max = control$eval.max, iter.max = control$iter.max)
333
334     ## the following capitalized vectors give the indices of
335     ## the parameters once they are concatenated in 'p'
336     RATE <- 1:Nb.rates
337     AGE <- Nb.rates + 1:length(unknown.ages)
338
339     if (model == "discrete") {
340         if (Nb.rates == 1) {
341             start.para <- c(ini.rate, ini.time[unknown.ages])
342             f <- function(p) -penal.loglik(p[RATE], p[AGE])
343             g <- NULL
344             LOW <- c(lower.rate, lower.age)
345             UP <- c(upper.rate, upper.age)
346         } else {
347             FREQ <- length(RATE) + length(AGE) + 1:(Nb.rates - 1)
348             start.para <- c(ini.rate, ini.time[unknown.ages], ini.freq)
349             f <- function(p) -penal.loglik(p[RATE], p[AGE], p[FREQ])
350             g <- NULL
351             LOW <- c(lower.rate, lower.age, lower.freq)
352             UP <- c(upper.rate, upper.age, upper.freq)
353         }
354     } else {
355         start.para <- c(ini.rate, ini.time[unknown.ages])
356         f <- function(p) -penal.loglik(p[RATE], p[AGE])
357         g <- function(p) -gradient(p[RATE], p[AGE])
358         LOW <- c(lower.rate, lower.age)
359         UP <- c(upper.rate, upper.age)
360     }
361
362     k <- length(LOW) # number of free parameters
363
364     if (!quiet) cat("Fitting in progress... get a first set of estimates\n")
365
366     out <- nlminb(start.para, f, g,
367                   control = opt.ctrl, lower = LOW, upper = UP)
368
369     if (model == "discrete") {
370         if (Nb.rates == 1) {
371             f.rates <- function(p) -penal.loglik(p, current.ages)
372             f.ages <- function(p) -penal.loglik(current.rates, p)
373         } else {
374             f.rates <- function(p) -penal.loglik(p, current.ages, current.freqs)
375             f.ages <- function(p) -penal.loglik(current.rates, p, current.freqs)
376             f.freqs <- function(p) -penal.loglik(current.rates, current.ages, p)
377             g.freqs <- NULL
378         }
379         g.rates <- NULL
380         g.ages <- NULL
381     } else {
382         f.rates <- function(p) -penal.loglik(p, current.ages)
383         g.rates <- function(p) -gradient(p, current.ages)[RATE]
384         f.ages <- function(p) -penal.loglik(current.rates, p)
385         g.ages <- function(p) -gradient(current.rates, p)[AGE]
386     }
387
388     current.ploglik <- -out$objective
389     current.rates <- out$par[RATE]
390     current.ages <- out$par[AGE]
391     if (model == "discrete" && Nb.rates > 1) current.freqs <- out$par[FREQ]
392
393     dual.iter.max <- control$dual.iter.max
394     i <- 0L
395
396     if (!quiet) cat("         Penalised log-lik =", current.ploglik, "\n")
397
398     repeat {
399         if (dual.iter.max < 1) break
400         if (!quiet) cat("Optimising rates...")
401         out.rates <- nlminb(current.rates, f.rates, g.rates,# h.rates,
402                             control = list(eval.max = 1000, iter.max = 1000,
403                                            step.min = 1e-8, step.max = .1),
404                             lower = lower.rate, upper = upper.rate)
405         new.rates <- out.rates$par
406         if (-out.rates$objective > current.ploglik)
407             current.rates <- new.rates
408
409         if (model == "discrete" && Nb.rates > 1) {
410             if (!quiet) cat(" frequencies...")
411             out.freqs <- nlminb(current.freqs, f.freqs,
412                                 control = list(eval.max = 1000, iter.max = 1000,
413                                                step.min = .001, step.max = .5),
414                                 lower = lower.freq, upper = upper.freq)
415             new.freqs <- out.freqs$par
416         }
417
418         if (!quiet) cat(" dates...")
419         out.ages <- nlminb(current.ages, f.ages, g.ages,# h.ages,
420                            control = list(eval.max = 1000, iter.max = 1000,
421                                           step.min = .001, step.max = 100),
422                            lower = lower.age, upper = upper.age)
423         new.ploglik <- -out.ages$objective
424
425         if (!quiet) cat("", current.ploglik, "\n")
426
427         if (new.ploglik - current.ploglik > 1e-6 && i <= dual.iter.max) {
428             current.ploglik <- new.ploglik
429             current.rates <- new.rates
430             current.ages <- out.ages$par
431             if (model == "discrete" && Nb.rates > 1) current.freqs <- new.freqs
432             out <- out.ages
433             i <- i + 1L
434         } else break
435     }
436
437     if (!quiet) cat("\nDone.\n")
438
439 #    browser()
440
441     if (model == "discrete") {
442         rate.freq <-
443             if (Nb.rates == 1) current.rates
444             else mean(c(current.freqs, 1 - sum(current.freqs)) * current.rates)
445         logLik <- log.lik.poisson(rate.freq, current.ages)
446         PHIIC <- list(logLik = logLik, k = k, PHIIC = - 2 * logLik + 2 * k)
447     } else {
448         logLik <- log.lik.poisson(current.rates, current.ages)
449         PHI <- switch(model,
450                       "correlated" = (current.rates[ind1] - current.rates[ind2])^2 + var(current.rates[basal]),
451                       "relaxed" = (1:N/N - pgamma(sort(current.rates), mean(current.rates)))^2) # avec loi Gamma
452         PHIIC <- list(logLik = logLik, k = k, lambda = lambda,
453                       PHIIC = - 2 * logLik + 2 * k + lambda * svd(PHI)$d)
454     }
455
456     attr(phy, "call") <- match.call()
457     attr(phy, "ploglik") <- -out$objective
458     attr(phy, "rates") <- current.rates #out$par[EDGES]
459     if (model == "discrete" && Nb.rates > 1)
460         attr(phy, "frequencies") <- current.freqs
461     attr(phy, "message") <- out$message
462     attr(phy, "PHIIC") <- PHIIC
463     age[unknown.ages] <- current.ages #out$par[-EDGES]
464     phy$edge.length <- age[e1] - age[e2]
465     class(phy) <- c("chronos", class(phy))
466     phy
467 }
468
469 print.chronos <- function(x, ...)
470 {
471     cat("\n    Chronogram\n\n")
472     cat("Call: ")
473     print(attr(x, "call"))
474     cat("\n")
475     NextMethod("print")
476 }