]> git.donarmstrong.com Git - mothur.git/blob - qFinderDMM.cpp
some changes on qFinderDMM
[mothur.git] / qFinderDMM.cpp
1 //
2 //  qFinderDMM.cpp
3 //  pds_dmm
4 //
5 //  Created by Patrick Schloss on 11/8/12.
6 //  Copyright (c) 2012 University of Michigan. All rights reserved.
7 //
8
9 #include "qFinderDMM.h"
10 #include "linearalgebra.h"
11
12 #define EPSILON numeric_limits<double>::epsilon()
13
14 /**************************************************************************************************/
15
16 qFinderDMM::qFinderDMM(vector<vector<int> > cm, int p): countMatrix(cm), numPartitions(p){
17     try {
18         m = MothurOut::getInstance();
19         numSamples = (int)countMatrix.size();
20         numOTUs = (int)countMatrix[0].size();
21         
22         
23         kMeans();
24         //    cout << "done kMeans" << endl;
25         
26         optimizeLambda();
27         
28         
29         //    cout << "done optimizeLambda" << endl;
30         
31         double change = 1.0000;
32         currNLL = 0.0000;
33         
34         int iter = 0;
35         
36         while(change > 1.0e-6 && iter < 100){
37             
38             //        cout << "Calc_Z: ";
39             calculatePiK();
40             
41             optimizeLambda();
42             
43             //        printf("Iter:%d\t",iter);
44             
45             for(int i=0;i<numPartitions;i++){
46                 weights[i] = 0.0000;
47                 for(int j=0;j<numSamples;j++){
48                     weights[i] += zMatrix[i][j];
49                 }
50                 //            printf("w_%d=%.3f\t",i,weights[i]);
51                 
52             }
53             
54             double nLL = getNegativeLogLikelihood();
55             
56             change = abs(nLL - currNLL);
57             
58             currNLL = nLL;
59             
60             //        printf("NLL=%.5f\tDelta=%.4e\n",currNLL, change);
61             
62             iter++;
63         }
64         
65         error.resize(numPartitions);
66         
67         logDeterminant = 0.0000;
68         
69         LinearAlgebra l;
70         
71         for(currentPartition=0;currentPartition<numPartitions;currentPartition++){
72             
73             error[currentPartition].assign(numOTUs, 0.0000);
74             
75             if(currentPartition > 0){
76                 logDeterminant += (2.0 * log(numSamples) - log(weights[currentPartition]));
77             }
78             vector<vector<double> > hessian = getHessian();
79             vector<vector<double> > invHessian = l.getInverse(hessian);
80             
81             for(int i=0;i<numOTUs;i++){
82                 logDeterminant += log(abs(hessian[i][i]));
83                 error[currentPartition][i] = invHessian[i][i];
84             }
85         }
86         
87         int numParameters = numPartitions * numOTUs + numPartitions - 1;
88         laplace = currNLL + 0.5 * logDeterminant - 0.5 * numParameters * log(2.0 * 3.14159);
89         bic = currNLL + 0.5 * log(numSamples) * numParameters;
90         aic = currNLL + numParameters;
91     }
92         catch(exception& e) {
93                 m->errorOut(e, "qFinderDMM", "qFinderDMM");
94                 exit(1);
95         }
96 }
97
98 /**************************************************************************************************/
99
100 void qFinderDMM::printZMatrix(string fileName, vector<string> sampleName){
101     try {
102         ofstream printMatrix;
103         m->openOutputFile(fileName, printMatrix); //(fileName.c_str());
104         printMatrix.setf(ios::fixed, ios::floatfield);
105         printMatrix.setf(ios::showpoint);
106         
107         for(int i=0;i<numPartitions;i++){   printMatrix << "\tPartition_" << i+1;   }   printMatrix << endl;
108         
109         for(int i=0;i<numSamples;i++){
110             printMatrix << sampleName[i];
111             for(int j=0;j<numPartitions;j++){
112                 printMatrix << setprecision(4) << '\t' << zMatrix[j][i];
113             }
114             printMatrix << endl;
115         }
116         printMatrix.close();
117     }
118         catch(exception& e) {
119                 m->errorOut(e, "qFinderDMM", "printZMatrix");
120                 exit(1);
121         }
122 }
123
124 /**************************************************************************************************/
125
126 void qFinderDMM::printRelAbund(string fileName, vector<string> otuNames){
127     try {
128         ofstream printRA;
129         m->openOutputFile(fileName, printRA); //(fileName.c_str());
130         printRA.setf(ios::fixed, ios::floatfield);
131         printRA.setf(ios::showpoint);
132         
133         vector<double> totals(numPartitions, 0.0000);
134         for(int i=0;i<numPartitions;i++){
135             for(int j=0;j<numOTUs;j++){
136                 totals[i] += exp(lambdaMatrix[i][j]);
137             }
138         }
139         
140         printRA << "Taxon";
141         for(int i=0;i<numPartitions;i++){
142             printRA << "\tPartition_" << i+1 << '_' << setprecision(4) << totals[i];
143             printRA << "\tPartition_" << i+1 <<"_LCI" << "\tPartition_" << i+1 << "_UCI";
144         }
145         printRA << endl;
146         
147         for(int i=0;i<numOTUs;i++){
148             
149             if (m->control_pressed) { break; }
150             
151             printRA << otuNames[i];
152             for(int j=0;j<numPartitions;j++){
153                 
154                 if(error[j][i] >= 0.0000){
155                     double std = sqrt(error[j][i]);
156                     printRA << '\t' << 100 * exp(lambdaMatrix[j][i]) / totals[j];
157                     printRA << '\t' << 100 * exp(lambdaMatrix[j][i] - 2.0 * std) / totals[j];
158                     printRA << '\t' << 100 * exp(lambdaMatrix[j][i] + 2.0 * std) / totals[j];
159                 }
160                 else{
161                     printRA << '\t' << 100 * exp(lambdaMatrix[j][i]) / totals[j];
162                     printRA << '\t' << "NA";
163                     printRA << '\t' << "NA";
164                 }
165             }
166             printRA << endl;
167         }
168         
169         printRA.close();
170     }
171         catch(exception& e) {
172                 m->errorOut(e, "qFinderDMM", "printRelAbund");
173                 exit(1);
174         }
175 }
176
177 /**************************************************************************************************/
178
179 // these functions for bfgs2 solver were lifted from the gnu_gsl source code...
180
181 /* Find a minimum in x=[0,1] of the interpolating quadratic through
182  * (0,f0) (1,f1) with derivative fp0 at x=0.  The interpolating
183  * polynomial is q(x) = f0 + fp0 * z + (f1-f0-fp0) * z^2
184  */
185
186 static double
187 interp_quad (double f0, double fp0, double f1, double zl, double zh)
188 {
189     double fl = f0 + zl*(fp0 + zl*(f1 - f0 -fp0));
190     double fh = f0 + zh*(fp0 + zh*(f1 - f0 -fp0));
191     double c = 2 * (f1 - f0 - fp0);       /* curvature */
192     
193     double zmin = zl, fmin = fl;
194     
195     if (fh < fmin) { zmin = zh; fmin = fh; } 
196     
197     if (c > 0)  /* positive curvature required for a minimum */
198     {
199         double z = -fp0 / c;      /* location of minimum */
200         if (z > zl && z < zh) {
201             double f = f0 + z*(fp0 + z*(f1 - f0 -fp0));
202             if (f < fmin) { zmin = z; fmin = f; };
203         }
204     }
205     
206     return zmin;
207 }
208
209 /**************************************************************************************************/
210
211 /* Find a minimum in x=[0,1] of the interpolating cubic through
212  * (0,f0) (1,f1) with derivatives fp0 at x=0 and fp1 at x=1.
213  *
214  * The interpolating polynomial is:
215  *
216  * c(x) = f0 + fp0 * z + eta * z^2 + xi * z^3
217  *
218  * where eta=3*(f1-f0)-2*fp0-fp1, xi=fp0+fp1-2*(f1-f0). 
219  */
220
221 double cubic (double c0, double c1, double c2, double c3, double z){
222     return c0 + z * (c1 + z * (c2 + z * c3));
223 }
224
225 /**************************************************************************************************/
226
227 void check_extremum (double c0, double c1, double c2, double c3, double z,
228                      double *zmin, double *fmin){
229     /* could make an early return by testing curvature >0 for minimum */
230     
231     double y = cubic (c0, c1, c2, c3, z);
232     
233     if (y < *fmin)  
234     {
235         *zmin = z;  /* accepted new point*/
236         *fmin = y;
237     }
238 }
239
240 /**************************************************************************************************/
241
242 int gsl_poly_solve_quadratic (double a, double b, double c, 
243                               double *x0, double *x1)
244 {
245     
246     double disc = b * b - 4 * a * c;
247     
248     if (a == 0) /* Handle linear case */
249     {
250         if (b == 0)
251         {
252             return 0;
253         }
254         else
255         {
256             *x0 = -c / b;
257             return 1;
258         };
259     }
260     
261     if (disc > 0)
262     {
263         if (b == 0)
264         {
265             double r = fabs (0.5 * sqrt (disc) / a);
266             *x0 = -r;
267             *x1 =  r;
268         }
269         else
270         {
271             double sgnb = (b > 0 ? 1 : -1);
272             double temp = -0.5 * (b + sgnb * sqrt (disc));
273             double r1 = temp / a ;
274             double r2 = c / temp ;
275             
276             if (r1 < r2) 
277             {
278                 *x0 = r1 ;
279                 *x1 = r2 ;
280             } 
281             else 
282             {
283                 *x0 = r2 ;
284                 *x1 = r1 ;
285             }
286         }
287         return 2;
288     }
289     else if (disc == 0) 
290     {
291         *x0 = -0.5 * b / a ;
292         *x1 = -0.5 * b / a ;
293         return 2 ;
294     }
295     else
296     {
297         return 0;
298     }
299    
300 }
301
302 /**************************************************************************************************/
303
304 double interp_cubic (double f0, double fp0, double f1, double fp1, double zl, double zh){
305     double eta = 3 * (f1 - f0) - 2 * fp0 - fp1;
306     double xi = fp0 + fp1 - 2 * (f1 - f0);
307     double c0 = f0, c1 = fp0, c2 = eta, c3 = xi;
308     double zmin, fmin;
309     double z0, z1;
310     
311     zmin = zl; fmin = cubic(c0, c1, c2, c3, zl);
312     check_extremum (c0, c1, c2, c3, zh, &zmin, &fmin);
313     
314     {
315         int n = gsl_poly_solve_quadratic (3 * c3, 2 * c2, c1, &z0, &z1);
316         
317         if (n == 2)  /* found 2 roots */
318         {
319             if (z0 > zl && z0 < zh) 
320                 check_extremum (c0, c1, c2, c3, z0, &zmin, &fmin);
321             if (z1 > zl && z1 < zh) 
322                 check_extremum (c0, c1, c2, c3, z1, &zmin, &fmin);
323         }
324         else if (n == 1)  /* found 1 root */
325         {
326             if (z0 > zl && z0 < zh) 
327                 check_extremum (c0, c1, c2, c3, z0, &zmin, &fmin);
328         }
329     }
330     
331     return zmin;
332 }
333
334 /**************************************************************************************************/
335
336 double interpolate (double a, double fa, double fpa,
337                     double b, double fb, double fpb, double xmin, double xmax){
338     /* Map [a,b] to [0,1] */
339     double z, alpha, zmin, zmax;
340     
341     zmin = (xmin - a) / (b - a);
342     zmax = (xmax - a) / (b - a);
343     
344     if (zmin > zmax)
345     {
346         double tmp = zmin;
347         zmin = zmax;
348         zmax = tmp;
349     };
350     
351     if(!isnan(fpb) ){
352         z = interp_cubic (fa, fpa * (b - a), fb, fpb * (b - a), zmin, zmax);
353     }
354     else{
355         z = interp_quad(fa, fpa * (b - a), fb, zmin, zmax);
356     }
357
358     
359     alpha = a + z * (b - a);
360     
361     return alpha;
362 }
363
364 /**************************************************************************************************/
365
366 int qFinderDMM::lineMinimizeFletcher(vector<double>& x, vector<double>& p, double f0, double df0, double alpha1, double& alphaNew, double& fAlpha, vector<double>& xalpha, vector<double>& gradient ){
367     try {
368         
369         double rho = 0.01;
370         double sigma = 0.10;
371         double tau1 = 9.00;
372         double tau2 = 0.05;
373         double tau3 = 0.50;
374         
375         double alpha = alpha1;
376         double alpha_prev = 0.0000;
377         
378         xalpha.resize(numOTUs, 0.0000);
379         
380         double falpha_prev = f0;
381         double dfalpha_prev = df0;
382         
383         double a = 0.0000;          double b = alpha;
384         double fa = f0;             double fb = 0.0000;
385         double dfa = df0;           double dfb = 0.0/0.0;
386         
387         int iter = 0;
388         int maxIters = 100;
389         while(iter++ < maxIters){
390             if (m->control_pressed) { break; }
391             
392             for(int i=0;i<numOTUs;i++){
393                 xalpha[i] = x[i] + alpha * p[i];
394             }
395             
396             fAlpha = negativeLogEvidenceLambdaPi(xalpha);
397             
398             if(fAlpha > f0 + alpha * rho * df0 || fAlpha >= falpha_prev){
399                 a = alpha_prev;         b = alpha;
400                 fa = falpha_prev;       fb = fAlpha;
401                 dfa = dfalpha_prev;     dfb = 0.0/0.0;
402                 break;
403             }
404             
405             negativeLogDerivEvidenceLambdaPi(xalpha, gradient);
406             double dfalpha = 0.0000;
407             for(int i=0;i<numOTUs;i++){ dfalpha += gradient[i] * p[i]; }
408             
409             if(abs(dfalpha) <= -sigma * df0){
410                 alphaNew = alpha;
411                 return 1;
412             }
413             
414             if(dfalpha >= 0){
415                 a = alpha;                  b = alpha_prev;
416                 fa = fAlpha;                fb = falpha_prev;
417                 dfa = dfalpha;              dfb = dfalpha_prev;
418                 break;
419             }
420             
421             double delta = alpha - alpha_prev;
422             
423             double lower = alpha + delta;
424             double upper = alpha + tau1 * delta;
425             
426             double alphaNext = interpolate(alpha_prev, falpha_prev, dfalpha_prev, alpha, fAlpha, dfalpha, lower, upper);
427             
428             alpha_prev = alpha;
429             falpha_prev = fAlpha;
430             dfalpha_prev = dfalpha;
431             alpha = alphaNext;
432         }
433         
434         iter = 0;
435         while(iter++ < maxIters){
436             if (m->control_pressed) { break; }
437             
438             double delta = b - a;
439             
440             double lower = a + tau2 * delta;
441             double upper = b - tau3 * delta;
442             
443             alpha = interpolate(a, fa, dfa, b, fb, dfb, lower, upper);
444             
445             for(int i=0;i<numOTUs;i++){
446                 xalpha[i] = x[i] + alpha * p[i];
447             }
448             
449             fAlpha = negativeLogEvidenceLambdaPi(xalpha);
450             
451             if((a - alpha) * dfa <= EPSILON){
452                 return 0;
453             }
454             
455             if(fAlpha > f0 + rho * alpha * df0 || fAlpha >= fa){
456                 b = alpha;
457                 fb = fAlpha;
458                 dfb = 0.0/0.0;
459             }
460             else{
461                 double dfalpha = 0.0000;
462                 
463                 negativeLogDerivEvidenceLambdaPi(xalpha, gradient);
464                 dfalpha = 0.0000;
465                 for(int i=0;i<numOTUs;i++){ dfalpha += gradient[i] * p[i]; }
466                 
467                 if(abs(dfalpha) <= -sigma * df0){
468                     alphaNew = alpha;
469                     return 1;
470                 }
471                 
472                 if(((b-a >= 0 && dfalpha >= 0) || ((b-a) <= 0.000 && dfalpha <= 0))){
473                     b = a;      fb = fa;        dfb = dfa;
474                     a = alpha;  fa = fAlpha;    dfa = dfalpha;
475                 }
476                 else{
477                     a = alpha;
478                     fa = fAlpha;
479                     dfa = dfalpha;
480                 }
481             }
482             
483             
484         }
485         
486         return 1;
487     }
488         catch(exception& e) {
489                 m->errorOut(e, "qFinderDMM", "lineMinimizeFletcher");
490                 exit(1);
491         }
492 }
493
494 /**************************************************************************************************/
495
496 int qFinderDMM::bfgs2_Solver(vector<double>& x){
497     try{
498 //        cout << "bfgs2_Solver" << endl;
499         int bfgsIter = 0;
500         double step = 1.0e-6;
501         double delta_f = 0.0000;//f-f0;
502
503         vector<double> gradient;
504         double f = negativeLogEvidenceLambdaPi(x);
505         
506 //        cout << "after negLE" << endl;
507         
508         negativeLogDerivEvidenceLambdaPi(x, gradient);
509
510 //        cout << "after negLDE" << endl;
511
512         vector<double> x0 = x;
513         vector<double> g0 = gradient;
514
515         double g0norm = 0;
516         for(int i=0;i<numOTUs;i++){
517             g0norm += g0[i] * g0[i];
518         }
519         g0norm = sqrt(g0norm);
520
521         vector<double> p = gradient;
522         double pNorm = 0;
523         for(int i=0;i<numOTUs;i++){
524             p[i] *= -1 / g0norm;
525             pNorm += p[i] * p[i];
526         }
527         pNorm = sqrt(pNorm);
528         double df0 = -g0norm;
529
530         int maxIter = 5000;
531         
532 //        cout << "before while" << endl;
533         
534         while(g0norm > 0.001 && bfgsIter++ < maxIter){
535             if (m->control_pressed) {  return 0; }
536
537             double f0 = f;
538             vector<double> dx(numOTUs, 0.0000);
539             
540             double alphaOld, alphaNew;
541
542             if(pNorm == 0 || g0norm == 0 || df0 == 0){
543                 dx.assign(numOTUs, 0.0000);
544                 break;
545             }
546             if(delta_f < 0){
547                 double delta = max(-delta_f, 10 * EPSILON * abs(f0));
548                 alphaOld = min(1.0, 2.0 * delta / (-df0));
549             }
550             else{
551                 alphaOld = step;
552             }
553             
554             int success = lineMinimizeFletcher(x0, p, f0, df0, alphaOld, alphaNew, f, x, gradient);
555             
556             if(!success){
557                 x = x0;
558                 break;   
559             }
560             
561             delta_f = f - f0;
562             
563             vector<double> dx0(numOTUs);
564             vector<double> dg0(numOTUs);
565             
566             for(int i=0;i<numOTUs;i++){
567                 dx0[i] = x[i] - x0[i];
568                 dg0[i] = gradient[i] - g0[i];
569             }
570             
571             double dxg = 0;
572             double dgg = 0;
573             double dxdg = 0;
574             double dgnorm = 0;
575             
576             for(int i=0;i<numOTUs;i++){
577                 dxg += dx0[i] * gradient[i];
578                 dgg += dg0[i] * gradient[i];
579                 dxdg += dx0[i] * dg0[i];
580                 dgnorm += dg0[i] * dg0[i];
581             }
582             dgnorm = sqrt(dgnorm);
583             
584             double A, B;
585             
586             if(dxdg != 0){
587                 B = dxg / dxdg;
588                 A = -(1.0 + dgnorm*dgnorm /dxdg) * B + dgg / dxdg;            
589             }
590             else{
591                 B = 0;
592                 A = 0;
593             }
594             
595             for(int i=0;i<numOTUs;i++){     p[i] = gradient[i] - A * dx0[i] - B * dg0[i];   }
596             
597             x0 = x;
598             g0 = gradient;
599             
600
601             double pg = 0;
602             pNorm = 0.0000;
603             g0norm = 0.0000;
604             
605             for(int i=0;i<numOTUs;i++){
606                 pg += p[i] * gradient[i];
607                 pNorm += p[i] * p[i];
608                 g0norm += g0[i] * g0[i];
609             }
610             pNorm = sqrt(pNorm);
611             g0norm = sqrt(g0norm);
612             
613             double dir = (pg >= 0.0) ? -1.0 : +1.0;
614
615             for(int i=0;i<numOTUs;i++){ p[i] *= dir / pNorm;    }
616             
617             pNorm = 0.0000;
618             df0 = 0.0000;
619             for(int i=0;i<numOTUs;i++){
620                 pNorm += p[i] * p[i];       
621                 df0 += p[i] * g0[i];
622             }
623             
624             pNorm = sqrt(pNorm);
625
626         }
627 //        cout << "bfgsIter:\t" << bfgsIter << endl;
628
629         return bfgsIter;
630     }
631     catch(exception& e){
632         m->errorOut(e, "qFinderDMM", "bfgs2_Solver");
633         exit(1);
634     }
635 }
636
637 /**************************************************************************************************/
638
639 //can we get these psi/psi1 calculations into their own math class?
640 //psi calcualtions swiped from gsl library...
641
642 static const double psi_cs[23] = {
643     -.038057080835217922,
644     .491415393029387130, 
645     -.056815747821244730,
646     .008357821225914313,
647     -.001333232857994342,
648     .000220313287069308,
649     -.000037040238178456,
650     .000006283793654854,
651     -.000001071263908506,
652     .000000183128394654,
653     -.000000031353509361,
654     .000000005372808776,
655     -.000000000921168141,
656     .000000000157981265,
657     -.000000000027098646,
658     .000000000004648722,
659     -.000000000000797527,
660     .000000000000136827,
661     -.000000000000023475,
662     .000000000000004027,
663     -.000000000000000691,
664     .000000000000000118,
665     -.000000000000000020
666 };
667
668 static double apsi_cs[16] = {    
669     -.0204749044678185,
670     -.0101801271534859,
671     .0000559718725387,
672     -.0000012917176570,
673     .0000000572858606,
674     -.0000000038213539,
675     .0000000003397434,
676     -.0000000000374838,
677     .0000000000048990,
678     -.0000000000007344,
679     .0000000000001233,
680     -.0000000000000228,
681     .0000000000000045,
682     -.0000000000000009,
683     .0000000000000002,
684     -.0000000000000000 
685 };    
686
687 /**************************************************************************************************/
688
689 double qFinderDMM::cheb_eval(const double seriesData[], int order, double xx){
690     try {
691         double d = 0.0000;
692         double dd = 0.0000;
693         
694         double x2 = xx * 2.0000;
695         
696         for(int j=order;j>=1;j--){
697             if (m->control_pressed) {  return 0; }
698             double temp = d;
699             d = x2 * d - dd + seriesData[j];
700             dd = temp;
701         }
702         
703         d = xx * d - dd + 0.5 * seriesData[0];
704         return d;
705     }
706     catch(exception& e){
707         m->errorOut(e, "qFinderDMM", "cheb_eval");
708         exit(1);
709     }
710 }
711
712 /**************************************************************************************************/
713
714 double qFinderDMM::psi(double xx){
715     try {
716         double psiX = 0.0000;
717         
718         if(xx < 1.0000){
719             
720             double t1 = 1.0 / xx;
721             psiX = cheb_eval(psi_cs, 22, 2.0*xx-1.0);
722             psiX = -t1 + psiX;
723             
724         }
725         else if(xx < 2.0000){
726             
727             const double v = xx - 1.0;
728             psiX = cheb_eval(psi_cs, 22, 2.0*v-1.0);
729             
730         }
731         else{
732             const double t = 8.0/(xx*xx)-1.0;
733             psiX = cheb_eval(apsi_cs, 15, t);
734             psiX += log(xx) - 0.5/xx;
735         }
736         
737         return psiX;
738     }
739     catch(exception& e){
740         m->errorOut(e, "qFinderDMM", "psi");
741         exit(1);
742     }
743 }
744
745 /**************************************************************************************************/
746
747 /* coefficients for Maclaurin summation in hzeta()
748  * B_{2j}/(2j)!
749  */
750 static double hzeta_c[15] = {
751     1.00000000000000000000000000000,
752     0.083333333333333333333333333333,
753     -0.00138888888888888888888888888889,
754     0.000033068783068783068783068783069,
755     -8.2671957671957671957671957672e-07,
756     2.0876756987868098979210090321e-08,
757     -5.2841901386874931848476822022e-10,
758     1.3382536530684678832826980975e-11,
759     -3.3896802963225828668301953912e-13,
760     8.5860620562778445641359054504e-15,
761     -2.1748686985580618730415164239e-16,
762     5.5090028283602295152026526089e-18,
763     -1.3954464685812523340707686264e-19,
764     3.5347070396294674716932299778e-21,
765     -8.9535174270375468504026113181e-23
766 };
767
768 /**************************************************************************************************/
769
770 double qFinderDMM::psi1(double xx){
771     try {
772         
773         /* Euler-Maclaurin summation formula
774          * [Moshier, p. 400, with several typo corrections]
775          */
776         
777         double s = 2.0000;
778         const int jmax = 12;
779         const int kmax = 10;
780         int j, k;
781         const double pmax  = pow(kmax + xx, -s);
782         double scp = s;
783         double pcp = pmax / (kmax + xx);
784         double value = pmax*((kmax+xx)/(s-1.0) + 0.5);
785         
786         for(k=0; k<kmax; k++) {
787             if (m->control_pressed) {  return 0; }
788             value += pow(k + xx, -s);
789         }
790         
791         for(j=0; j<=jmax; j++) {
792             if (m->control_pressed) {  return 0; }
793             double delta = hzeta_c[j+1] * scp * pcp;
794             value += delta;
795             
796             if(fabs(delta/value) < 0.5*EPSILON) break;
797             
798             scp *= (s+2*j+1)*(s+2*j+2);
799             pcp /= (kmax + xx)*(kmax + xx);
800         }
801         
802         return value;
803     }
804     catch(exception& e){
805         m->errorOut(e, "qFinderDMM", "psi1");
806         exit(1);
807     }
808 }
809
810 /**************************************************************************************************/
811
812 double qFinderDMM::negativeLogEvidenceLambdaPi(vector<double>& x){
813     try{
814         vector<double> sumAlphaX(numSamples, 0.0000);
815         
816         double logEAlpha = 0.0000;
817         double sumLambda = 0.0000;
818         double sumAlpha = 0.0000;
819         double logE = 0.0000;
820         double nu = 0.10000;
821         double eta = 0.10000;
822         
823         double weight = 0.00000;
824         for(int i=0;i<numSamples;i++){
825             weight += zMatrix[currentPartition][i];
826         }
827         
828         for(int i=0;i<numOTUs;i++){
829             if (m->control_pressed) {  return 0; }
830             double lambda = x[i];
831             double alpha = exp(x[i]);
832             logEAlpha += lgamma(alpha);
833             sumLambda += lambda;
834             sumAlpha += alpha;
835             
836             for(int j=0;j<numSamples;j++){
837                 double X = countMatrix[j][i];
838                 double alphaX = alpha + X;
839                 sumAlphaX[j] += alphaX;
840                 
841                 logE -= zMatrix[currentPartition][j] * lgamma(alphaX);
842             }
843         }
844         
845         logEAlpha -= lgamma(sumAlpha);
846
847         for(int i=0;i<numSamples;i++){
848             logE += zMatrix[currentPartition][i] * lgamma(sumAlphaX[i]);
849         }
850
851         return logE + weight * logEAlpha + nu * sumAlpha - eta * sumLambda;
852     }
853     catch(exception& e){
854         m->errorOut(e, "qFinderDMM", "negativeLogEvidenceLambdaPi");
855         exit(1);
856     }
857 }
858
859 /**************************************************************************************************/
860
861 void qFinderDMM::negativeLogDerivEvidenceLambdaPi(vector<double>& x, vector<double>& df){
862     try{
863 //        cout << "\tstart negativeLogDerivEvidenceLambdaPi" << endl;
864         
865         vector<double> storeVector(numSamples, 0.0000);
866         vector<double> derivative(numOTUs, 0.0000);
867         vector<double> alpha(numOTUs, 0.0000);
868         
869         double store = 0.0000;
870         double nu = 0.1000;
871         double eta = 0.1000;
872         
873         double weight = 0.0000;
874         for(int i=0;i<numSamples;i++){
875             weight += zMatrix[currentPartition][i];
876         }
877
878         
879         for(int i=0;i<numOTUs;i++){
880             if (m->control_pressed) {  return; }
881 //            cout << "start i loop" << endl;
882 //            
883 //            cout << i << '\t' << alpha[i] << '\t' << x[i] << '\t' << exp(x[i]) << '\t' << store << endl;
884             
885             alpha[i] = exp(x[i]);
886             store += alpha[i];
887             
888 //            cout << "before derivative" << endl;
889             
890             derivative[i] = weight * psi(alpha[i]);
891
892 //            cout << "after derivative" << endl;
893
894 //            cout << i << '\t' << alpha[i] << '\t' << psi(alpha[i]) << '\t' << derivative[i] << endl;
895
896             for(int j=0;j<numSamples;j++){
897                 double X = countMatrix[j][i];
898                 double alphaX = X + alpha[i];
899                 
900                 derivative[i] -= zMatrix[currentPartition][j] * psi(alphaX);
901                 storeVector[j] += alphaX;
902             }
903 //            cout << "end i loop" << endl;
904         }
905
906         double sumStore = 0.0000;
907         for(int i=0;i<numSamples;i++){
908             sumStore += zMatrix[currentPartition][i] * psi(storeVector[i]);
909         }
910         
911         store = weight * psi(store);
912         
913         df.resize(numOTUs, 0.0000);
914         
915         for(int i=0;i<numOTUs;i++){
916             df[i] = alpha[i] * (nu + derivative[i] - store + sumStore) - eta;
917 //            cout << i << '\t' << df[i] << endl;
918         }
919 //        cout << df.size() << endl;
920 //        cout << "\tend negativeLogDerivEvidenceLambdaPi" << endl;
921     }
922     catch(exception& e){
923          m->errorOut(e, "qFinderDMM", "negativeLogDerivEvidenceLambdaPi");
924         exit(1);
925     }
926 }
927
928 /**************************************************************************************************/
929
930 double qFinderDMM::getNegativeLogEvidence(vector<double>& lambda, int group){
931     try {
932         double sumAlpha = 0.0000;
933         double sumAlphaX = 0.0000;
934         double sumLnGamAlpha = 0.0000;
935         double logEvidence = 0.0000;
936         
937         for(int i=0;i<numOTUs;i++){
938             if (m->control_pressed) {  return 0; }
939             double alpha = exp(lambda[i]);
940             double X = countMatrix[group][i];
941             double alphaX = alpha + X;
942             
943             sumLnGamAlpha += lgamma(alpha);
944             sumAlpha += alpha;
945             sumAlphaX += alphaX;
946             
947             logEvidence -= lgamma(alphaX);
948         }
949         
950         sumLnGamAlpha -= lgamma(sumAlpha);
951         logEvidence += lgamma(sumAlphaX);
952         
953         return logEvidence + sumLnGamAlpha;
954     }
955     catch(exception& e){
956         m->errorOut(e, "qFinderDMM", "getNegativeLogEvidence");
957         exit(1);
958     }
959 }
960
961 /**************************************************************************************************/
962
963 void qFinderDMM::kMeans(){
964     try {
965         
966         vector<vector<double> > relativeAbundance(numSamples);
967         vector<vector<double> > alphaMatrix;
968         
969         alphaMatrix.resize(numPartitions);
970         lambdaMatrix.resize(numPartitions);
971         for(int i=0;i<numPartitions;i++){
972             alphaMatrix[i].assign(numOTUs, 0);
973             lambdaMatrix[i].assign(numOTUs, 0);
974         }
975         
976         //get relative abundance
977         for(int i=0;i<numSamples;i++){
978             if (m->control_pressed) {  return; }
979             int groupTotal = 0;
980             
981             relativeAbundance[i].assign(numOTUs, 0.0);
982             
983             for(int j=0;j<numOTUs;j++){
984                 groupTotal += countMatrix[i][j];
985             }
986             for(int j=0;j<numOTUs;j++){
987                 relativeAbundance[i][j] = countMatrix[i][j] / (double)groupTotal;
988             }
989         }
990         
991         //randomly assign samples into partitions
992         zMatrix.resize(numPartitions);
993         for(int i=0;i<numPartitions;i++){
994             zMatrix[i].assign(numSamples, 0);
995         }
996         
997         for(int i=0;i<numSamples;i++){
998             zMatrix[rand()%numPartitions][i] = 1;
999         }
1000         
1001         double maxChange = 1;
1002         int maxIters = 1000;
1003         int iteration = 0;
1004         
1005         weights.assign(numPartitions, 0);
1006         
1007         while(maxChange > 1e-6 && iteration < maxIters){
1008             
1009             if (m->control_pressed) {  return; }
1010             //calcualte average relative abundance
1011             maxChange = 0.0000;
1012             for(int i=0;i<numPartitions;i++){
1013                 
1014                 double normChange = 0.0;
1015                 
1016                 weights[i] = 0;
1017                 
1018                 for(int j=0;j<numSamples;j++){
1019                     weights[i] += (double)zMatrix[i][j];
1020                 }
1021                 
1022                 vector<double> averageRelativeAbundance(numOTUs, 0);
1023                 for(int j=0;j<numOTUs;j++){
1024                     for(int k=0;k<numSamples;k++){
1025                         averageRelativeAbundance[j] += zMatrix[i][k] * relativeAbundance[k][j];
1026                     }
1027                 }
1028                 
1029                 for(int j=0;j<numOTUs;j++){
1030                     averageRelativeAbundance[j] /= weights[i];
1031                     double difference = averageRelativeAbundance[j] - alphaMatrix[i][j];
1032                     normChange += difference * difference;
1033                     alphaMatrix[i][j] = averageRelativeAbundance[j];
1034                 }
1035                 
1036                 normChange = sqrt(normChange);
1037                 
1038                 if(normChange > maxChange){ maxChange = normChange; }
1039             }
1040             
1041             
1042             //calcualte distance between each sample in partition adn the average relative abundance
1043             for(int i=0;i<numSamples;i++){
1044                 if (m->control_pressed) {  return; }
1045                 
1046                 double normalizationFactor = 0;
1047                 vector<double> totalDistToPartition(numPartitions, 0);
1048                 
1049                 for(int j=0;j<numPartitions;j++){
1050                     for(int k=0;k<numOTUs;k++){
1051                         double difference = alphaMatrix[j][k] - relativeAbundance[i][k];
1052                         totalDistToPartition[j] += difference * difference;
1053                     }
1054                     totalDistToPartition[j] = sqrt(totalDistToPartition[j]);
1055                     normalizationFactor += exp(-50.0 * totalDistToPartition[j]);
1056                 }
1057                 
1058                 
1059                 for(int j=0;j<numPartitions;j++){
1060                     zMatrix[j][i] = exp(-50.0 * totalDistToPartition[j]) / normalizationFactor;
1061                 }
1062                 
1063             }
1064             
1065             iteration++;
1066             //        cout << "K means: " << iteration << '\t' << maxChange << endl;
1067             
1068         }
1069         
1070         //    cout << "Iter:-1";
1071         for(int i=0;i<numPartitions;i++){
1072             weights[i] = 0.0000;
1073             
1074             for(int j=0;j<numSamples;j++){
1075                 weights[i] += zMatrix[i][j];
1076             }
1077             //        printf("\tw_%d=%.3f", i, weights[i]);
1078         }
1079         //    cout << endl;
1080         
1081         
1082         for(int i=0;i<numOTUs;i++){
1083             if (m->control_pressed) {  return; }
1084             for(int j=0;j<numPartitions;j++){
1085                 if(alphaMatrix[j][i] > 0){
1086                     lambdaMatrix[j][i] = log(alphaMatrix[j][i]);
1087                 }
1088                 else{
1089                     lambdaMatrix[j][i] = -10.0;
1090                 }
1091             }
1092         }
1093     }
1094     catch(exception& e){
1095         m->errorOut(e, "qFinderDMM", "kMeans");
1096         exit(1);
1097     }
1098 }
1099
1100 /**************************************************************************************************/
1101
1102 void qFinderDMM::optimizeLambda(){    
1103     try {
1104         for(currentPartition=0;currentPartition<numPartitions;currentPartition++){
1105             if (m->control_pressed) {  return; }
1106             bfgs2_Solver(lambdaMatrix[currentPartition]);
1107         }
1108     }
1109     catch(exception& e){
1110         m->errorOut(e, "qFinderDMM", "optimizeLambda");
1111         exit(1);
1112     }
1113 }
1114
1115 /**************************************************************************************************/
1116
1117 void qFinderDMM::calculatePiK(){
1118     try {
1119         vector<double> store(numPartitions);
1120         
1121         for(int i=0;i<numSamples;i++){
1122             if (m->control_pressed) {  return; }
1123             double sum = 0.0000;
1124             double minNegLogEvidence =numeric_limits<double>::max();
1125             
1126             for(int j=0;j<numPartitions;j++){
1127                 double negLogEvidenceJ = getNegativeLogEvidence(lambdaMatrix[j], i);
1128                 
1129                 if(negLogEvidenceJ < minNegLogEvidence){
1130                     minNegLogEvidence = negLogEvidenceJ;
1131                 }
1132                 store[j] = negLogEvidenceJ;
1133             }
1134             
1135             for(int j=0;j<numPartitions;j++){
1136                 if (m->control_pressed) {  return; }
1137                 zMatrix[j][i] = weights[j] * exp(-(store[j] - minNegLogEvidence));
1138                 sum += zMatrix[j][i];
1139             }
1140             
1141             for(int j=0;j<numPartitions;j++){
1142                 zMatrix[j][i] /= sum;
1143             }
1144             
1145         }
1146     }
1147     catch(exception& e){
1148         m->errorOut(e, "qFinderDMM", "calculatePiK");
1149         exit(1);
1150     }
1151     
1152 }
1153
1154 /**************************************************************************************************/
1155
1156 double qFinderDMM::getNegativeLogLikelihood(){
1157     try {
1158         double eta = 0.10000;
1159         double nu = 0.10000;
1160         
1161         vector<double> pi(numPartitions, 0.0000);
1162         vector<double> logBAlpha(numPartitions, 0.0000);
1163         
1164         double doubleSum = 0.0000;
1165         
1166         for(int i=0;i<numPartitions;i++){
1167             if (m->control_pressed) {  return 0; }
1168             double sumAlphaK = 0.0000;
1169             
1170             pi[i] = weights[i] / (double)numSamples;
1171             
1172             for(int j=0;j<numOTUs;j++){
1173                 double alpha = exp(lambdaMatrix[i][j]);
1174                 sumAlphaK += alpha;
1175                 
1176                 logBAlpha[i] += lgamma(alpha);
1177             }
1178             logBAlpha[i] -= lgamma(sumAlphaK);
1179         }
1180         
1181         for(int i=0;i<numSamples;i++){
1182             if (m->control_pressed) {  return 0; }
1183             
1184             double probability = 0.0000;
1185             double factor = 0.0000;
1186             double sum = 0.0000;
1187             vector<double> logStore(numPartitions, 0.0000);
1188             double offset = -numeric_limits<double>::max();
1189             
1190             for(int j=0;j<numOTUs;j++){
1191                 sum += countMatrix[i][j];
1192                 factor += lgamma(countMatrix[i][j] + 1.0000);
1193             }
1194             factor -= lgamma(sum + 1.0);
1195             
1196             for(int k=0;k<numPartitions;k++){
1197                 
1198                 double sumAlphaKX = 0.0000;
1199                 double logBAlphaX = 0.0000;
1200                 
1201                 for(int j=0;j<numOTUs;j++){
1202                     double alphaX = exp(lambdaMatrix[k][j]) + (double)countMatrix[i][j];
1203                     
1204                     sumAlphaKX += alphaX;
1205                     logBAlphaX += lgamma(alphaX);
1206                 }
1207                 
1208                 logBAlphaX -= lgamma(sumAlphaKX);
1209                 
1210                 logStore[k] = logBAlphaX - logBAlpha[k] - factor;
1211                 if(logStore[k] > offset){
1212                     offset = logStore[k];
1213                 }
1214                 
1215             }
1216             
1217             for(int k=0;k<numPartitions;k++){
1218                 probability += pi[k] * exp(-offset + logStore[k]);
1219             }
1220             doubleSum += log(probability) + offset;
1221             
1222         }
1223         
1224         double L5 = - numOTUs * numPartitions * lgamma(eta);
1225         double L6 = eta * numPartitions * numOTUs * log(nu);
1226         
1227         double alphaSum, lambdaSum;
1228         alphaSum = lambdaSum = 0.0000;
1229         
1230         for(int i=0;i<numPartitions;i++){
1231             for(int j=0;j<numOTUs;j++){
1232                 if (m->control_pressed) {  return 0; }
1233                 alphaSum += exp(lambdaMatrix[i][j]);
1234                 lambdaSum += lambdaMatrix[i][j];
1235             }
1236         }
1237         alphaSum *= -nu;
1238         lambdaSum *= eta;
1239         
1240         return (-doubleSum - L5 - L6 - alphaSum - lambdaSum);
1241     }
1242     catch(exception& e){
1243         m->errorOut(e, "qFinderDMM", "getNegativeLogLikelihood");
1244         exit(1);
1245     }
1246
1247
1248 }
1249
1250 /**************************************************************************************************/
1251
1252 vector<vector<double> > qFinderDMM::getHessian(){
1253     try {
1254         vector<double> alpha(numOTUs, 0.0000);
1255         double alphaSum = 0.0000;
1256         
1257         vector<double> pi = zMatrix[currentPartition];
1258         vector<double> psi_ajk(numOTUs, 0.0000);
1259         vector<double> psi_cjk(numOTUs, 0.0000);
1260         vector<double> psi1_ajk(numOTUs, 0.0000);
1261         vector<double> psi1_cjk(numOTUs, 0.0000);
1262         
1263         for(int j=0;j<numOTUs;j++){
1264             
1265             if (m->control_pressed) {  break; }
1266             
1267             alpha[j] = exp(lambdaMatrix[currentPartition][j]);
1268             alphaSum += alpha[j];
1269             
1270             for(int i=0;i<numSamples;i++){
1271                 double X = (double) countMatrix[i][j];
1272                 
1273                 psi_ajk[j] += pi[i] * psi(alpha[j]);
1274                 psi1_ajk[j] += pi[i] * psi1(alpha[j]);
1275                 
1276                 psi_cjk[j] += pi[i] * psi(alpha[j] + X);
1277                 psi1_cjk[j] += pi[i] * psi1(alpha[j] + X);
1278             }
1279         }
1280         
1281         
1282         double psi_Ck = 0.0000;
1283         double psi1_Ck = 0.0000;
1284         
1285         double weight = 0.0000;
1286         
1287         for(int i=0;i<numSamples;i++){
1288             if (m->control_pressed) {  break; }
1289             weight += pi[i];
1290             double sum = 0.0000;
1291             for(int j=0;j<numOTUs;j++){     sum += alpha[j] + countMatrix[i][j];    }
1292             
1293             psi_Ck += pi[i] * psi(sum);
1294             psi1_Ck += pi[i] * psi1(sum);
1295         }
1296         
1297         double psi_Ak = weight * psi(alphaSum);
1298         double psi1_Ak = weight * psi1(alphaSum);
1299         
1300         vector<vector<double> > hessian(numOTUs);
1301         for(int i=0;i<numOTUs;i++){ hessian[i].assign(numOTUs, 0.0000); }
1302         
1303         for(int i=0;i<numOTUs;i++){
1304             if (m->control_pressed) {  break; }
1305             double term1 = -alpha[i] * (- psi_ajk[i] + psi_Ak + psi_cjk[i] - psi_Ck);
1306             double term2 = -alpha[i] * alpha[i] * (-psi1_ajk[i] + psi1_Ak + psi1_cjk[i] - psi1_Ck);
1307             double term3 = 0.1 * alpha[i];
1308             
1309             hessian[i][i] = term1 + term2 + term3;
1310             
1311             for(int j=0;j<i;j++){   
1312                 hessian[i][j] = - alpha[i] * alpha[j] * (psi1_Ak - psi1_Ck);
1313                 hessian[j][i] = hessian[i][j];
1314             }
1315         }
1316         
1317         return hessian;
1318     }
1319     catch(exception& e){
1320         m->errorOut(e, "qFinderDMM", "getHessian");
1321         exit(1);
1322     }
1323 }
1324
1325 /**************************************************************************************************/