]> git.donarmstrong.com Git - mothur.git/blob - qFinderDMM.cpp
fixes while testing 1.33.0
[mothur.git] / qFinderDMM.cpp
1 //
2 //  qFinderDMM.cpp
3 //  pds_dmm
4 //
5 //  Created by Patrick Schloss on 11/8/12.
6 //  Copyright (c) 2012 University of Michigan. All rights reserved.
7 //
8
9 #include "qFinderDMM.h"
10
11
12
13 /**************************************************************************************************/
14
15 qFinderDMM::qFinderDMM(vector<vector<int> > cm, int p) : CommunityTypeFinder() {
16     try {
17         //cout << "here" << endl;
18         numPartitions = p;
19         countMatrix = cm;
20         numSamples = (int)countMatrix.size();
21         numOTUs = (int)countMatrix[0].size();
22         
23         //cout << "before kmeans" <<endl;
24         findkMeans();
25             //cout << "done kMeans" << endl;
26         
27         optimizeLambda();
28         
29         
30             //cout << "done optimizeLambda" << endl;
31         
32         double change = 1.0000;
33         currNLL = 0.0000;
34         
35         int iter = 0;
36         
37         while(change > 1.0e-6 && iter < 100){
38             
39                     //cout << "Calc_Z: ";
40             calculatePiK();
41             
42             optimizeLambda();
43             
44                     //printf("Iter:%d\t",iter);
45             
46             for(int i=0;i<numPartitions;i++){
47                 weights[i] = 0.0000;
48                 for(int j=0;j<numSamples;j++){
49                     weights[i] += zMatrix[i][j];
50                 }
51                            // printf("w_%d=%.3f\t",i,weights[i]);
52                 
53             }
54             
55             double nLL = getNegativeLogLikelihood();
56             
57             change = abs(nLL - currNLL);
58             
59             currNLL = nLL;
60             
61                    // printf("NLL=%.5f\tDelta=%.4e\n",currNLL, change);
62             
63             iter++;
64         }
65         
66         error.resize(numPartitions);
67         
68         logDeterminant = 0.0000;
69         
70         LinearAlgebra l;
71         
72         for(currentPartition=0;currentPartition<numPartitions;currentPartition++){
73             
74             error[currentPartition].assign(numOTUs, 0.0000);
75             
76             if(currentPartition > 0){
77                 logDeterminant += (2.0 * log(numSamples) - log(weights[currentPartition]));
78             }
79             vector<vector<double> > hessian = getHessian();
80             vector<vector<double> > invHessian = l.getInverse(hessian);
81             
82             for(int i=0;i<numOTUs;i++){
83                 logDeterminant += log(abs(hessian[i][i]));
84                 error[currentPartition][i] = invHessian[i][i];
85             }
86         }
87         
88         int numParameters = numPartitions * numOTUs + numPartitions - 1;
89         laplace = currNLL + 0.5 * logDeterminant - 0.5 * numParameters * log(2.0 * 3.14159);
90         bic = currNLL + 0.5 * log(numSamples) * numParameters;
91         aic = currNLL + numParameters;
92     }
93         catch(exception& e) {
94                 m->errorOut(e, "qFinderDMM", "qFinderDMM");
95                 exit(1);
96         }
97 }
98 /**************************************************************************************************/
99 void qFinderDMM::printFitData(ofstream& out){
100     try {
101         out << setprecision (2) << numPartitions << '\t'  << getNLL() << '\t' << getLogDet() << '\t' <<  getBIC() << '\t' << getAIC() << '\t' << laplace << endl;
102         return;
103     }
104     catch(exception& e){
105         m->errorOut(e, "CommunityTypeFinder", "printFitData");
106         exit(1);
107     }
108 }
109 /**************************************************************************************************/
110 void qFinderDMM::printFitData(ostream& out, double minLaplace){
111     try {
112         if(laplace < minLaplace){
113             out << setprecision (2) << numPartitions << '\t'  << getNLL() << '\t' << getLogDet() << '\t' <<  getBIC() << '\t' << getAIC() << '\t' << laplace << "***" << endl;
114         }else {
115             out << setprecision (2) << numPartitions << '\t'  << getNLL() << '\t' << getLogDet() << '\t' <<  getBIC() << '\t' << getAIC() << '\t' << laplace << endl;
116         }
117         
118         m->mothurOutJustToLog(toString(numPartitions) + '\t' + toString(getNLL()) + '\t' + toString(getLogDet()) + '\t');
119         m->mothurOutJustToLog(toString(getBIC()) + '\t' + toString(getAIC()) + '\t' + toString(laplace));
120
121         return;
122     }
123     catch(exception& e){
124         m->errorOut(e, "CommunityTypeFinder", "printFitData");
125         exit(1);
126     }
127 }
128
129 /**************************************************************************************************/
130
131 // these functions for bfgs2 solver were lifted from the gnu_gsl source code...
132
133 /* Find a minimum in x=[0,1] of the interpolating quadratic through
134  * (0,f0) (1,f1) with derivative fp0 at x=0.  The interpolating
135  * polynomial is q(x) = f0 + fp0 * z + (f1-f0-fp0) * z^2
136  */
137
138 static double
139 interp_quad (double f0, double fp0, double f1, double zl, double zh)
140 {
141     double fl = f0 + zl*(fp0 + zl*(f1 - f0 -fp0));
142     double fh = f0 + zh*(fp0 + zh*(f1 - f0 -fp0));
143     double c = 2 * (f1 - f0 - fp0);       /* curvature */
144     
145     double zmin = zl, fmin = fl;
146     
147     if (fh < fmin) { zmin = zh; fmin = fh; }
148     
149     if (c > 0)  /* positive curvature required for a minimum */
150     {
151         double z = -fp0 / c;      /* location of minimum */
152         if (z > zl && z < zh) {
153             double f = f0 + z*(fp0 + z*(f1 - f0 -fp0));
154             if (f < fmin) { zmin = z; fmin = f; };
155         }
156     }
157     
158     return zmin;
159 }
160
161 /**************************************************************************************************/
162
163 /* Find a minimum in x=[0,1] of the interpolating cubic through
164  * (0,f0) (1,f1) with derivatives fp0 at x=0 and fp1 at x=1.
165  *
166  * The interpolating polynomial is:
167  *
168  * c(x) = f0 + fp0 * z + eta * z^2 + xi * z^3
169  *
170  * where eta=3*(f1-f0)-2*fp0-fp1, xi=fp0+fp1-2*(f1-f0).
171  */
172
173 double cubic (double c0, double c1, double c2, double c3, double z){
174     return c0 + z * (c1 + z * (c2 + z * c3));
175 }
176
177 /**************************************************************************************************/
178
179 void check_extremum (double c0, double c1, double c2, double c3, double z,
180                      double *zmin, double *fmin){
181     /* could make an early return by testing curvature >0 for minimum */
182     
183     double y = cubic (c0, c1, c2, c3, z);
184     
185     if (y < *fmin)
186     {
187         *zmin = z;  /* accepted new point*/
188         *fmin = y;
189     }
190 }
191
192 /**************************************************************************************************/
193
194 int gsl_poly_solve_quadratic (double a, double b, double c,
195                               double *x0, double *x1)
196 {
197     
198     double disc = b * b - 4 * a * c;
199     
200     if (a == 0) /* Handle linear case */
201     {
202         if (b == 0)
203         {
204             return 0;
205         }
206         else
207         {
208             *x0 = -c / b;
209             return 1;
210         };
211     }
212     
213     if (disc > 0)
214     {
215         if (b == 0)
216         {
217             double r = fabs (0.5 * sqrt (disc) / a);
218             *x0 = -r;
219             *x1 =  r;
220         }
221         else
222         {
223             double sgnb = (b > 0 ? 1 : -1);
224             double temp = -0.5 * (b + sgnb * sqrt (disc));
225             double r1 = temp / a ;
226             double r2 = c / temp ;
227             
228             if (r1 < r2)
229             {
230                 *x0 = r1 ;
231                 *x1 = r2 ;
232             }
233             else
234             {
235                 *x0 = r2 ;
236                 *x1 = r1 ;
237             }
238         }
239         return 2;
240     }
241     else if (disc == 0)
242     {
243         *x0 = -0.5 * b / a ;
244         *x1 = -0.5 * b / a ;
245         return 2 ;
246     }
247     else
248     {
249         return 0;
250     }
251     
252 }
253
254 /**************************************************************************************************/
255
256 double interp_cubic (double f0, double fp0, double f1, double fp1, double zl, double zh){
257     double eta = 3 * (f1 - f0) - 2 * fp0 - fp1;
258     double xi = fp0 + fp1 - 2 * (f1 - f0);
259     double c0 = f0, c1 = fp0, c2 = eta, c3 = xi;
260     double zmin, fmin;
261     double z0, z1;
262     
263     zmin = zl; fmin = cubic(c0, c1, c2, c3, zl);
264     check_extremum (c0, c1, c2, c3, zh, &zmin, &fmin);
265     
266     {
267         int n = gsl_poly_solve_quadratic (3 * c3, 2 * c2, c1, &z0, &z1);
268         
269         if (n == 2)  /* found 2 roots */
270         {
271             if (z0 > zl && z0 < zh)
272                 check_extremum (c0, c1, c2, c3, z0, &zmin, &fmin);
273             if (z1 > zl && z1 < zh)
274                 check_extremum (c0, c1, c2, c3, z1, &zmin, &fmin);
275         }
276         else if (n == 1)  /* found 1 root */
277         {
278             if (z0 > zl && z0 < zh)
279                 check_extremum (c0, c1, c2, c3, z0, &zmin, &fmin);
280         }
281     }
282     
283     return zmin;
284 }
285
286 /**************************************************************************************************/
287
288 double interpolate (double a, double fa, double fpa,
289                     double b, double fb, double fpb, double xmin, double xmax){
290     /* Map [a,b] to [0,1] */
291     double z, alpha, zmin, zmax;
292     
293     zmin = (xmin - a) / (b - a);
294     zmax = (xmax - a) / (b - a);
295     
296     if (zmin > zmax)
297     {
298         double tmp = zmin;
299         zmin = zmax;
300         zmax = tmp;
301     };
302     
303     if(!isnan(fpb) ){
304         z = interp_cubic (fa, fpa * (b - a), fb, fpb * (b - a), zmin, zmax);
305     }
306     else{
307         z = interp_quad(fa, fpa * (b - a), fb, zmin, zmax);
308     }
309     
310     
311     alpha = a + z * (b - a);
312     
313     return alpha;
314 }
315
316 /**************************************************************************************************/
317
318 int qFinderDMM::lineMinimizeFletcher(vector<double>& x, vector<double>& p, double f0, double df0, double alpha1, double& alphaNew, double& fAlpha, vector<double>& xalpha, vector<double>& gradient ){
319     try {
320         
321         double rho = 0.01;
322         double sigma = 0.10;
323         double tau1 = 9.00;
324         double tau2 = 0.05;
325         double tau3 = 0.50;
326         
327         double alpha = alpha1;
328         double alpha_prev = 0.0000;
329         
330         xalpha.resize(numOTUs, 0.0000);
331         
332         double falpha_prev = f0;
333         double dfalpha_prev = df0;
334         
335         double a = 0.0000;          double b = alpha;
336         double fa = f0;             double fb = 0.0000;
337         double dfa = df0;           double dfb = 0.0/0.0;
338         
339         int iter = 0;
340         int maxIters = 100;
341         while(iter++ < maxIters){
342             if (m->control_pressed) { break; }
343             
344             for(int i=0;i<numOTUs;i++){
345                 xalpha[i] = x[i] + alpha * p[i];
346             }
347             
348             fAlpha = negativeLogEvidenceLambdaPi(xalpha);
349             
350             if(fAlpha > f0 + alpha * rho * df0 || fAlpha >= falpha_prev){
351                 a = alpha_prev;         b = alpha;
352                 fa = falpha_prev;       fb = fAlpha;
353                 dfa = dfalpha_prev;     dfb = 0.0/0.0;
354                 break;
355             }
356             
357             negativeLogDerivEvidenceLambdaPi(xalpha, gradient);
358             double dfalpha = 0.0000;
359             for(int i=0;i<numOTUs;i++){ dfalpha += gradient[i] * p[i]; }
360             
361             if(abs(dfalpha) <= -sigma * df0){
362                 alphaNew = alpha;
363                 return 1;
364             }
365             
366             if(dfalpha >= 0){
367                 a = alpha;                  b = alpha_prev;
368                 fa = fAlpha;                fb = falpha_prev;
369                 dfa = dfalpha;              dfb = dfalpha_prev;
370                 break;
371             }
372             
373             double delta = alpha - alpha_prev;
374             
375             double lower = alpha + delta;
376             double upper = alpha + tau1 * delta;
377             
378             double alphaNext = interpolate(alpha_prev, falpha_prev, dfalpha_prev, alpha, fAlpha, dfalpha, lower, upper);
379             
380             alpha_prev = alpha;
381             falpha_prev = fAlpha;
382             dfalpha_prev = dfalpha;
383             alpha = alphaNext;
384         }
385         
386         iter = 0;
387         while(iter++ < maxIters){
388             if (m->control_pressed) { break; }
389             
390             double delta = b - a;
391             
392             double lower = a + tau2 * delta;
393             double upper = b - tau3 * delta;
394             
395             alpha = interpolate(a, fa, dfa, b, fb, dfb, lower, upper);
396             
397             for(int i=0;i<numOTUs;i++){
398                 xalpha[i] = x[i] + alpha * p[i];
399             }
400             
401             fAlpha = negativeLogEvidenceLambdaPi(xalpha);
402             
403             if((a - alpha) * dfa <= EPSILON){
404                 return 0;
405             }
406             
407             if(fAlpha > f0 + rho * alpha * df0 || fAlpha >= fa){
408                 b = alpha;
409                 fb = fAlpha;
410                 dfb = 0.0/0.0;
411             }
412             else{
413                 double dfalpha = 0.0000;
414                 
415                 negativeLogDerivEvidenceLambdaPi(xalpha, gradient);
416                 dfalpha = 0.0000;
417                 for(int i=0;i<numOTUs;i++){ dfalpha += gradient[i] * p[i]; }
418                 
419                 if(abs(dfalpha) <= -sigma * df0){
420                     alphaNew = alpha;
421                     return 1;
422                 }
423                 
424                 if(((b-a >= 0 && dfalpha >= 0) || ((b-a) <= 0.000 && dfalpha <= 0))){
425                     b = a;      fb = fa;        dfb = dfa;
426                     a = alpha;  fa = fAlpha;    dfa = dfalpha;
427                 }
428                 else{
429                     a = alpha;
430                     fa = fAlpha;
431                     dfa = dfalpha;
432                 }
433             }
434             
435             
436         }
437         
438         return 1;
439     }
440         catch(exception& e) {
441                 m->errorOut(e, "qFinderDMM", "lineMinimizeFletcher");
442                 exit(1);
443         }
444 }
445
446 /**************************************************************************************************/
447
448 int qFinderDMM::bfgs2_Solver(vector<double>& x){
449     try{
450 //        cout << "bfgs2_Solver" << endl;
451         int bfgsIter = 0;
452         double step = 1.0e-6;
453         double delta_f = 0.0000;//f-f0;
454
455         vector<double> gradient;
456         double f = negativeLogEvidenceLambdaPi(x);
457         
458 //        cout << "after negLE" << endl;
459         
460         negativeLogDerivEvidenceLambdaPi(x, gradient);
461
462 //        cout << "after negLDE" << endl;
463
464         vector<double> x0 = x;
465         vector<double> g0 = gradient;
466
467         double g0norm = 0;
468         for(int i=0;i<numOTUs;i++){
469             g0norm += g0[i] * g0[i];
470         }
471         g0norm = sqrt(g0norm);
472
473         vector<double> p = gradient;
474         double pNorm = 0;
475         for(int i=0;i<numOTUs;i++){
476             p[i] *= -1 / g0norm;
477             pNorm += p[i] * p[i];
478         }
479         pNorm = sqrt(pNorm);
480         double df0 = -g0norm;
481
482         int maxIter = 5000;
483         
484 //        cout << "before while" << endl;
485         
486         while(g0norm > 0.001 && bfgsIter++ < maxIter){
487             if (m->control_pressed) {  return 0; }
488
489             double f0 = f;
490             vector<double> dx(numOTUs, 0.0000);
491             
492             double alphaOld, alphaNew;
493
494             if(pNorm == 0 || g0norm == 0 || df0 == 0){
495                 dx.assign(numOTUs, 0.0000);
496                 break;
497             }
498             if(delta_f < 0){
499                 double delta = max(-delta_f, 10 * EPSILON * abs(f0));
500                 alphaOld = min(1.0, 2.0 * delta / (-df0));
501             }
502             else{
503                 alphaOld = step;
504             }
505             
506             int success = lineMinimizeFletcher(x0, p, f0, df0, alphaOld, alphaNew, f, x, gradient);
507             
508             if(!success){
509                 x = x0;
510                 break;   
511             }
512             
513             delta_f = f - f0;
514             
515             vector<double> dx0(numOTUs);
516             vector<double> dg0(numOTUs);
517             
518             for(int i=0;i<numOTUs;i++){
519                 dx0[i] = x[i] - x0[i];
520                 dg0[i] = gradient[i] - g0[i];
521             }
522             
523             double dxg = 0;
524             double dgg = 0;
525             double dxdg = 0;
526             double dgnorm = 0;
527             
528             for(int i=0;i<numOTUs;i++){
529                 dxg += dx0[i] * gradient[i];
530                 dgg += dg0[i] * gradient[i];
531                 dxdg += dx0[i] * dg0[i];
532                 dgnorm += dg0[i] * dg0[i];
533             }
534             dgnorm = sqrt(dgnorm);
535             
536             double A, B;
537             
538             if(dxdg != 0){
539                 B = dxg / dxdg;
540                 A = -(1.0 + dgnorm*dgnorm /dxdg) * B + dgg / dxdg;            
541             }
542             else{
543                 B = 0;
544                 A = 0;
545             }
546             
547             for(int i=0;i<numOTUs;i++){     p[i] = gradient[i] - A * dx0[i] - B * dg0[i];   }
548             
549             x0 = x;
550             g0 = gradient;
551             
552
553             double pg = 0;
554             pNorm = 0.0000;
555             g0norm = 0.0000;
556             
557             for(int i=0;i<numOTUs;i++){
558                 pg += p[i] * gradient[i];
559                 pNorm += p[i] * p[i];
560                 g0norm += g0[i] * g0[i];
561             }
562             pNorm = sqrt(pNorm);
563             g0norm = sqrt(g0norm);
564             
565             double dir = (pg >= 0.0) ? -1.0 : +1.0;
566
567             for(int i=0;i<numOTUs;i++){ p[i] *= dir / pNorm;    }
568             
569             pNorm = 0.0000;
570             df0 = 0.0000;
571             for(int i=0;i<numOTUs;i++){
572                 pNorm += p[i] * p[i];       
573                 df0 += p[i] * g0[i];
574             }
575             
576             pNorm = sqrt(pNorm);
577
578         }
579 //        cout << "bfgsIter:\t" << bfgsIter << endl;
580
581         return bfgsIter;
582     }
583     catch(exception& e){
584         m->errorOut(e, "qFinderDMM", "bfgs2_Solver");
585         exit(1);
586     }
587 }
588
589
590 /**************************************************************************************************/
591
592 double qFinderDMM::negativeLogEvidenceLambdaPi(vector<double>& x){
593     try{
594         vector<double> sumAlphaX(numSamples, 0.0000);
595         
596         double logEAlpha = 0.0000;
597         double sumLambda = 0.0000;
598         double sumAlpha = 0.0000;
599         double logE = 0.0000;
600         double nu = 0.10000;
601         double eta = 0.10000;
602         
603         double weight = 0.00000;
604         for(int i=0;i<numSamples;i++){
605             weight += zMatrix[currentPartition][i];
606         }
607         
608         for(int i=0;i<numOTUs;i++){
609             if (m->control_pressed) {  return 0; }
610             double lambda = x[i];
611             double alpha = exp(x[i]);
612             logEAlpha += lgamma(alpha);
613             sumLambda += lambda;
614             sumAlpha += alpha;
615             
616             for(int j=0;j<numSamples;j++){
617                 double X = countMatrix[j][i];
618                 double alphaX = alpha + X;
619                 sumAlphaX[j] += alphaX;
620                 
621                 logE -= zMatrix[currentPartition][j] * lgamma(alphaX);
622             }
623         }
624         
625         logEAlpha -= lgamma(sumAlpha);
626
627         for(int i=0;i<numSamples;i++){
628             logE += zMatrix[currentPartition][i] * lgamma(sumAlphaX[i]);
629         }
630
631         return logE + weight * logEAlpha + nu * sumAlpha - eta * sumLambda;
632     }
633     catch(exception& e){
634         m->errorOut(e, "qFinderDMM", "negativeLogEvidenceLambdaPi");
635         exit(1);
636     }
637 }
638
639 /**************************************************************************************************/
640
641 void qFinderDMM::negativeLogDerivEvidenceLambdaPi(vector<double>& x, vector<double>& df){
642     try{
643 //        cout << "\tstart negativeLogDerivEvidenceLambdaPi" << endl;
644         
645         vector<double> storeVector(numSamples, 0.0000);
646         vector<double> derivative(numOTUs, 0.0000);
647         vector<double> alpha(numOTUs, 0.0000);
648         
649         double store = 0.0000;
650         double nu = 0.1000;
651         double eta = 0.1000;
652         
653         double weight = 0.0000;
654         for(int i=0;i<numSamples;i++){
655             weight += zMatrix[currentPartition][i];
656         }
657
658         
659         for(int i=0;i<numOTUs;i++){
660             if (m->control_pressed) {  return; }
661 //            cout << "start i loop" << endl;
662 //            
663 //            cout << i << '\t' << alpha[i] << '\t' << x[i] << '\t' << exp(x[i]) << '\t' << store << endl;
664             
665             alpha[i] = exp(x[i]);
666             store += alpha[i];
667             
668 //            cout << "before derivative" << endl;
669             
670             derivative[i] = weight * psi(alpha[i]);
671
672 //            cout << "after derivative" << endl;
673
674 //            cout << i << '\t' << alpha[i] << '\t' << psi(alpha[i]) << '\t' << derivative[i] << endl;
675
676             for(int j=0;j<numSamples;j++){
677                 double X = countMatrix[j][i];
678                 double alphaX = X + alpha[i];
679                 
680                 derivative[i] -= zMatrix[currentPartition][j] * psi(alphaX);
681                 storeVector[j] += alphaX;
682             }
683 //            cout << "end i loop" << endl;
684         }
685
686         double sumStore = 0.0000;
687         for(int i=0;i<numSamples;i++){
688             sumStore += zMatrix[currentPartition][i] * psi(storeVector[i]);
689         }
690         
691         store = weight * psi(store);
692         
693         df.resize(numOTUs, 0.0000);
694         
695         for(int i=0;i<numOTUs;i++){
696             df[i] = alpha[i] * (nu + derivative[i] - store + sumStore) - eta;
697 //            cout << i << '\t' << df[i] << endl;
698         }
699 //        cout << df.size() << endl;
700 //        cout << "\tend negativeLogDerivEvidenceLambdaPi" << endl;
701     }
702     catch(exception& e){
703          m->errorOut(e, "qFinderDMM", "negativeLogDerivEvidenceLambdaPi");
704         exit(1);
705     }
706 }
707
708 /**************************************************************************************************/
709
710 double qFinderDMM::getNegativeLogEvidence(vector<double>& lambda, int group){
711     try {
712         double sumAlpha = 0.0000;
713         double sumAlphaX = 0.0000;
714         double sumLnGamAlpha = 0.0000;
715         double logEvidence = 0.0000;
716         
717         for(int i=0;i<numOTUs;i++){
718             if (m->control_pressed) {  return 0; }
719             double alpha = exp(lambda[i]);
720             double X = countMatrix[group][i];
721             double alphaX = alpha + X;
722             
723             sumLnGamAlpha += lgamma(alpha);
724             sumAlpha += alpha;
725             sumAlphaX += alphaX;
726             
727             logEvidence -= lgamma(alphaX);
728         }
729         
730         sumLnGamAlpha -= lgamma(sumAlpha);
731         logEvidence += lgamma(sumAlphaX);
732         
733         return logEvidence + sumLnGamAlpha;
734     }
735     catch(exception& e){
736         m->errorOut(e, "qFinderDMM", "getNegativeLogEvidence");
737         exit(1);
738     }
739 }
740
741 /**************************************************************************************************/
742
743 void qFinderDMM::optimizeLambda(){    
744     try {
745         for(currentPartition=0;currentPartition<numPartitions;currentPartition++){
746             if (m->control_pressed) {  return; }
747             bfgs2_Solver(lambdaMatrix[currentPartition]);
748         }
749     }
750     catch(exception& e){
751         m->errorOut(e, "qFinderDMM", "optimizeLambda");
752         exit(1);
753     }
754 }
755 /**************************************************************************************************/
756
757 void qFinderDMM::calculatePiK(){
758     try {
759         vector<double> store(numPartitions);
760         
761         for(int i=0;i<numSamples;i++){
762             if (m->control_pressed) {  return; }
763             double sum = 0.0000;
764             double minNegLogEvidence =numeric_limits<double>::max();
765             
766             for(int j=0;j<numPartitions;j++){
767                 double negLogEvidenceJ = getNegativeLogEvidence(lambdaMatrix[j], i);
768                 
769                 if(negLogEvidenceJ < minNegLogEvidence){
770                     minNegLogEvidence = negLogEvidenceJ;
771                 }
772                 store[j] = negLogEvidenceJ;
773             }
774             
775             for(int j=0;j<numPartitions;j++){
776                 if (m->control_pressed) {  return; }
777                 zMatrix[j][i] = weights[j] * exp(-(store[j] - minNegLogEvidence));
778                 sum += zMatrix[j][i];
779             }
780             
781             for(int j=0;j<numPartitions;j++){
782                 zMatrix[j][i] /= sum;
783             }
784             
785         }
786     }
787     catch(exception& e){
788         m->errorOut(e, "qFinderDMM", "calculatePiK");
789         exit(1);
790     }
791     
792 }
793
794 /**************************************************************************************************/
795
796 double qFinderDMM::getNegativeLogLikelihood(){
797     try {
798         double eta = 0.10000;
799         double nu = 0.10000;
800         
801         vector<double> pi(numPartitions, 0.0000);
802         vector<double> logBAlpha(numPartitions, 0.0000);
803         
804         double doubleSum = 0.0000;
805         
806         for(int i=0;i<numPartitions;i++){
807             if (m->control_pressed) {  return 0; }
808             double sumAlphaK = 0.0000;
809             
810             pi[i] = weights[i] / (double)numSamples;
811             
812             for(int j=0;j<numOTUs;j++){
813                 double alpha = exp(lambdaMatrix[i][j]);
814                 sumAlphaK += alpha;
815                 
816                 logBAlpha[i] += lgamma(alpha);
817             }
818             logBAlpha[i] -= lgamma(sumAlphaK);
819         }
820         
821         for(int i=0;i<numSamples;i++){
822             if (m->control_pressed) {  return 0; }
823             
824             double probability = 0.0000;
825             double factor = 0.0000;
826             double sum = 0.0000;
827             vector<double> logStore(numPartitions, 0.0000);
828             double offset = -numeric_limits<double>::max();
829             
830             for(int j=0;j<numOTUs;j++){
831                 sum += countMatrix[i][j];
832                 factor += lgamma(countMatrix[i][j] + 1.0000);
833             }
834             factor -= lgamma(sum + 1.0);
835             
836             for(int k=0;k<numPartitions;k++){
837                 
838                 double sumAlphaKX = 0.0000;
839                 double logBAlphaX = 0.0000;
840                 
841                 for(int j=0;j<numOTUs;j++){
842                     double alphaX = exp(lambdaMatrix[k][j]) + (double)countMatrix[i][j];
843                     
844                     sumAlphaKX += alphaX;
845                     logBAlphaX += lgamma(alphaX);
846                 }
847                 
848                 logBAlphaX -= lgamma(sumAlphaKX);
849                 
850                 logStore[k] = logBAlphaX - logBAlpha[k] - factor;
851                 if(logStore[k] > offset){
852                     offset = logStore[k];
853                 }
854                 
855             }
856             
857             for(int k=0;k<numPartitions;k++){
858                 probability += pi[k] * exp(-offset + logStore[k]);
859             }
860             doubleSum += log(probability) + offset;
861             
862         }
863         
864         double L5 = - numOTUs * numPartitions * lgamma(eta);
865         double L6 = eta * numPartitions * numOTUs * log(nu);
866         
867         double alphaSum, lambdaSum;
868         alphaSum = lambdaSum = 0.0000;
869         
870         for(int i=0;i<numPartitions;i++){
871             for(int j=0;j<numOTUs;j++){
872                 if (m->control_pressed) {  return 0; }
873                 alphaSum += exp(lambdaMatrix[i][j]);
874                 lambdaSum += lambdaMatrix[i][j];
875             }
876         }
877         alphaSum *= -nu;
878         lambdaSum *= eta;
879         
880         return (-doubleSum - L5 - L6 - alphaSum - lambdaSum);
881     }
882     catch(exception& e){
883         m->errorOut(e, "qFinderDMM", "getNegativeLogLikelihood");
884         exit(1);
885     }
886
887
888 }
889 /**************************************************************************************************/