5 // Created by Patrick Schloss on 11/8/12.
6 // Copyright (c) 2012 University of Michigan. All rights reserved.
9 #include "qFinderDMM.h"
13 /**************************************************************************************************/
15 qFinderDMM::qFinderDMM(vector<vector<int> > cm, int p) : CommunityTypeFinder() {
17 //cout << "here" << endl;
20 numSamples = (int)countMatrix.size();
21 numOTUs = (int)countMatrix[0].size();
23 //cout << "before kmeans" <<endl;
25 //cout << "done kMeans" << endl;
30 //cout << "done optimizeLambda" << endl;
32 double change = 1.0000;
37 while(change > 1.0e-6 && iter < 100){
44 //printf("Iter:%d\t",iter);
46 for(int i=0;i<numPartitions;i++){
48 for(int j=0;j<numSamples;j++){
49 weights[i] += zMatrix[i][j];
51 // printf("w_%d=%.3f\t",i,weights[i]);
55 double nLL = getNegativeLogLikelihood();
57 change = abs(nLL - currNLL);
61 // printf("NLL=%.5f\tDelta=%.4e\n",currNLL, change);
66 error.resize(numPartitions);
68 logDeterminant = 0.0000;
72 for(currentPartition=0;currentPartition<numPartitions;currentPartition++){
74 error[currentPartition].assign(numOTUs, 0.0000);
76 if(currentPartition > 0){
77 logDeterminant += (2.0 * log(numSamples) - log(weights[currentPartition]));
79 vector<vector<double> > hessian = getHessian();
80 vector<vector<double> > invHessian = l.getInverse(hessian);
82 for(int i=0;i<numOTUs;i++){
83 logDeterminant += log(abs(hessian[i][i]));
84 error[currentPartition][i] = invHessian[i][i];
88 int numParameters = numPartitions * numOTUs + numPartitions - 1;
89 laplace = currNLL + 0.5 * logDeterminant - 0.5 * numParameters * log(2.0 * 3.14159);
90 bic = currNLL + 0.5 * log(numSamples) * numParameters;
91 aic = currNLL + numParameters;
94 m->errorOut(e, "qFinderDMM", "qFinderDMM");
98 /**************************************************************************************************/
99 void qFinderDMM::printFitData(ofstream& out){
101 out << setprecision (2) << numPartitions << '\t' << getNLL() << '\t' << getLogDet() << '\t' << getBIC() << '\t' << getAIC() << '\t' << laplace << endl;
105 m->errorOut(e, "CommunityTypeFinder", "printFitData");
109 /**************************************************************************************************/
110 void qFinderDMM::printFitData(ostream& out, double minLaplace){
112 if(laplace < minLaplace){
113 out << setprecision (2) << numPartitions << '\t' << getNLL() << '\t' << getLogDet() << '\t' << getBIC() << '\t' << getAIC() << '\t' << laplace << "***" << endl;
115 out << setprecision (2) << numPartitions << '\t' << getNLL() << '\t' << getLogDet() << '\t' << getBIC() << '\t' << getAIC() << '\t' << laplace << endl;
118 m->mothurOutJustToLog(toString(numPartitions) + '\t' + toString(getNLL()) + '\t' + toString(getLogDet()) + '\t');
119 m->mothurOutJustToLog(toString(getBIC()) + '\t' + toString(getAIC()) + '\t' + toString(laplace));
124 m->errorOut(e, "CommunityTypeFinder", "printFitData");
129 /**************************************************************************************************/
131 // these functions for bfgs2 solver were lifted from the gnu_gsl source code...
133 /* Find a minimum in x=[0,1] of the interpolating quadratic through
134 * (0,f0) (1,f1) with derivative fp0 at x=0. The interpolating
135 * polynomial is q(x) = f0 + fp0 * z + (f1-f0-fp0) * z^2
139 interp_quad (double f0, double fp0, double f1, double zl, double zh)
141 double fl = f0 + zl*(fp0 + zl*(f1 - f0 -fp0));
142 double fh = f0 + zh*(fp0 + zh*(f1 - f0 -fp0));
143 double c = 2 * (f1 - f0 - fp0); /* curvature */
145 double zmin = zl, fmin = fl;
147 if (fh < fmin) { zmin = zh; fmin = fh; }
149 if (c > 0) /* positive curvature required for a minimum */
151 double z = -fp0 / c; /* location of minimum */
152 if (z > zl && z < zh) {
153 double f = f0 + z*(fp0 + z*(f1 - f0 -fp0));
154 if (f < fmin) { zmin = z; fmin = f; };
161 /**************************************************************************************************/
163 /* Find a minimum in x=[0,1] of the interpolating cubic through
164 * (0,f0) (1,f1) with derivatives fp0 at x=0 and fp1 at x=1.
166 * The interpolating polynomial is:
168 * c(x) = f0 + fp0 * z + eta * z^2 + xi * z^3
170 * where eta=3*(f1-f0)-2*fp0-fp1, xi=fp0+fp1-2*(f1-f0).
173 double cubic (double c0, double c1, double c2, double c3, double z){
174 return c0 + z * (c1 + z * (c2 + z * c3));
177 /**************************************************************************************************/
179 void check_extremum (double c0, double c1, double c2, double c3, double z,
180 double *zmin, double *fmin){
181 /* could make an early return by testing curvature >0 for minimum */
183 double y = cubic (c0, c1, c2, c3, z);
187 *zmin = z; /* accepted new point*/
192 /**************************************************************************************************/
194 int gsl_poly_solve_quadratic (double a, double b, double c,
195 double *x0, double *x1)
198 double disc = b * b - 4 * a * c;
200 if (a == 0) /* Handle linear case */
217 double r = fabs (0.5 * sqrt (disc) / a);
223 double sgnb = (b > 0 ? 1 : -1);
224 double temp = -0.5 * (b + sgnb * sqrt (disc));
225 double r1 = temp / a ;
226 double r2 = c / temp ;
254 /**************************************************************************************************/
256 double interp_cubic (double f0, double fp0, double f1, double fp1, double zl, double zh){
257 double eta = 3 * (f1 - f0) - 2 * fp0 - fp1;
258 double xi = fp0 + fp1 - 2 * (f1 - f0);
259 double c0 = f0, c1 = fp0, c2 = eta, c3 = xi;
263 zmin = zl; fmin = cubic(c0, c1, c2, c3, zl);
264 check_extremum (c0, c1, c2, c3, zh, &zmin, &fmin);
267 int n = gsl_poly_solve_quadratic (3 * c3, 2 * c2, c1, &z0, &z1);
269 if (n == 2) /* found 2 roots */
271 if (z0 > zl && z0 < zh)
272 check_extremum (c0, c1, c2, c3, z0, &zmin, &fmin);
273 if (z1 > zl && z1 < zh)
274 check_extremum (c0, c1, c2, c3, z1, &zmin, &fmin);
276 else if (n == 1) /* found 1 root */
278 if (z0 > zl && z0 < zh)
279 check_extremum (c0, c1, c2, c3, z0, &zmin, &fmin);
286 /**************************************************************************************************/
288 double interpolate (double a, double fa, double fpa,
289 double b, double fb, double fpb, double xmin, double xmax){
290 /* Map [a,b] to [0,1] */
291 double z, alpha, zmin, zmax;
293 zmin = (xmin - a) / (b - a);
294 zmax = (xmax - a) / (b - a);
304 z = interp_cubic (fa, fpa * (b - a), fb, fpb * (b - a), zmin, zmax);
307 z = interp_quad(fa, fpa * (b - a), fb, zmin, zmax);
311 alpha = a + z * (b - a);
316 /**************************************************************************************************/
318 int qFinderDMM::lineMinimizeFletcher(vector<double>& x, vector<double>& p, double f0, double df0, double alpha1, double& alphaNew, double& fAlpha, vector<double>& xalpha, vector<double>& gradient ){
327 double alpha = alpha1;
328 double alpha_prev = 0.0000;
330 xalpha.resize(numOTUs, 0.0000);
332 double falpha_prev = f0;
333 double dfalpha_prev = df0;
335 double a = 0.0000; double b = alpha;
336 double fa = f0; double fb = 0.0000;
337 double dfa = df0; double dfb = 0.0/0.0;
341 while(iter++ < maxIters){
342 if (m->control_pressed) { break; }
344 for(int i=0;i<numOTUs;i++){
345 xalpha[i] = x[i] + alpha * p[i];
348 fAlpha = negativeLogEvidenceLambdaPi(xalpha);
350 if(fAlpha > f0 + alpha * rho * df0 || fAlpha >= falpha_prev){
351 a = alpha_prev; b = alpha;
352 fa = falpha_prev; fb = fAlpha;
353 dfa = dfalpha_prev; dfb = 0.0/0.0;
357 negativeLogDerivEvidenceLambdaPi(xalpha, gradient);
358 double dfalpha = 0.0000;
359 for(int i=0;i<numOTUs;i++){ dfalpha += gradient[i] * p[i]; }
361 if(abs(dfalpha) <= -sigma * df0){
367 a = alpha; b = alpha_prev;
368 fa = fAlpha; fb = falpha_prev;
369 dfa = dfalpha; dfb = dfalpha_prev;
373 double delta = alpha - alpha_prev;
375 double lower = alpha + delta;
376 double upper = alpha + tau1 * delta;
378 double alphaNext = interpolate(alpha_prev, falpha_prev, dfalpha_prev, alpha, fAlpha, dfalpha, lower, upper);
381 falpha_prev = fAlpha;
382 dfalpha_prev = dfalpha;
387 while(iter++ < maxIters){
388 if (m->control_pressed) { break; }
390 double delta = b - a;
392 double lower = a + tau2 * delta;
393 double upper = b - tau3 * delta;
395 alpha = interpolate(a, fa, dfa, b, fb, dfb, lower, upper);
397 for(int i=0;i<numOTUs;i++){
398 xalpha[i] = x[i] + alpha * p[i];
401 fAlpha = negativeLogEvidenceLambdaPi(xalpha);
403 if((a - alpha) * dfa <= EPSILON){
407 if(fAlpha > f0 + rho * alpha * df0 || fAlpha >= fa){
413 double dfalpha = 0.0000;
415 negativeLogDerivEvidenceLambdaPi(xalpha, gradient);
417 for(int i=0;i<numOTUs;i++){ dfalpha += gradient[i] * p[i]; }
419 if(abs(dfalpha) <= -sigma * df0){
424 if(((b-a >= 0 && dfalpha >= 0) || ((b-a) <= 0.000 && dfalpha <= 0))){
425 b = a; fb = fa; dfb = dfa;
426 a = alpha; fa = fAlpha; dfa = dfalpha;
440 catch(exception& e) {
441 m->errorOut(e, "qFinderDMM", "lineMinimizeFletcher");
446 /**************************************************************************************************/
448 int qFinderDMM::bfgs2_Solver(vector<double>& x){
450 // cout << "bfgs2_Solver" << endl;
452 double step = 1.0e-6;
453 double delta_f = 0.0000;//f-f0;
455 vector<double> gradient;
456 double f = negativeLogEvidenceLambdaPi(x);
458 // cout << "after negLE" << endl;
460 negativeLogDerivEvidenceLambdaPi(x, gradient);
462 // cout << "after negLDE" << endl;
464 vector<double> x0 = x;
465 vector<double> g0 = gradient;
468 for(int i=0;i<numOTUs;i++){
469 g0norm += g0[i] * g0[i];
471 g0norm = sqrt(g0norm);
473 vector<double> p = gradient;
475 for(int i=0;i<numOTUs;i++){
477 pNorm += p[i] * p[i];
480 double df0 = -g0norm;
484 // cout << "before while" << endl;
486 while(g0norm > 0.001 && bfgsIter++ < maxIter){
487 if (m->control_pressed) { return 0; }
490 vector<double> dx(numOTUs, 0.0000);
492 double alphaOld, alphaNew;
494 if(pNorm == 0 || g0norm == 0 || df0 == 0){
495 dx.assign(numOTUs, 0.0000);
499 double delta = max(-delta_f, 10 * EPSILON * abs(f0));
500 alphaOld = min(1.0, 2.0 * delta / (-df0));
506 int success = lineMinimizeFletcher(x0, p, f0, df0, alphaOld, alphaNew, f, x, gradient);
515 vector<double> dx0(numOTUs);
516 vector<double> dg0(numOTUs);
518 for(int i=0;i<numOTUs;i++){
519 dx0[i] = x[i] - x0[i];
520 dg0[i] = gradient[i] - g0[i];
528 for(int i=0;i<numOTUs;i++){
529 dxg += dx0[i] * gradient[i];
530 dgg += dg0[i] * gradient[i];
531 dxdg += dx0[i] * dg0[i];
532 dgnorm += dg0[i] * dg0[i];
534 dgnorm = sqrt(dgnorm);
540 A = -(1.0 + dgnorm*dgnorm /dxdg) * B + dgg / dxdg;
547 for(int i=0;i<numOTUs;i++){ p[i] = gradient[i] - A * dx0[i] - B * dg0[i]; }
557 for(int i=0;i<numOTUs;i++){
558 pg += p[i] * gradient[i];
559 pNorm += p[i] * p[i];
560 g0norm += g0[i] * g0[i];
563 g0norm = sqrt(g0norm);
565 double dir = (pg >= 0.0) ? -1.0 : +1.0;
567 for(int i=0;i<numOTUs;i++){ p[i] *= dir / pNorm; }
571 for(int i=0;i<numOTUs;i++){
572 pNorm += p[i] * p[i];
579 // cout << "bfgsIter:\t" << bfgsIter << endl;
584 m->errorOut(e, "qFinderDMM", "bfgs2_Solver");
590 /**************************************************************************************************/
592 double qFinderDMM::negativeLogEvidenceLambdaPi(vector<double>& x){
594 vector<double> sumAlphaX(numSamples, 0.0000);
596 double logEAlpha = 0.0000;
597 double sumLambda = 0.0000;
598 double sumAlpha = 0.0000;
599 double logE = 0.0000;
601 double eta = 0.10000;
603 double weight = 0.00000;
604 for(int i=0;i<numSamples;i++){
605 weight += zMatrix[currentPartition][i];
608 for(int i=0;i<numOTUs;i++){
609 if (m->control_pressed) { return 0; }
610 double lambda = x[i];
611 double alpha = exp(x[i]);
612 logEAlpha += lgamma(alpha);
616 for(int j=0;j<numSamples;j++){
617 double X = countMatrix[j][i];
618 double alphaX = alpha + X;
619 sumAlphaX[j] += alphaX;
621 logE -= zMatrix[currentPartition][j] * lgamma(alphaX);
625 logEAlpha -= lgamma(sumAlpha);
627 for(int i=0;i<numSamples;i++){
628 logE += zMatrix[currentPartition][i] * lgamma(sumAlphaX[i]);
631 return logE + weight * logEAlpha + nu * sumAlpha - eta * sumLambda;
634 m->errorOut(e, "qFinderDMM", "negativeLogEvidenceLambdaPi");
639 /**************************************************************************************************/
641 void qFinderDMM::negativeLogDerivEvidenceLambdaPi(vector<double>& x, vector<double>& df){
643 // cout << "\tstart negativeLogDerivEvidenceLambdaPi" << endl;
645 vector<double> storeVector(numSamples, 0.0000);
646 vector<double> derivative(numOTUs, 0.0000);
647 vector<double> alpha(numOTUs, 0.0000);
649 double store = 0.0000;
653 double weight = 0.0000;
654 for(int i=0;i<numSamples;i++){
655 weight += zMatrix[currentPartition][i];
659 for(int i=0;i<numOTUs;i++){
660 if (m->control_pressed) { return; }
661 // cout << "start i loop" << endl;
663 // cout << i << '\t' << alpha[i] << '\t' << x[i] << '\t' << exp(x[i]) << '\t' << store << endl;
665 alpha[i] = exp(x[i]);
668 // cout << "before derivative" << endl;
670 derivative[i] = weight * psi(alpha[i]);
672 // cout << "after derivative" << endl;
674 // cout << i << '\t' << alpha[i] << '\t' << psi(alpha[i]) << '\t' << derivative[i] << endl;
676 for(int j=0;j<numSamples;j++){
677 double X = countMatrix[j][i];
678 double alphaX = X + alpha[i];
680 derivative[i] -= zMatrix[currentPartition][j] * psi(alphaX);
681 storeVector[j] += alphaX;
683 // cout << "end i loop" << endl;
686 double sumStore = 0.0000;
687 for(int i=0;i<numSamples;i++){
688 sumStore += zMatrix[currentPartition][i] * psi(storeVector[i]);
691 store = weight * psi(store);
693 df.resize(numOTUs, 0.0000);
695 for(int i=0;i<numOTUs;i++){
696 df[i] = alpha[i] * (nu + derivative[i] - store + sumStore) - eta;
697 // cout << i << '\t' << df[i] << endl;
699 // cout << df.size() << endl;
700 // cout << "\tend negativeLogDerivEvidenceLambdaPi" << endl;
703 m->errorOut(e, "qFinderDMM", "negativeLogDerivEvidenceLambdaPi");
708 /**************************************************************************************************/
710 double qFinderDMM::getNegativeLogEvidence(vector<double>& lambda, int group){
712 double sumAlpha = 0.0000;
713 double sumAlphaX = 0.0000;
714 double sumLnGamAlpha = 0.0000;
715 double logEvidence = 0.0000;
717 for(int i=0;i<numOTUs;i++){
718 if (m->control_pressed) { return 0; }
719 double alpha = exp(lambda[i]);
720 double X = countMatrix[group][i];
721 double alphaX = alpha + X;
723 sumLnGamAlpha += lgamma(alpha);
727 logEvidence -= lgamma(alphaX);
730 sumLnGamAlpha -= lgamma(sumAlpha);
731 logEvidence += lgamma(sumAlphaX);
733 return logEvidence + sumLnGamAlpha;
736 m->errorOut(e, "qFinderDMM", "getNegativeLogEvidence");
741 /**************************************************************************************************/
743 void qFinderDMM::optimizeLambda(){
745 for(currentPartition=0;currentPartition<numPartitions;currentPartition++){
746 if (m->control_pressed) { return; }
747 bfgs2_Solver(lambdaMatrix[currentPartition]);
751 m->errorOut(e, "qFinderDMM", "optimizeLambda");
755 /**************************************************************************************************/
757 void qFinderDMM::calculatePiK(){
759 vector<double> store(numPartitions);
761 for(int i=0;i<numSamples;i++){
762 if (m->control_pressed) { return; }
764 double minNegLogEvidence =numeric_limits<double>::max();
766 for(int j=0;j<numPartitions;j++){
767 double negLogEvidenceJ = getNegativeLogEvidence(lambdaMatrix[j], i);
769 if(negLogEvidenceJ < minNegLogEvidence){
770 minNegLogEvidence = negLogEvidenceJ;
772 store[j] = negLogEvidenceJ;
775 for(int j=0;j<numPartitions;j++){
776 if (m->control_pressed) { return; }
777 zMatrix[j][i] = weights[j] * exp(-(store[j] - minNegLogEvidence));
778 sum += zMatrix[j][i];
781 for(int j=0;j<numPartitions;j++){
782 zMatrix[j][i] /= sum;
788 m->errorOut(e, "qFinderDMM", "calculatePiK");
794 /**************************************************************************************************/
796 double qFinderDMM::getNegativeLogLikelihood(){
798 double eta = 0.10000;
801 vector<double> pi(numPartitions, 0.0000);
802 vector<double> logBAlpha(numPartitions, 0.0000);
804 double doubleSum = 0.0000;
806 for(int i=0;i<numPartitions;i++){
807 if (m->control_pressed) { return 0; }
808 double sumAlphaK = 0.0000;
810 pi[i] = weights[i] / (double)numSamples;
812 for(int j=0;j<numOTUs;j++){
813 double alpha = exp(lambdaMatrix[i][j]);
816 logBAlpha[i] += lgamma(alpha);
818 logBAlpha[i] -= lgamma(sumAlphaK);
821 for(int i=0;i<numSamples;i++){
822 if (m->control_pressed) { return 0; }
824 double probability = 0.0000;
825 double factor = 0.0000;
827 vector<double> logStore(numPartitions, 0.0000);
828 double offset = -numeric_limits<double>::max();
830 for(int j=0;j<numOTUs;j++){
831 sum += countMatrix[i][j];
832 factor += lgamma(countMatrix[i][j] + 1.0000);
834 factor -= lgamma(sum + 1.0);
836 for(int k=0;k<numPartitions;k++){
838 double sumAlphaKX = 0.0000;
839 double logBAlphaX = 0.0000;
841 for(int j=0;j<numOTUs;j++){
842 double alphaX = exp(lambdaMatrix[k][j]) + (double)countMatrix[i][j];
844 sumAlphaKX += alphaX;
845 logBAlphaX += lgamma(alphaX);
848 logBAlphaX -= lgamma(sumAlphaKX);
850 logStore[k] = logBAlphaX - logBAlpha[k] - factor;
851 if(logStore[k] > offset){
852 offset = logStore[k];
857 for(int k=0;k<numPartitions;k++){
858 probability += pi[k] * exp(-offset + logStore[k]);
860 doubleSum += log(probability) + offset;
864 double L5 = - numOTUs * numPartitions * lgamma(eta);
865 double L6 = eta * numPartitions * numOTUs * log(nu);
867 double alphaSum, lambdaSum;
868 alphaSum = lambdaSum = 0.0000;
870 for(int i=0;i<numPartitions;i++){
871 for(int j=0;j<numOTUs;j++){
872 if (m->control_pressed) { return 0; }
873 alphaSum += exp(lambdaMatrix[i][j]);
874 lambdaSum += lambdaMatrix[i][j];
880 return (-doubleSum - L5 - L6 - alphaSum - lambdaSum);
883 m->errorOut(e, "qFinderDMM", "getNegativeLogLikelihood");
889 /**************************************************************************************************/