]> git.donarmstrong.com Git - mothur.git/blob - communitytype.cpp
added Jensen-Shannon calc. working on get.communitytype command. fixed bug in get...
[mothur.git] / communitytype.cpp
1 //
2 //  communitytype.cpp
3 //  Mothur
4 //
5 //  Created by SarahsWork on 12/3/13.
6 //  Copyright (c) 2013 Schloss Lab. All rights reserved.
7 //
8
9 #include "communitytype.h"
10
11 /**************************************************************************************************/
12
13 //can we get these psi/psi1 calculations into their own math class?
14 //psi calcualtions swiped from gsl library...
15
16 static const double psi_cs[23] = {
17     -.038057080835217922,
18     .491415393029387130,
19     -.056815747821244730,
20     .008357821225914313,
21     -.001333232857994342,
22     .000220313287069308,
23     -.000037040238178456,
24     .000006283793654854,
25     -.000001071263908506,
26     .000000183128394654,
27     -.000000031353509361,
28     .000000005372808776,
29     -.000000000921168141,
30     .000000000157981265,
31     -.000000000027098646,
32     .000000000004648722,
33     -.000000000000797527,
34     .000000000000136827,
35     -.000000000000023475,
36     .000000000000004027,
37     -.000000000000000691,
38     .000000000000000118,
39     -.000000000000000020
40 };
41
42 static double apsi_cs[16] = {
43     -.0204749044678185,
44     -.0101801271534859,
45     .0000559718725387,
46     -.0000012917176570,
47     .0000000572858606,
48     -.0000000038213539,
49     .0000000003397434,
50     -.0000000000374838,
51     .0000000000048990,
52     -.0000000000007344,
53     .0000000000001233,
54     -.0000000000000228,
55     .0000000000000045,
56     -.0000000000000009,
57     .0000000000000002,
58     -.0000000000000000
59 };
60
61 /**************************************************************************************************/
62 /* coefficients for Maclaurin summation in hzeta()
63  * B_{2j}/(2j)!
64  */
65 static double hzeta_c[15] = {
66     1.00000000000000000000000000000,
67     0.083333333333333333333333333333,
68     -0.00138888888888888888888888888889,
69     0.000033068783068783068783068783069,
70     -8.2671957671957671957671957672e-07,
71     2.0876756987868098979210090321e-08,
72     -5.2841901386874931848476822022e-10,
73     1.3382536530684678832826980975e-11,
74     -3.3896802963225828668301953912e-13,
75     8.5860620562778445641359054504e-15,
76     -2.1748686985580618730415164239e-16,
77     5.5090028283602295152026526089e-18,
78     -1.3954464685812523340707686264e-19,
79     3.5347070396294674716932299778e-21,
80     -8.9535174270375468504026113181e-23
81 };
82
83 /**************************************************************************************************/
84 void CommunityTypeFinder::printSilData(ofstream& out, double chi, vector<double> sils){
85     try {
86         out << setprecision (6) << numPartitions << '\t'  << chi << '\t';
87         for (int i = 0; i < sils.size(); i++) {
88             out << sils[i] << '\t';
89         }
90         out << endl;
91         
92         return;
93     }
94     catch(exception& e){
95         m->errorOut(e, "CommunityTypeFinder", "printSilData");
96         exit(1);
97     }
98 }
99 /**************************************************************************************************/
100 void CommunityTypeFinder::printSilData(ostream& out, double chi, vector<double> sils){
101     try {
102         out << setprecision (6) << numPartitions << '\t'  << chi << '\t';
103         m->mothurOutJustToLog(toString(numPartitions) + '\t' + toString(chi) + '\t');
104         for (int i = 0; i < sils.size(); i++) {
105             out << sils[i] << '\t';
106             m->mothurOutJustToLog(toString(sils[i]) + '\t');
107         }
108         out << endl;
109         m->mothurOutJustToLog("\n");
110         
111         return;
112     }
113     catch(exception& e){
114         m->errorOut(e, "CommunityTypeFinder", "printSilData");
115         exit(1);
116     }
117 }
118 /**************************************************************************************************/
119
120 void CommunityTypeFinder::printZMatrix(string fileName, vector<string> sampleName){
121     try {
122         ofstream printMatrix;
123         m->openOutputFile(fileName, printMatrix); //(fileName.c_str());
124         printMatrix.setf(ios::fixed, ios::floatfield);
125         printMatrix.setf(ios::showpoint);
126         
127         for(int i=0;i<numPartitions;i++){   printMatrix << "\tPartition_" << i+1;   }   printMatrix << endl;
128         
129         for(int i=0;i<numSamples;i++){
130             printMatrix << sampleName[i];
131             for(int j=0;j<numPartitions;j++){
132                 printMatrix << setprecision(4) << '\t' << zMatrix[j][i];
133             }
134             printMatrix << endl;
135         }
136         printMatrix.close();
137     }
138         catch(exception& e) {
139                 m->errorOut(e, "CommunityTypeFinder", "printZMatrix");
140                 exit(1);
141         }
142 }
143
144 /**************************************************************************************************/
145
146 void CommunityTypeFinder::printRelAbund(string fileName, vector<string> otuNames){
147     try {
148         ofstream printRA;
149         m->openOutputFile(fileName, printRA); //(fileName.c_str());
150         printRA.setf(ios::fixed, ios::floatfield);
151         printRA.setf(ios::showpoint);
152         
153         vector<double> totals(numPartitions, 0.0000);
154         for(int i=0;i<numPartitions;i++){
155             for(int j=0;j<numOTUs;j++){
156                 totals[i] += exp(lambdaMatrix[i][j]);
157             }
158         }
159         
160         printRA << "Taxon";
161         for(int i=0;i<numPartitions;i++){
162             printRA << "\tPartition_" << i+1 << '_' << setprecision(4) << totals[i];
163             printRA << "\tPartition_" << i+1 <<"_LCI" << "\tPartition_" << i+1 << "_UCI";
164         }
165         printRA << endl;
166         
167         for(int i=0;i<numOTUs;i++){
168             
169             if (m->control_pressed) { break; }
170             
171             printRA << otuNames[i];
172             for(int j=0;j<numPartitions;j++){
173                 
174                 if(error[j][i] >= 0.0000){
175                     double std = sqrt(error[j][i]);
176                     printRA << '\t' << 100 * exp(lambdaMatrix[j][i]) / totals[j];
177                     printRA << '\t' << 100 * exp(lambdaMatrix[j][i] - 2.0 * std) / totals[j];
178                     printRA << '\t' << 100 * exp(lambdaMatrix[j][i] + 2.0 * std) / totals[j];
179                 }
180                 else{
181                     printRA << '\t' << 100 * exp(lambdaMatrix[j][i]) / totals[j];
182                     printRA << '\t' << "NA";
183                     printRA << '\t' << "NA";
184                 }
185             }
186             printRA << endl;
187         }
188         
189         printRA.close();
190     }
191         catch(exception& e) {
192                 m->errorOut(e, "CommunityTypeFinder", "printRelAbund");
193                 exit(1);
194         }
195 }
196
197 /**************************************************************************************************/
198
199 vector<vector<double> > CommunityTypeFinder::getHessian(){
200     try {
201         vector<double> alpha(numOTUs, 0.0000);
202         double alphaSum = 0.0000;
203         
204         vector<double> pi = zMatrix[currentPartition];
205         vector<double> psi_ajk(numOTUs, 0.0000);
206         vector<double> psi_cjk(numOTUs, 0.0000);
207         vector<double> psi1_ajk(numOTUs, 0.0000);
208         vector<double> psi1_cjk(numOTUs, 0.0000);
209         
210         for(int j=0;j<numOTUs;j++){
211             
212             if (m->control_pressed) {  break; }
213             
214             alpha[j] = exp(lambdaMatrix[currentPartition][j]);
215             alphaSum += alpha[j];
216             
217             for(int i=0;i<numSamples;i++){
218                 double X = (double) countMatrix[i][j];
219                 
220                 psi_ajk[j] += pi[i] * psi(alpha[j]);
221                 psi1_ajk[j] += pi[i] * psi1(alpha[j]);
222                 
223                 psi_cjk[j] += pi[i] * psi(alpha[j] + X);
224                 psi1_cjk[j] += pi[i] * psi1(alpha[j] + X);
225             }
226         }
227         
228         
229         double psi_Ck = 0.0000;
230         double psi1_Ck = 0.0000;
231         
232         double weight = 0.0000;
233         
234         for(int i=0;i<numSamples;i++){
235             if (m->control_pressed) {  break; }
236             weight += pi[i];
237             double sum = 0.0000;
238             for(int j=0;j<numOTUs;j++){     sum += alpha[j] + countMatrix[i][j];    }
239             
240             psi_Ck += pi[i] * psi(sum);
241             psi1_Ck += pi[i] * psi1(sum);
242         }
243         
244         double psi_Ak = weight * psi(alphaSum);
245         double psi1_Ak = weight * psi1(alphaSum);
246         
247         vector<vector<double> > hessian(numOTUs);
248         for(int i=0;i<numOTUs;i++){ hessian[i].assign(numOTUs, 0.0000); }
249         
250         for(int i=0;i<numOTUs;i++){
251             if (m->control_pressed) {  break; }
252             double term1 = -alpha[i] * (- psi_ajk[i] + psi_Ak + psi_cjk[i] - psi_Ck);
253             double term2 = -alpha[i] * alpha[i] * (-psi1_ajk[i] + psi1_Ak + psi1_cjk[i] - psi1_Ck);
254             double term3 = 0.1 * alpha[i];
255             
256             hessian[i][i] = term1 + term2 + term3;
257             
258             for(int j=0;j<i;j++){
259                 hessian[i][j] = - alpha[i] * alpha[j] * (psi1_Ak - psi1_Ck);
260                 hessian[j][i] = hessian[i][j];
261             }
262         }
263         
264         return hessian;
265     }
266     catch(exception& e){
267         m->errorOut(e, "CommunityTypeFinder", "getHessian");
268         exit(1);
269     }
270 }
271 /**************************************************************************************************/
272
273 double CommunityTypeFinder::psi1(double xx){
274     try {
275         
276         /* Euler-Maclaurin summation formula
277          * [Moshier, p. 400, with several typo corrections]
278          */
279         
280         double s = 2.0000;
281         const int jmax = 12;
282         const int kmax = 10;
283         int j, k;
284         const double pmax  = pow(kmax + xx, -s);
285         double scp = s;
286         double pcp = pmax / (kmax + xx);
287         double value = pmax*((kmax+xx)/(s-1.0) + 0.5);
288         
289         for(k=0; k<kmax; k++) {
290             if (m->control_pressed) {  return 0; }
291             value += pow(k + xx, -s);
292         }
293         
294         for(j=0; j<=jmax; j++) {
295             if (m->control_pressed) {  return 0; }
296             double delta = hzeta_c[j+1] * scp * pcp;
297             value += delta;
298             
299             if(fabs(delta/value) < 0.5*EPSILON) break;
300             
301             scp *= (s+2*j+1)*(s+2*j+2);
302             pcp /= (kmax + xx)*(kmax + xx);
303         }
304         
305         return value;
306     }
307     catch(exception& e){
308         m->errorOut(e, "CommunityTypeFinder", "psi1");
309         exit(1);
310     }
311 }
312
313 /**************************************************************************************************/
314
315 double CommunityTypeFinder::psi(double xx){
316     try {
317         double psiX = 0.0000;
318         
319         if(xx < 1.0000){
320             
321             double t1 = 1.0 / xx;
322             psiX = cheb_eval(psi_cs, 22, 2.0*xx-1.0);
323             psiX = -t1 + psiX;
324             
325         }
326         else if(xx < 2.0000){
327             
328             const double v = xx - 1.0;
329             psiX = cheb_eval(psi_cs, 22, 2.0*v-1.0);
330             
331         }
332         else{
333             const double t = 8.0/(xx*xx)-1.0;
334             psiX = cheb_eval(apsi_cs, 15, t);
335             psiX += log(xx) - 0.5/xx;
336         }
337         
338         return psiX;
339     }
340     catch(exception& e){
341         m->errorOut(e, "CommunityTypeFinder", "psi");
342         exit(1);
343     }
344 }
345 /**************************************************************************************************/
346
347 double CommunityTypeFinder::cheb_eval(const double seriesData[], int order, double xx){
348     try {
349         double d = 0.0000;
350         double dd = 0.0000;
351         
352         double x2 = xx * 2.0000;
353         
354         for(int j=order;j>=1;j--){
355             if (m->control_pressed) {  return 0; }
356             double temp = d;
357             d = x2 * d - dd + seriesData[j];
358             dd = temp;
359         }
360         
361         d = xx * d - dd + 0.5 * seriesData[0];
362         return d;
363     }
364     catch(exception& e){
365         m->errorOut(e, "CommunityTypeFinder", "cheb_eval");
366         exit(1);
367     }
368 }
369 /**************************************************************************************************/
370
371 int CommunityTypeFinder::findkMeans(){
372     try {
373         error.resize(numPartitions); for (int i = 0; i < numPartitions; i++) { error[i].resize(numOTUs, 0.0); }
374         vector<vector<double> > relativeAbundance(numSamples);
375         vector<vector<double> > alphaMatrix;
376         
377         alphaMatrix.resize(numPartitions);
378         lambdaMatrix.resize(numPartitions);
379         for(int i=0;i<numPartitions;i++){
380             alphaMatrix[i].assign(numOTUs, 0);
381             lambdaMatrix[i].assign(numOTUs, 0);
382         }
383         
384         //get relative abundance
385         for(int i=0;i<numSamples;i++){
386             if (m->control_pressed) {  return 0; }
387             int groupTotal = 0;
388             
389             relativeAbundance[i].assign(numOTUs, 0.0);
390             
391             for(int j=0;j<numOTUs;j++){
392                 groupTotal += countMatrix[i][j];
393             }
394             for(int j=0;j<numOTUs;j++){
395                 relativeAbundance[i][j] = countMatrix[i][j] / (double)groupTotal;
396             }
397         }
398         
399         //randomly assign samples into partitions
400         zMatrix.resize(numPartitions);
401         for(int i=0;i<numPartitions;i++){
402             zMatrix[i].assign(numSamples, 0);
403         }
404         
405         for(int i=0;i<numSamples;i++){
406             zMatrix[rand()%numPartitions][i] = 1;
407         }
408         
409         double maxChange = 1;
410         int maxIters = 1000;
411         int iteration = 0;
412         
413         weights.assign(numPartitions, 0);
414         
415         while(maxChange > 1e-6 && iteration < maxIters){
416             
417             if (m->control_pressed) {  return 0; }
418             //calcualte average relative abundance
419             maxChange = 0.0000;
420             for(int i=0;i<numPartitions;i++){
421                 
422                 double normChange = 0.0;
423                 
424                 weights[i] = 0;
425                 
426                 for(int j=0;j<numSamples;j++){
427                     weights[i] += (double)zMatrix[i][j];
428                 }
429                 
430                 vector<double> averageRelativeAbundance(numOTUs, 0);
431                 for(int j=0;j<numOTUs;j++){
432                     for(int k=0;k<numSamples;k++){
433                         averageRelativeAbundance[j] += zMatrix[i][k] * relativeAbundance[k][j];
434                     }
435                 }
436                 
437                 for(int j=0;j<numOTUs;j++){
438                     averageRelativeAbundance[j] /= weights[i];
439                     double difference = averageRelativeAbundance[j] - alphaMatrix[i][j];
440                     normChange += difference * difference;
441                     alphaMatrix[i][j] = averageRelativeAbundance[j];
442                 }
443                 
444                 normChange = sqrt(normChange);
445                 
446                 if(normChange > maxChange){ maxChange = normChange; }
447             }
448             
449             
450             //calcualte distance between each sample in partition and the average relative abundance
451             for(int i=0;i<numSamples;i++){
452                 if (m->control_pressed) {  return 0; }
453                 
454                 double normalizationFactor = 0;
455                 vector<double> totalDistToPartition(numPartitions, 0);
456                 
457                 for(int j=0;j<numPartitions;j++){
458                     for(int k=0;k<numOTUs;k++){
459                         double difference = alphaMatrix[j][k] - relativeAbundance[i][k];
460                         totalDistToPartition[j] += difference * difference;
461                     }
462                     totalDistToPartition[j] = sqrt(totalDistToPartition[j]);
463                     normalizationFactor += exp(-50.0 * totalDistToPartition[j]);
464                 }
465                 
466                 
467                 for(int j=0;j<numPartitions;j++){
468                     zMatrix[j][i] = exp(-50.0 * totalDistToPartition[j]) / normalizationFactor;
469                 }
470                 
471             }
472             
473             iteration++;
474             //        cout << "K means: " << iteration << '\t' << maxChange << endl;
475             
476         }
477         
478         //    cout << "Iter:-1";
479         for(int i=0;i<numPartitions;i++){
480             weights[i] = 0.0000;
481             
482             for(int j=0;j<numSamples;j++){
483                 weights[i] += zMatrix[i][j];
484             }
485             //        printf("\tw_%d=%.3f", i, weights[i]);
486         }
487         //    cout << endl;
488         
489         
490         for(int i=0;i<numOTUs;i++){
491             if (m->control_pressed) {  return 0; }
492             for(int j=0;j<numPartitions;j++){
493                 if(alphaMatrix[j][i] > 0){
494                     lambdaMatrix[j][i] = log(alphaMatrix[j][i]);
495                 }
496                 else{
497                     lambdaMatrix[j][i] = -10.0;
498                 }
499             }
500         }
501         
502         return 0;
503     }
504     catch(exception& e){
505         m->errorOut(e, "CommunityTypeFinder", "kMeans");
506         exit(1);
507     }
508 }
509
510
511 /**************************************************************************************************/
512