]> git.donarmstrong.com Git - mothur.git/blob - qFinderDMM.cpp
added Jensen-Shannon calc. working on get.communitytype command. fixed bug in get...
[mothur.git] / qFinderDMM.cpp
1 //
2 //  qFinderDMM.cpp
3 //  pds_dmm
4 //
5 //  Created by Patrick Schloss on 11/8/12.
6 //  Copyright (c) 2012 University of Michigan. All rights reserved.
7 //
8
9 #include "qFinderDMM.h"
10
11
12
13 /**************************************************************************************************/
14
15 qFinderDMM::qFinderDMM(vector<vector<int> > cm, int p) : CommunityTypeFinder() {
16     try {
17         //cout << "here" << endl;
18         numPartitions = p;
19         countMatrix = cm;
20         numSamples = (int)countMatrix.size();
21         numOTUs = (int)countMatrix[0].size();
22         
23         //cout << "before kmeans" <<endl;
24         findkMeans();
25             //cout << "done kMeans" << endl;
26         
27         optimizeLambda();
28         
29         
30             //cout << "done optimizeLambda" << endl;
31         
32         double change = 1.0000;
33         currNLL = 0.0000;
34         
35         int iter = 0;
36         
37         while(change > 1.0e-6 && iter < 100){
38             
39                     //cout << "Calc_Z: ";
40             calculatePiK();
41             
42             optimizeLambda();
43             
44                     //printf("Iter:%d\t",iter);
45             
46             for(int i=0;i<numPartitions;i++){
47                 weights[i] = 0.0000;
48                 for(int j=0;j<numSamples;j++){
49                     weights[i] += zMatrix[i][j];
50                 }
51                            // printf("w_%d=%.3f\t",i,weights[i]);
52                 
53             }
54             
55             double nLL = getNegativeLogLikelihood();
56             
57             change = abs(nLL - currNLL);
58             
59             currNLL = nLL;
60             
61                    // printf("NLL=%.5f\tDelta=%.4e\n",currNLL, change);
62             
63             iter++;
64         }
65         
66         error.resize(numPartitions);
67         
68         logDeterminant = 0.0000;
69         
70         LinearAlgebra l;
71         
72         for(currentPartition=0;currentPartition<numPartitions;currentPartition++){
73             
74             error[currentPartition].assign(numOTUs, 0.0000);
75             
76             if(currentPartition > 0){
77                 logDeterminant += (2.0 * log(numSamples) - log(weights[currentPartition]));
78             }
79             vector<vector<double> > hessian = getHessian();
80             vector<vector<double> > invHessian = l.getInverse(hessian);
81             
82             for(int i=0;i<numOTUs;i++){
83                 logDeterminant += log(abs(hessian[i][i]));
84                 error[currentPartition][i] = invHessian[i][i];
85             }
86             
87         }
88         
89         int numParameters = numPartitions * numOTUs + numPartitions - 1;
90         laplace = currNLL + 0.5 * logDeterminant - 0.5 * numParameters * log(2.0 * 3.14159);
91         bic = currNLL + 0.5 * log(numSamples) * numParameters;
92         aic = currNLL + numParameters;
93     }
94         catch(exception& e) {
95                 m->errorOut(e, "qFinderDMM", "qFinderDMM");
96                 exit(1);
97         }
98 }
99 /**************************************************************************************************/
100 void qFinderDMM::printFitData(ofstream& out){
101     try {
102         out << setprecision (2) << numPartitions << '\t'  << getNLL() << '\t' << getLogDet() << '\t' <<  getBIC() << '\t' << getAIC() << '\t' << laplace << endl;
103         return;
104     }
105     catch(exception& e){
106         m->errorOut(e, "CommunityTypeFinder", "printFitData");
107         exit(1);
108     }
109 }
110 /**************************************************************************************************/
111 void qFinderDMM::printFitData(ostream& out, double minLaplace){
112     try {
113         if(laplace < minLaplace){
114             out << setprecision (2) << numPartitions << '\t'  << getNLL() << '\t' << getLogDet() << '\t' <<  getBIC() << '\t' << getAIC() << '\t' << laplace << "***" << endl;
115         }else {
116             out << setprecision (2) << numPartitions << '\t'  << getNLL() << '\t' << getLogDet() << '\t' <<  getBIC() << '\t' << getAIC() << '\t' << laplace << endl;
117         }
118         
119         m->mothurOutJustToLog(toString(numPartitions) + '\t' + toString(getNLL()) + '\t' + toString(getLogDet()) + '\t');
120         m->mothurOutJustToLog(toString(getBIC()) + '\t' + toString(getAIC()) + '\t' + toString(laplace));
121
122         return;
123     }
124     catch(exception& e){
125         m->errorOut(e, "CommunityTypeFinder", "printFitData");
126         exit(1);
127     }
128 }
129
130 /**************************************************************************************************/
131
132 // these functions for bfgs2 solver were lifted from the gnu_gsl source code...
133
134 /* Find a minimum in x=[0,1] of the interpolating quadratic through
135  * (0,f0) (1,f1) with derivative fp0 at x=0.  The interpolating
136  * polynomial is q(x) = f0 + fp0 * z + (f1-f0-fp0) * z^2
137  */
138
139 static double
140 interp_quad (double f0, double fp0, double f1, double zl, double zh)
141 {
142     double fl = f0 + zl*(fp0 + zl*(f1 - f0 -fp0));
143     double fh = f0 + zh*(fp0 + zh*(f1 - f0 -fp0));
144     double c = 2 * (f1 - f0 - fp0);       /* curvature */
145     
146     double zmin = zl, fmin = fl;
147     
148     if (fh < fmin) { zmin = zh; fmin = fh; }
149     
150     if (c > 0)  /* positive curvature required for a minimum */
151     {
152         double z = -fp0 / c;      /* location of minimum */
153         if (z > zl && z < zh) {
154             double f = f0 + z*(fp0 + z*(f1 - f0 -fp0));
155             if (f < fmin) { zmin = z; fmin = f; };
156         }
157     }
158     
159     return zmin;
160 }
161
162 /**************************************************************************************************/
163
164 /* Find a minimum in x=[0,1] of the interpolating cubic through
165  * (0,f0) (1,f1) with derivatives fp0 at x=0 and fp1 at x=1.
166  *
167  * The interpolating polynomial is:
168  *
169  * c(x) = f0 + fp0 * z + eta * z^2 + xi * z^3
170  *
171  * where eta=3*(f1-f0)-2*fp0-fp1, xi=fp0+fp1-2*(f1-f0).
172  */
173
174 double cubic (double c0, double c1, double c2, double c3, double z){
175     return c0 + z * (c1 + z * (c2 + z * c3));
176 }
177
178 /**************************************************************************************************/
179
180 void check_extremum (double c0, double c1, double c2, double c3, double z,
181                      double *zmin, double *fmin){
182     /* could make an early return by testing curvature >0 for minimum */
183     
184     double y = cubic (c0, c1, c2, c3, z);
185     
186     if (y < *fmin)
187     {
188         *zmin = z;  /* accepted new point*/
189         *fmin = y;
190     }
191 }
192
193 /**************************************************************************************************/
194
195 int gsl_poly_solve_quadratic (double a, double b, double c,
196                               double *x0, double *x1)
197 {
198     
199     double disc = b * b - 4 * a * c;
200     
201     if (a == 0) /* Handle linear case */
202     {
203         if (b == 0)
204         {
205             return 0;
206         }
207         else
208         {
209             *x0 = -c / b;
210             return 1;
211         };
212     }
213     
214     if (disc > 0)
215     {
216         if (b == 0)
217         {
218             double r = fabs (0.5 * sqrt (disc) / a);
219             *x0 = -r;
220             *x1 =  r;
221         }
222         else
223         {
224             double sgnb = (b > 0 ? 1 : -1);
225             double temp = -0.5 * (b + sgnb * sqrt (disc));
226             double r1 = temp / a ;
227             double r2 = c / temp ;
228             
229             if (r1 < r2)
230             {
231                 *x0 = r1 ;
232                 *x1 = r2 ;
233             }
234             else
235             {
236                 *x0 = r2 ;
237                 *x1 = r1 ;
238             }
239         }
240         return 2;
241     }
242     else if (disc == 0)
243     {
244         *x0 = -0.5 * b / a ;
245         *x1 = -0.5 * b / a ;
246         return 2 ;
247     }
248     else
249     {
250         return 0;
251     }
252     
253 }
254
255 /**************************************************************************************************/
256
257 double interp_cubic (double f0, double fp0, double f1, double fp1, double zl, double zh){
258     double eta = 3 * (f1 - f0) - 2 * fp0 - fp1;
259     double xi = fp0 + fp1 - 2 * (f1 - f0);
260     double c0 = f0, c1 = fp0, c2 = eta, c3 = xi;
261     double zmin, fmin;
262     double z0, z1;
263     
264     zmin = zl; fmin = cubic(c0, c1, c2, c3, zl);
265     check_extremum (c0, c1, c2, c3, zh, &zmin, &fmin);
266     
267     {
268         int n = gsl_poly_solve_quadratic (3 * c3, 2 * c2, c1, &z0, &z1);
269         
270         if (n == 2)  /* found 2 roots */
271         {
272             if (z0 > zl && z0 < zh)
273                 check_extremum (c0, c1, c2, c3, z0, &zmin, &fmin);
274             if (z1 > zl && z1 < zh)
275                 check_extremum (c0, c1, c2, c3, z1, &zmin, &fmin);
276         }
277         else if (n == 1)  /* found 1 root */
278         {
279             if (z0 > zl && z0 < zh)
280                 check_extremum (c0, c1, c2, c3, z0, &zmin, &fmin);
281         }
282     }
283     
284     return zmin;
285 }
286
287 /**************************************************************************************************/
288
289 double interpolate (double a, double fa, double fpa,
290                     double b, double fb, double fpb, double xmin, double xmax){
291     /* Map [a,b] to [0,1] */
292     double z, alpha, zmin, zmax;
293     
294     zmin = (xmin - a) / (b - a);
295     zmax = (xmax - a) / (b - a);
296     
297     if (zmin > zmax)
298     {
299         double tmp = zmin;
300         zmin = zmax;
301         zmax = tmp;
302     };
303     
304     if(!isnan(fpb) ){
305         z = interp_cubic (fa, fpa * (b - a), fb, fpb * (b - a), zmin, zmax);
306     }
307     else{
308         z = interp_quad(fa, fpa * (b - a), fb, zmin, zmax);
309     }
310     
311     
312     alpha = a + z * (b - a);
313     
314     return alpha;
315 }
316
317 /**************************************************************************************************/
318
319 int qFinderDMM::lineMinimizeFletcher(vector<double>& x, vector<double>& p, double f0, double df0, double alpha1, double& alphaNew, double& fAlpha, vector<double>& xalpha, vector<double>& gradient ){
320     try {
321         
322         double rho = 0.01;
323         double sigma = 0.10;
324         double tau1 = 9.00;
325         double tau2 = 0.05;
326         double tau3 = 0.50;
327         
328         double alpha = alpha1;
329         double alpha_prev = 0.0000;
330         
331         xalpha.resize(numOTUs, 0.0000);
332         
333         double falpha_prev = f0;
334         double dfalpha_prev = df0;
335         
336         double a = 0.0000;          double b = alpha;
337         double fa = f0;             double fb = 0.0000;
338         double dfa = df0;           double dfb = 0.0/0.0;
339         
340         int iter = 0;
341         int maxIters = 100;
342         while(iter++ < maxIters){
343             if (m->control_pressed) { break; }
344             
345             for(int i=0;i<numOTUs;i++){
346                 xalpha[i] = x[i] + alpha * p[i];
347             }
348             
349             fAlpha = negativeLogEvidenceLambdaPi(xalpha);
350             
351             if(fAlpha > f0 + alpha * rho * df0 || fAlpha >= falpha_prev){
352                 a = alpha_prev;         b = alpha;
353                 fa = falpha_prev;       fb = fAlpha;
354                 dfa = dfalpha_prev;     dfb = 0.0/0.0;
355                 break;
356             }
357             
358             negativeLogDerivEvidenceLambdaPi(xalpha, gradient);
359             double dfalpha = 0.0000;
360             for(int i=0;i<numOTUs;i++){ dfalpha += gradient[i] * p[i]; }
361             
362             if(abs(dfalpha) <= -sigma * df0){
363                 alphaNew = alpha;
364                 return 1;
365             }
366             
367             if(dfalpha >= 0){
368                 a = alpha;                  b = alpha_prev;
369                 fa = fAlpha;                fb = falpha_prev;
370                 dfa = dfalpha;              dfb = dfalpha_prev;
371                 break;
372             }
373             
374             double delta = alpha - alpha_prev;
375             
376             double lower = alpha + delta;
377             double upper = alpha + tau1 * delta;
378             
379             double alphaNext = interpolate(alpha_prev, falpha_prev, dfalpha_prev, alpha, fAlpha, dfalpha, lower, upper);
380             
381             alpha_prev = alpha;
382             falpha_prev = fAlpha;
383             dfalpha_prev = dfalpha;
384             alpha = alphaNext;
385         }
386         
387         iter = 0;
388         while(iter++ < maxIters){
389             if (m->control_pressed) { break; }
390             
391             double delta = b - a;
392             
393             double lower = a + tau2 * delta;
394             double upper = b - tau3 * delta;
395             
396             alpha = interpolate(a, fa, dfa, b, fb, dfb, lower, upper);
397             
398             for(int i=0;i<numOTUs;i++){
399                 xalpha[i] = x[i] + alpha * p[i];
400             }
401             
402             fAlpha = negativeLogEvidenceLambdaPi(xalpha);
403             
404             if((a - alpha) * dfa <= EPSILON){
405                 return 0;
406             }
407             
408             if(fAlpha > f0 + rho * alpha * df0 || fAlpha >= fa){
409                 b = alpha;
410                 fb = fAlpha;
411                 dfb = 0.0/0.0;
412             }
413             else{
414                 double dfalpha = 0.0000;
415                 
416                 negativeLogDerivEvidenceLambdaPi(xalpha, gradient);
417                 dfalpha = 0.0000;
418                 for(int i=0;i<numOTUs;i++){ dfalpha += gradient[i] * p[i]; }
419                 
420                 if(abs(dfalpha) <= -sigma * df0){
421                     alphaNew = alpha;
422                     return 1;
423                 }
424                 
425                 if(((b-a >= 0 && dfalpha >= 0) || ((b-a) <= 0.000 && dfalpha <= 0))){
426                     b = a;      fb = fa;        dfb = dfa;
427                     a = alpha;  fa = fAlpha;    dfa = dfalpha;
428                 }
429                 else{
430                     a = alpha;
431                     fa = fAlpha;
432                     dfa = dfalpha;
433                 }
434             }
435             
436             
437         }
438         
439         return 1;
440     }
441         catch(exception& e) {
442                 m->errorOut(e, "qFinderDMM", "lineMinimizeFletcher");
443                 exit(1);
444         }
445 }
446
447 /**************************************************************************************************/
448
449 int qFinderDMM::bfgs2_Solver(vector<double>& x){
450     try{
451 //        cout << "bfgs2_Solver" << endl;
452         int bfgsIter = 0;
453         double step = 1.0e-6;
454         double delta_f = 0.0000;//f-f0;
455
456         vector<double> gradient;
457         double f = negativeLogEvidenceLambdaPi(x);
458         
459 //        cout << "after negLE" << endl;
460         
461         negativeLogDerivEvidenceLambdaPi(x, gradient);
462
463 //        cout << "after negLDE" << endl;
464
465         vector<double> x0 = x;
466         vector<double> g0 = gradient;
467
468         double g0norm = 0;
469         for(int i=0;i<numOTUs;i++){
470             g0norm += g0[i] * g0[i];
471         }
472         g0norm = sqrt(g0norm);
473
474         vector<double> p = gradient;
475         double pNorm = 0;
476         for(int i=0;i<numOTUs;i++){
477             p[i] *= -1 / g0norm;
478             pNorm += p[i] * p[i];
479         }
480         pNorm = sqrt(pNorm);
481         double df0 = -g0norm;
482
483         int maxIter = 5000;
484         
485 //        cout << "before while" << endl;
486         
487         while(g0norm > 0.001 && bfgsIter++ < maxIter){
488             if (m->control_pressed) {  return 0; }
489
490             double f0 = f;
491             vector<double> dx(numOTUs, 0.0000);
492             
493             double alphaOld, alphaNew;
494
495             if(pNorm == 0 || g0norm == 0 || df0 == 0){
496                 dx.assign(numOTUs, 0.0000);
497                 break;
498             }
499             if(delta_f < 0){
500                 double delta = max(-delta_f, 10 * EPSILON * abs(f0));
501                 alphaOld = min(1.0, 2.0 * delta / (-df0));
502             }
503             else{
504                 alphaOld = step;
505             }
506             
507             int success = lineMinimizeFletcher(x0, p, f0, df0, alphaOld, alphaNew, f, x, gradient);
508             
509             if(!success){
510                 x = x0;
511                 break;   
512             }
513             
514             delta_f = f - f0;
515             
516             vector<double> dx0(numOTUs);
517             vector<double> dg0(numOTUs);
518             
519             for(int i=0;i<numOTUs;i++){
520                 dx0[i] = x[i] - x0[i];
521                 dg0[i] = gradient[i] - g0[i];
522             }
523             
524             double dxg = 0;
525             double dgg = 0;
526             double dxdg = 0;
527             double dgnorm = 0;
528             
529             for(int i=0;i<numOTUs;i++){
530                 dxg += dx0[i] * gradient[i];
531                 dgg += dg0[i] * gradient[i];
532                 dxdg += dx0[i] * dg0[i];
533                 dgnorm += dg0[i] * dg0[i];
534             }
535             dgnorm = sqrt(dgnorm);
536             
537             double A, B;
538             
539             if(dxdg != 0){
540                 B = dxg / dxdg;
541                 A = -(1.0 + dgnorm*dgnorm /dxdg) * B + dgg / dxdg;            
542             }
543             else{
544                 B = 0;
545                 A = 0;
546             }
547             
548             for(int i=0;i<numOTUs;i++){     p[i] = gradient[i] - A * dx0[i] - B * dg0[i];   }
549             
550             x0 = x;
551             g0 = gradient;
552             
553
554             double pg = 0;
555             pNorm = 0.0000;
556             g0norm = 0.0000;
557             
558             for(int i=0;i<numOTUs;i++){
559                 pg += p[i] * gradient[i];
560                 pNorm += p[i] * p[i];
561                 g0norm += g0[i] * g0[i];
562             }
563             pNorm = sqrt(pNorm);
564             g0norm = sqrt(g0norm);
565             
566             double dir = (pg >= 0.0) ? -1.0 : +1.0;
567
568             for(int i=0;i<numOTUs;i++){ p[i] *= dir / pNorm;    }
569             
570             pNorm = 0.0000;
571             df0 = 0.0000;
572             for(int i=0;i<numOTUs;i++){
573                 pNorm += p[i] * p[i];       
574                 df0 += p[i] * g0[i];
575             }
576             
577             pNorm = sqrt(pNorm);
578
579         }
580 //        cout << "bfgsIter:\t" << bfgsIter << endl;
581
582         return bfgsIter;
583     }
584     catch(exception& e){
585         m->errorOut(e, "qFinderDMM", "bfgs2_Solver");
586         exit(1);
587     }
588 }
589
590
591 /**************************************************************************************************/
592
593 double qFinderDMM::negativeLogEvidenceLambdaPi(vector<double>& x){
594     try{
595         vector<double> sumAlphaX(numSamples, 0.0000);
596         
597         double logEAlpha = 0.0000;
598         double sumLambda = 0.0000;
599         double sumAlpha = 0.0000;
600         double logE = 0.0000;
601         double nu = 0.10000;
602         double eta = 0.10000;
603         
604         double weight = 0.00000;
605         for(int i=0;i<numSamples;i++){
606             weight += zMatrix[currentPartition][i];
607         }
608         
609         for(int i=0;i<numOTUs;i++){
610             if (m->control_pressed) {  return 0; }
611             double lambda = x[i];
612             double alpha = exp(x[i]);
613             logEAlpha += lgamma(alpha);
614             sumLambda += lambda;
615             sumAlpha += alpha;
616             
617             for(int j=0;j<numSamples;j++){
618                 double X = countMatrix[j][i];
619                 double alphaX = alpha + X;
620                 sumAlphaX[j] += alphaX;
621                 
622                 logE -= zMatrix[currentPartition][j] * lgamma(alphaX);
623             }
624         }
625         
626         logEAlpha -= lgamma(sumAlpha);
627
628         for(int i=0;i<numSamples;i++){
629             logE += zMatrix[currentPartition][i] * lgamma(sumAlphaX[i]);
630         }
631
632         return logE + weight * logEAlpha + nu * sumAlpha - eta * sumLambda;
633     }
634     catch(exception& e){
635         m->errorOut(e, "qFinderDMM", "negativeLogEvidenceLambdaPi");
636         exit(1);
637     }
638 }
639
640 /**************************************************************************************************/
641
642 void qFinderDMM::negativeLogDerivEvidenceLambdaPi(vector<double>& x, vector<double>& df){
643     try{
644 //        cout << "\tstart negativeLogDerivEvidenceLambdaPi" << endl;
645         
646         vector<double> storeVector(numSamples, 0.0000);
647         vector<double> derivative(numOTUs, 0.0000);
648         vector<double> alpha(numOTUs, 0.0000);
649         
650         double store = 0.0000;
651         double nu = 0.1000;
652         double eta = 0.1000;
653         
654         double weight = 0.0000;
655         for(int i=0;i<numSamples;i++){
656             weight += zMatrix[currentPartition][i];
657         }
658
659         
660         for(int i=0;i<numOTUs;i++){
661             if (m->control_pressed) {  return; }
662 //            cout << "start i loop" << endl;
663 //            
664 //            cout << i << '\t' << alpha[i] << '\t' << x[i] << '\t' << exp(x[i]) << '\t' << store << endl;
665             
666             alpha[i] = exp(x[i]);
667             store += alpha[i];
668             
669 //            cout << "before derivative" << endl;
670             
671             derivative[i] = weight * psi(alpha[i]);
672
673 //            cout << "after derivative" << endl;
674
675 //            cout << i << '\t' << alpha[i] << '\t' << psi(alpha[i]) << '\t' << derivative[i] << endl;
676
677             for(int j=0;j<numSamples;j++){
678                 double X = countMatrix[j][i];
679                 double alphaX = X + alpha[i];
680                 
681                 derivative[i] -= zMatrix[currentPartition][j] * psi(alphaX);
682                 storeVector[j] += alphaX;
683             }
684 //            cout << "end i loop" << endl;
685         }
686
687         double sumStore = 0.0000;
688         for(int i=0;i<numSamples;i++){
689             sumStore += zMatrix[currentPartition][i] * psi(storeVector[i]);
690         }
691         
692         store = weight * psi(store);
693         
694         df.resize(numOTUs, 0.0000);
695         
696         for(int i=0;i<numOTUs;i++){
697             df[i] = alpha[i] * (nu + derivative[i] - store + sumStore) - eta;
698 //            cout << i << '\t' << df[i] << endl;
699         }
700 //        cout << df.size() << endl;
701 //        cout << "\tend negativeLogDerivEvidenceLambdaPi" << endl;
702     }
703     catch(exception& e){
704          m->errorOut(e, "qFinderDMM", "negativeLogDerivEvidenceLambdaPi");
705         exit(1);
706     }
707 }
708
709 /**************************************************************************************************/
710
711 double qFinderDMM::getNegativeLogEvidence(vector<double>& lambda, int group){
712     try {
713         double sumAlpha = 0.0000;
714         double sumAlphaX = 0.0000;
715         double sumLnGamAlpha = 0.0000;
716         double logEvidence = 0.0000;
717         
718         for(int i=0;i<numOTUs;i++){
719             if (m->control_pressed) {  return 0; }
720             double alpha = exp(lambda[i]);
721             double X = countMatrix[group][i];
722             double alphaX = alpha + X;
723             
724             sumLnGamAlpha += lgamma(alpha);
725             sumAlpha += alpha;
726             sumAlphaX += alphaX;
727             
728             logEvidence -= lgamma(alphaX);
729         }
730         
731         sumLnGamAlpha -= lgamma(sumAlpha);
732         logEvidence += lgamma(sumAlphaX);
733         
734         return logEvidence + sumLnGamAlpha;
735     }
736     catch(exception& e){
737         m->errorOut(e, "qFinderDMM", "getNegativeLogEvidence");
738         exit(1);
739     }
740 }
741
742 /**************************************************************************************************/
743
744 void qFinderDMM::optimizeLambda(){    
745     try {
746         for(currentPartition=0;currentPartition<numPartitions;currentPartition++){
747             if (m->control_pressed) {  return; }
748             bfgs2_Solver(lambdaMatrix[currentPartition]);
749         }
750     }
751     catch(exception& e){
752         m->errorOut(e, "qFinderDMM", "optimizeLambda");
753         exit(1);
754     }
755 }
756 /**************************************************************************************************/
757
758 void qFinderDMM::calculatePiK(){
759     try {
760         vector<double> store(numPartitions);
761         
762         for(int i=0;i<numSamples;i++){
763             if (m->control_pressed) {  return; }
764             double sum = 0.0000;
765             double minNegLogEvidence =numeric_limits<double>::max();
766             
767             for(int j=0;j<numPartitions;j++){
768                 double negLogEvidenceJ = getNegativeLogEvidence(lambdaMatrix[j], i);
769                 
770                 if(negLogEvidenceJ < minNegLogEvidence){
771                     minNegLogEvidence = negLogEvidenceJ;
772                 }
773                 store[j] = negLogEvidenceJ;
774             }
775             
776             for(int j=0;j<numPartitions;j++){
777                 if (m->control_pressed) {  return; }
778                 zMatrix[j][i] = weights[j] * exp(-(store[j] - minNegLogEvidence));
779                 sum += zMatrix[j][i];
780             }
781             
782             for(int j=0;j<numPartitions;j++){
783                 zMatrix[j][i] /= sum;
784             }
785             
786         }
787     }
788     catch(exception& e){
789         m->errorOut(e, "qFinderDMM", "calculatePiK");
790         exit(1);
791     }
792     
793 }
794
795 /**************************************************************************************************/
796
797 double qFinderDMM::getNegativeLogLikelihood(){
798     try {
799         double eta = 0.10000;
800         double nu = 0.10000;
801         
802         vector<double> pi(numPartitions, 0.0000);
803         vector<double> logBAlpha(numPartitions, 0.0000);
804         
805         double doubleSum = 0.0000;
806         
807         for(int i=0;i<numPartitions;i++){
808             if (m->control_pressed) {  return 0; }
809             double sumAlphaK = 0.0000;
810             
811             pi[i] = weights[i] / (double)numSamples;
812             
813             for(int j=0;j<numOTUs;j++){
814                 double alpha = exp(lambdaMatrix[i][j]);
815                 sumAlphaK += alpha;
816                 
817                 logBAlpha[i] += lgamma(alpha);
818             }
819             logBAlpha[i] -= lgamma(sumAlphaK);
820         }
821         
822         for(int i=0;i<numSamples;i++){
823             if (m->control_pressed) {  return 0; }
824             
825             double probability = 0.0000;
826             double factor = 0.0000;
827             double sum = 0.0000;
828             vector<double> logStore(numPartitions, 0.0000);
829             double offset = -numeric_limits<double>::max();
830             
831             for(int j=0;j<numOTUs;j++){
832                 sum += countMatrix[i][j];
833                 factor += lgamma(countMatrix[i][j] + 1.0000);
834             }
835             factor -= lgamma(sum + 1.0);
836             
837             for(int k=0;k<numPartitions;k++){
838                 
839                 double sumAlphaKX = 0.0000;
840                 double logBAlphaX = 0.0000;
841                 
842                 for(int j=0;j<numOTUs;j++){
843                     double alphaX = exp(lambdaMatrix[k][j]) + (double)countMatrix[i][j];
844                     
845                     sumAlphaKX += alphaX;
846                     logBAlphaX += lgamma(alphaX);
847                 }
848                 
849                 logBAlphaX -= lgamma(sumAlphaKX);
850                 
851                 logStore[k] = logBAlphaX - logBAlpha[k] - factor;
852                 if(logStore[k] > offset){
853                     offset = logStore[k];
854                 }
855                 
856             }
857             
858             for(int k=0;k<numPartitions;k++){
859                 probability += pi[k] * exp(-offset + logStore[k]);
860             }
861             doubleSum += log(probability) + offset;
862             
863         }
864         
865         double L5 = - numOTUs * numPartitions * lgamma(eta);
866         double L6 = eta * numPartitions * numOTUs * log(nu);
867         
868         double alphaSum, lambdaSum;
869         alphaSum = lambdaSum = 0.0000;
870         
871         for(int i=0;i<numPartitions;i++){
872             for(int j=0;j<numOTUs;j++){
873                 if (m->control_pressed) {  return 0; }
874                 alphaSum += exp(lambdaMatrix[i][j]);
875                 lambdaSum += lambdaMatrix[i][j];
876             }
877         }
878         alphaSum *= -nu;
879         lambdaSum *= eta;
880         
881         return (-doubleSum - L5 - L6 - alphaSum - lambdaSum);
882     }
883     catch(exception& e){
884         m->errorOut(e, "qFinderDMM", "getNegativeLogLikelihood");
885         exit(1);
886     }
887
888
889 }
890 /**************************************************************************************************/