5 // Created by Patrick Schloss on 11/8/12.
6 // Copyright (c) 2012 University of Michigan. All rights reserved.
9 #include "qFinderDMM.h"
13 /**************************************************************************************************/
15 qFinderDMM::qFinderDMM(vector<vector<int> > cm, int p) : CommunityTypeFinder() {
17 //cout << "here" << endl;
20 numSamples = (int)countMatrix.size();
21 numOTUs = (int)countMatrix[0].size();
23 //cout << "before kmeans" <<endl;
25 //cout << "done kMeans" << endl;
30 //cout << "done optimizeLambda" << endl;
32 double change = 1.0000;
37 while(change > 1.0e-6 && iter < 100){
44 //printf("Iter:%d\t",iter);
46 for(int i=0;i<numPartitions;i++){
48 for(int j=0;j<numSamples;j++){
49 weights[i] += zMatrix[i][j];
51 // printf("w_%d=%.3f\t",i,weights[i]);
55 double nLL = getNegativeLogLikelihood();
57 change = abs(nLL - currNLL);
61 // printf("NLL=%.5f\tDelta=%.4e\n",currNLL, change);
66 error.resize(numPartitions);
68 logDeterminant = 0.0000;
72 for(currentPartition=0;currentPartition<numPartitions;currentPartition++){
74 error[currentPartition].assign(numOTUs, 0.0000);
76 if(currentPartition > 0){
77 logDeterminant += (2.0 * log(numSamples) - log(weights[currentPartition]));
79 vector<vector<double> > hessian = getHessian();
80 vector<vector<double> > invHessian = l.getInverse(hessian);
82 for(int i=0;i<numOTUs;i++){
83 logDeterminant += log(abs(hessian[i][i]));
84 error[currentPartition][i] = invHessian[i][i];
89 int numParameters = numPartitions * numOTUs + numPartitions - 1;
90 laplace = currNLL + 0.5 * logDeterminant - 0.5 * numParameters * log(2.0 * 3.14159);
91 bic = currNLL + 0.5 * log(numSamples) * numParameters;
92 aic = currNLL + numParameters;
95 m->errorOut(e, "qFinderDMM", "qFinderDMM");
99 /**************************************************************************************************/
100 void qFinderDMM::printFitData(ofstream& out){
102 out << setprecision (2) << numPartitions << '\t' << getNLL() << '\t' << getLogDet() << '\t' << getBIC() << '\t' << getAIC() << '\t' << laplace << endl;
106 m->errorOut(e, "CommunityTypeFinder", "printFitData");
110 /**************************************************************************************************/
111 void qFinderDMM::printFitData(ostream& out, double minLaplace){
113 if(laplace < minLaplace){
114 out << setprecision (2) << numPartitions << '\t' << getNLL() << '\t' << getLogDet() << '\t' << getBIC() << '\t' << getAIC() << '\t' << laplace << "***" << endl;
116 out << setprecision (2) << numPartitions << '\t' << getNLL() << '\t' << getLogDet() << '\t' << getBIC() << '\t' << getAIC() << '\t' << laplace << endl;
119 m->mothurOutJustToLog(toString(numPartitions) + '\t' + toString(getNLL()) + '\t' + toString(getLogDet()) + '\t');
120 m->mothurOutJustToLog(toString(getBIC()) + '\t' + toString(getAIC()) + '\t' + toString(laplace));
125 m->errorOut(e, "CommunityTypeFinder", "printFitData");
130 /**************************************************************************************************/
132 // these functions for bfgs2 solver were lifted from the gnu_gsl source code...
134 /* Find a minimum in x=[0,1] of the interpolating quadratic through
135 * (0,f0) (1,f1) with derivative fp0 at x=0. The interpolating
136 * polynomial is q(x) = f0 + fp0 * z + (f1-f0-fp0) * z^2
140 interp_quad (double f0, double fp0, double f1, double zl, double zh)
142 double fl = f0 + zl*(fp0 + zl*(f1 - f0 -fp0));
143 double fh = f0 + zh*(fp0 + zh*(f1 - f0 -fp0));
144 double c = 2 * (f1 - f0 - fp0); /* curvature */
146 double zmin = zl, fmin = fl;
148 if (fh < fmin) { zmin = zh; fmin = fh; }
150 if (c > 0) /* positive curvature required for a minimum */
152 double z = -fp0 / c; /* location of minimum */
153 if (z > zl && z < zh) {
154 double f = f0 + z*(fp0 + z*(f1 - f0 -fp0));
155 if (f < fmin) { zmin = z; fmin = f; };
162 /**************************************************************************************************/
164 /* Find a minimum in x=[0,1] of the interpolating cubic through
165 * (0,f0) (1,f1) with derivatives fp0 at x=0 and fp1 at x=1.
167 * The interpolating polynomial is:
169 * c(x) = f0 + fp0 * z + eta * z^2 + xi * z^3
171 * where eta=3*(f1-f0)-2*fp0-fp1, xi=fp0+fp1-2*(f1-f0).
174 double cubic (double c0, double c1, double c2, double c3, double z){
175 return c0 + z * (c1 + z * (c2 + z * c3));
178 /**************************************************************************************************/
180 void check_extremum (double c0, double c1, double c2, double c3, double z,
181 double *zmin, double *fmin){
182 /* could make an early return by testing curvature >0 for minimum */
184 double y = cubic (c0, c1, c2, c3, z);
188 *zmin = z; /* accepted new point*/
193 /**************************************************************************************************/
195 int gsl_poly_solve_quadratic (double a, double b, double c,
196 double *x0, double *x1)
199 double disc = b * b - 4 * a * c;
201 if (a == 0) /* Handle linear case */
218 double r = fabs (0.5 * sqrt (disc) / a);
224 double sgnb = (b > 0 ? 1 : -1);
225 double temp = -0.5 * (b + sgnb * sqrt (disc));
226 double r1 = temp / a ;
227 double r2 = c / temp ;
255 /**************************************************************************************************/
257 double interp_cubic (double f0, double fp0, double f1, double fp1, double zl, double zh){
258 double eta = 3 * (f1 - f0) - 2 * fp0 - fp1;
259 double xi = fp0 + fp1 - 2 * (f1 - f0);
260 double c0 = f0, c1 = fp0, c2 = eta, c3 = xi;
264 zmin = zl; fmin = cubic(c0, c1, c2, c3, zl);
265 check_extremum (c0, c1, c2, c3, zh, &zmin, &fmin);
268 int n = gsl_poly_solve_quadratic (3 * c3, 2 * c2, c1, &z0, &z1);
270 if (n == 2) /* found 2 roots */
272 if (z0 > zl && z0 < zh)
273 check_extremum (c0, c1, c2, c3, z0, &zmin, &fmin);
274 if (z1 > zl && z1 < zh)
275 check_extremum (c0, c1, c2, c3, z1, &zmin, &fmin);
277 else if (n == 1) /* found 1 root */
279 if (z0 > zl && z0 < zh)
280 check_extremum (c0, c1, c2, c3, z0, &zmin, &fmin);
287 /**************************************************************************************************/
289 double interpolate (double a, double fa, double fpa,
290 double b, double fb, double fpb, double xmin, double xmax){
291 /* Map [a,b] to [0,1] */
292 double z, alpha, zmin, zmax;
294 zmin = (xmin - a) / (b - a);
295 zmax = (xmax - a) / (b - a);
305 z = interp_cubic (fa, fpa * (b - a), fb, fpb * (b - a), zmin, zmax);
308 z = interp_quad(fa, fpa * (b - a), fb, zmin, zmax);
312 alpha = a + z * (b - a);
317 /**************************************************************************************************/
319 int qFinderDMM::lineMinimizeFletcher(vector<double>& x, vector<double>& p, double f0, double df0, double alpha1, double& alphaNew, double& fAlpha, vector<double>& xalpha, vector<double>& gradient ){
328 double alpha = alpha1;
329 double alpha_prev = 0.0000;
331 xalpha.resize(numOTUs, 0.0000);
333 double falpha_prev = f0;
334 double dfalpha_prev = df0;
336 double a = 0.0000; double b = alpha;
337 double fa = f0; double fb = 0.0000;
338 double dfa = df0; double dfb = 0.0/0.0;
342 while(iter++ < maxIters){
343 if (m->control_pressed) { break; }
345 for(int i=0;i<numOTUs;i++){
346 xalpha[i] = x[i] + alpha * p[i];
349 fAlpha = negativeLogEvidenceLambdaPi(xalpha);
351 if(fAlpha > f0 + alpha * rho * df0 || fAlpha >= falpha_prev){
352 a = alpha_prev; b = alpha;
353 fa = falpha_prev; fb = fAlpha;
354 dfa = dfalpha_prev; dfb = 0.0/0.0;
358 negativeLogDerivEvidenceLambdaPi(xalpha, gradient);
359 double dfalpha = 0.0000;
360 for(int i=0;i<numOTUs;i++){ dfalpha += gradient[i] * p[i]; }
362 if(abs(dfalpha) <= -sigma * df0){
368 a = alpha; b = alpha_prev;
369 fa = fAlpha; fb = falpha_prev;
370 dfa = dfalpha; dfb = dfalpha_prev;
374 double delta = alpha - alpha_prev;
376 double lower = alpha + delta;
377 double upper = alpha + tau1 * delta;
379 double alphaNext = interpolate(alpha_prev, falpha_prev, dfalpha_prev, alpha, fAlpha, dfalpha, lower, upper);
382 falpha_prev = fAlpha;
383 dfalpha_prev = dfalpha;
388 while(iter++ < maxIters){
389 if (m->control_pressed) { break; }
391 double delta = b - a;
393 double lower = a + tau2 * delta;
394 double upper = b - tau3 * delta;
396 alpha = interpolate(a, fa, dfa, b, fb, dfb, lower, upper);
398 for(int i=0;i<numOTUs;i++){
399 xalpha[i] = x[i] + alpha * p[i];
402 fAlpha = negativeLogEvidenceLambdaPi(xalpha);
404 if((a - alpha) * dfa <= EPSILON){
408 if(fAlpha > f0 + rho * alpha * df0 || fAlpha >= fa){
414 double dfalpha = 0.0000;
416 negativeLogDerivEvidenceLambdaPi(xalpha, gradient);
418 for(int i=0;i<numOTUs;i++){ dfalpha += gradient[i] * p[i]; }
420 if(abs(dfalpha) <= -sigma * df0){
425 if(((b-a >= 0 && dfalpha >= 0) || ((b-a) <= 0.000 && dfalpha <= 0))){
426 b = a; fb = fa; dfb = dfa;
427 a = alpha; fa = fAlpha; dfa = dfalpha;
441 catch(exception& e) {
442 m->errorOut(e, "qFinderDMM", "lineMinimizeFletcher");
447 /**************************************************************************************************/
449 int qFinderDMM::bfgs2_Solver(vector<double>& x){
451 // cout << "bfgs2_Solver" << endl;
453 double step = 1.0e-6;
454 double delta_f = 0.0000;//f-f0;
456 vector<double> gradient;
457 double f = negativeLogEvidenceLambdaPi(x);
459 // cout << "after negLE" << endl;
461 negativeLogDerivEvidenceLambdaPi(x, gradient);
463 // cout << "after negLDE" << endl;
465 vector<double> x0 = x;
466 vector<double> g0 = gradient;
469 for(int i=0;i<numOTUs;i++){
470 g0norm += g0[i] * g0[i];
472 g0norm = sqrt(g0norm);
474 vector<double> p = gradient;
476 for(int i=0;i<numOTUs;i++){
478 pNorm += p[i] * p[i];
481 double df0 = -g0norm;
485 // cout << "before while" << endl;
487 while(g0norm > 0.001 && bfgsIter++ < maxIter){
488 if (m->control_pressed) { return 0; }
491 vector<double> dx(numOTUs, 0.0000);
493 double alphaOld, alphaNew;
495 if(pNorm == 0 || g0norm == 0 || df0 == 0){
496 dx.assign(numOTUs, 0.0000);
500 double delta = max(-delta_f, 10 * EPSILON * abs(f0));
501 alphaOld = min(1.0, 2.0 * delta / (-df0));
507 int success = lineMinimizeFletcher(x0, p, f0, df0, alphaOld, alphaNew, f, x, gradient);
516 vector<double> dx0(numOTUs);
517 vector<double> dg0(numOTUs);
519 for(int i=0;i<numOTUs;i++){
520 dx0[i] = x[i] - x0[i];
521 dg0[i] = gradient[i] - g0[i];
529 for(int i=0;i<numOTUs;i++){
530 dxg += dx0[i] * gradient[i];
531 dgg += dg0[i] * gradient[i];
532 dxdg += dx0[i] * dg0[i];
533 dgnorm += dg0[i] * dg0[i];
535 dgnorm = sqrt(dgnorm);
541 A = -(1.0 + dgnorm*dgnorm /dxdg) * B + dgg / dxdg;
548 for(int i=0;i<numOTUs;i++){ p[i] = gradient[i] - A * dx0[i] - B * dg0[i]; }
558 for(int i=0;i<numOTUs;i++){
559 pg += p[i] * gradient[i];
560 pNorm += p[i] * p[i];
561 g0norm += g0[i] * g0[i];
564 g0norm = sqrt(g0norm);
566 double dir = (pg >= 0.0) ? -1.0 : +1.0;
568 for(int i=0;i<numOTUs;i++){ p[i] *= dir / pNorm; }
572 for(int i=0;i<numOTUs;i++){
573 pNorm += p[i] * p[i];
580 // cout << "bfgsIter:\t" << bfgsIter << endl;
585 m->errorOut(e, "qFinderDMM", "bfgs2_Solver");
591 /**************************************************************************************************/
593 double qFinderDMM::negativeLogEvidenceLambdaPi(vector<double>& x){
595 vector<double> sumAlphaX(numSamples, 0.0000);
597 double logEAlpha = 0.0000;
598 double sumLambda = 0.0000;
599 double sumAlpha = 0.0000;
600 double logE = 0.0000;
602 double eta = 0.10000;
604 double weight = 0.00000;
605 for(int i=0;i<numSamples;i++){
606 weight += zMatrix[currentPartition][i];
609 for(int i=0;i<numOTUs;i++){
610 if (m->control_pressed) { return 0; }
611 double lambda = x[i];
612 double alpha = exp(x[i]);
613 logEAlpha += lgamma(alpha);
617 for(int j=0;j<numSamples;j++){
618 double X = countMatrix[j][i];
619 double alphaX = alpha + X;
620 sumAlphaX[j] += alphaX;
622 logE -= zMatrix[currentPartition][j] * lgamma(alphaX);
626 logEAlpha -= lgamma(sumAlpha);
628 for(int i=0;i<numSamples;i++){
629 logE += zMatrix[currentPartition][i] * lgamma(sumAlphaX[i]);
632 return logE + weight * logEAlpha + nu * sumAlpha - eta * sumLambda;
635 m->errorOut(e, "qFinderDMM", "negativeLogEvidenceLambdaPi");
640 /**************************************************************************************************/
642 void qFinderDMM::negativeLogDerivEvidenceLambdaPi(vector<double>& x, vector<double>& df){
644 // cout << "\tstart negativeLogDerivEvidenceLambdaPi" << endl;
646 vector<double> storeVector(numSamples, 0.0000);
647 vector<double> derivative(numOTUs, 0.0000);
648 vector<double> alpha(numOTUs, 0.0000);
650 double store = 0.0000;
654 double weight = 0.0000;
655 for(int i=0;i<numSamples;i++){
656 weight += zMatrix[currentPartition][i];
660 for(int i=0;i<numOTUs;i++){
661 if (m->control_pressed) { return; }
662 // cout << "start i loop" << endl;
664 // cout << i << '\t' << alpha[i] << '\t' << x[i] << '\t' << exp(x[i]) << '\t' << store << endl;
666 alpha[i] = exp(x[i]);
669 // cout << "before derivative" << endl;
671 derivative[i] = weight * psi(alpha[i]);
673 // cout << "after derivative" << endl;
675 // cout << i << '\t' << alpha[i] << '\t' << psi(alpha[i]) << '\t' << derivative[i] << endl;
677 for(int j=0;j<numSamples;j++){
678 double X = countMatrix[j][i];
679 double alphaX = X + alpha[i];
681 derivative[i] -= zMatrix[currentPartition][j] * psi(alphaX);
682 storeVector[j] += alphaX;
684 // cout << "end i loop" << endl;
687 double sumStore = 0.0000;
688 for(int i=0;i<numSamples;i++){
689 sumStore += zMatrix[currentPartition][i] * psi(storeVector[i]);
692 store = weight * psi(store);
694 df.resize(numOTUs, 0.0000);
696 for(int i=0;i<numOTUs;i++){
697 df[i] = alpha[i] * (nu + derivative[i] - store + sumStore) - eta;
698 // cout << i << '\t' << df[i] << endl;
700 // cout << df.size() << endl;
701 // cout << "\tend negativeLogDerivEvidenceLambdaPi" << endl;
704 m->errorOut(e, "qFinderDMM", "negativeLogDerivEvidenceLambdaPi");
709 /**************************************************************************************************/
711 double qFinderDMM::getNegativeLogEvidence(vector<double>& lambda, int group){
713 double sumAlpha = 0.0000;
714 double sumAlphaX = 0.0000;
715 double sumLnGamAlpha = 0.0000;
716 double logEvidence = 0.0000;
718 for(int i=0;i<numOTUs;i++){
719 if (m->control_pressed) { return 0; }
720 double alpha = exp(lambda[i]);
721 double X = countMatrix[group][i];
722 double alphaX = alpha + X;
724 sumLnGamAlpha += lgamma(alpha);
728 logEvidence -= lgamma(alphaX);
731 sumLnGamAlpha -= lgamma(sumAlpha);
732 logEvidence += lgamma(sumAlphaX);
734 return logEvidence + sumLnGamAlpha;
737 m->errorOut(e, "qFinderDMM", "getNegativeLogEvidence");
742 /**************************************************************************************************/
744 void qFinderDMM::optimizeLambda(){
746 for(currentPartition=0;currentPartition<numPartitions;currentPartition++){
747 if (m->control_pressed) { return; }
748 bfgs2_Solver(lambdaMatrix[currentPartition]);
752 m->errorOut(e, "qFinderDMM", "optimizeLambda");
756 /**************************************************************************************************/
758 void qFinderDMM::calculatePiK(){
760 vector<double> store(numPartitions);
762 for(int i=0;i<numSamples;i++){
763 if (m->control_pressed) { return; }
765 double minNegLogEvidence =numeric_limits<double>::max();
767 for(int j=0;j<numPartitions;j++){
768 double negLogEvidenceJ = getNegativeLogEvidence(lambdaMatrix[j], i);
770 if(negLogEvidenceJ < minNegLogEvidence){
771 minNegLogEvidence = negLogEvidenceJ;
773 store[j] = negLogEvidenceJ;
776 for(int j=0;j<numPartitions;j++){
777 if (m->control_pressed) { return; }
778 zMatrix[j][i] = weights[j] * exp(-(store[j] - minNegLogEvidence));
779 sum += zMatrix[j][i];
782 for(int j=0;j<numPartitions;j++){
783 zMatrix[j][i] /= sum;
789 m->errorOut(e, "qFinderDMM", "calculatePiK");
795 /**************************************************************************************************/
797 double qFinderDMM::getNegativeLogLikelihood(){
799 double eta = 0.10000;
802 vector<double> pi(numPartitions, 0.0000);
803 vector<double> logBAlpha(numPartitions, 0.0000);
805 double doubleSum = 0.0000;
807 for(int i=0;i<numPartitions;i++){
808 if (m->control_pressed) { return 0; }
809 double sumAlphaK = 0.0000;
811 pi[i] = weights[i] / (double)numSamples;
813 for(int j=0;j<numOTUs;j++){
814 double alpha = exp(lambdaMatrix[i][j]);
817 logBAlpha[i] += lgamma(alpha);
819 logBAlpha[i] -= lgamma(sumAlphaK);
822 for(int i=0;i<numSamples;i++){
823 if (m->control_pressed) { return 0; }
825 double probability = 0.0000;
826 double factor = 0.0000;
828 vector<double> logStore(numPartitions, 0.0000);
829 double offset = -numeric_limits<double>::max();
831 for(int j=0;j<numOTUs;j++){
832 sum += countMatrix[i][j];
833 factor += lgamma(countMatrix[i][j] + 1.0000);
835 factor -= lgamma(sum + 1.0);
837 for(int k=0;k<numPartitions;k++){
839 double sumAlphaKX = 0.0000;
840 double logBAlphaX = 0.0000;
842 for(int j=0;j<numOTUs;j++){
843 double alphaX = exp(lambdaMatrix[k][j]) + (double)countMatrix[i][j];
845 sumAlphaKX += alphaX;
846 logBAlphaX += lgamma(alphaX);
849 logBAlphaX -= lgamma(sumAlphaKX);
851 logStore[k] = logBAlphaX - logBAlpha[k] - factor;
852 if(logStore[k] > offset){
853 offset = logStore[k];
858 for(int k=0;k<numPartitions;k++){
859 probability += pi[k] * exp(-offset + logStore[k]);
861 doubleSum += log(probability) + offset;
865 double L5 = - numOTUs * numPartitions * lgamma(eta);
866 double L6 = eta * numPartitions * numOTUs * log(nu);
868 double alphaSum, lambdaSum;
869 alphaSum = lambdaSum = 0.0000;
871 for(int i=0;i<numPartitions;i++){
872 for(int j=0;j<numOTUs;j++){
873 if (m->control_pressed) { return 0; }
874 alphaSum += exp(lambdaMatrix[i][j]);
875 lambdaSum += lambdaMatrix[i][j];
881 return (-doubleSum - L5 - L6 - alphaSum - lambdaSum);
884 m->errorOut(e, "qFinderDMM", "getNegativeLogLikelihood");
890 /**************************************************************************************************/