]> git.donarmstrong.com Git - mothur.git/blob - bayesian.cpp
fixed craig nelsons weighted bug and paralellized parsimony
[mothur.git] / bayesian.cpp
1 /*
2  *  bayesian.cpp
3  *  Mothur
4  *
5  *  Created by westcott on 11/3/09.
6  *  Copyright 2009 Schloss Lab. All rights reserved.
7  *
8  */
9
10 #include "bayesian.h"
11 #include "kmer.hpp"
12 #include "phylosummary.h"
13
14 /**************************************************************************************************/
15 Bayesian::Bayesian(string tfile, string tempFile, string method, int ksize, int cutoff, int i) : 
16 Classify(), kmerSize(ksize), confidenceThreshold(cutoff), iters(i)  {
17         try {
18                 
19                 /************calculate the probablity that each word will be in a specific taxonomy*************/
20                 string tfileroot = tfile.substr(0,tfile.find_last_of(".")+1);
21                 string tempfileroot = m->getRootName(m->getSimpleName(tempFile));
22                 string phyloTreeName = tfileroot + "tree.train";
23                 string phyloTreeSumName = tfileroot + "tree.sum";
24                 string probFileName = tfileroot + tempfileroot + char('0'+ kmerSize) + "mer.prob";
25                 string probFileName2 = tfileroot + tempfileroot + char('0'+ kmerSize) + "mer.numNonZero";
26                 
27                 ofstream out;
28                 ofstream out2;
29                 
30                 ifstream phyloTreeTest(phyloTreeName.c_str());
31                 ifstream probFileTest2(probFileName2.c_str());
32                 ifstream probFileTest(probFileName.c_str());
33                 ifstream probFileTest3(phyloTreeSumName.c_str());
34                 
35                 int start = time(NULL);
36                 
37                 //if they are there make sure they were created after this release date
38                 bool FilesGood = false;
39                 if(probFileTest && probFileTest2 && phyloTreeTest && probFileTest3){
40                         FilesGood = checkReleaseDate(probFileTest, probFileTest2, phyloTreeTest, probFileTest3);
41                 }
42                 
43                 if(probFileTest && probFileTest2 && phyloTreeTest && probFileTest3 && FilesGood){       
44                         m->mothurOut("Reading template taxonomy...     "); cout.flush();
45                         
46                         phyloTree = new PhyloTree(phyloTreeTest, phyloTreeName);
47                         
48                         m->mothurOut("DONE."); m->mothurOutEndLine();
49                         
50                         genusNodes = phyloTree->getGenusNodes(); 
51                         genusTotals = phyloTree->getGenusTotals();
52                 
53                         m->mothurOut("Reading template probabilities...     "); cout.flush();
54                         readProbFile(probFileTest, probFileTest2, probFileName, probFileName2); 
55                         
56                 }else{
57                 
58                         //create search database and names vector
59                         generateDatabaseAndNames(tfile, tempFile, method, ksize, 0.0, 0.0, 0.0, 0.0);
60                         
61                         //prevents errors caused by creating shortcut files if you had an error in the sanity check.
62                         if (m->control_pressed) {  remove(phyloTreeName.c_str());  remove(probFileName.c_str()); remove(probFileName2.c_str()); }
63                         else{ 
64                                 genusNodes = phyloTree->getGenusNodes(); 
65                                 genusTotals = phyloTree->getGenusTotals();
66                                 
67                                 m->mothurOut("Calculating template taxonomy tree...     "); cout.flush();
68                                 
69                                 phyloTree->printTreeNodes(phyloTreeName);
70                                                         
71                                 m->mothurOut("DONE."); m->mothurOutEndLine();
72                                 
73                                 m->mothurOut("Calculating template probabilities...     "); cout.flush();
74                                 
75                                 numKmers = database->getMaxKmer() + 1;
76                         
77                                 //initialze probabilities
78                                 wordGenusProb.resize(numKmers);
79                         //cout << numKmers << '\t' << genusNodes.size() << endl;
80                                 for (int j = 0; j < wordGenusProb.size(); j++) {        wordGenusProb[j].resize(genusNodes.size());             }
81                         //cout << numKmers << '\t' << genusNodes.size() << endl;        
82                                 ofstream out;
83                                 ofstream out2;
84                                 
85                                 #ifdef USE_MPI
86                                         int pid;
87                                         MPI_Comm_rank(MPI_COMM_WORLD, &pid); //find out who we are
88
89                                         if (pid == 0) {  
90                                 #endif
91
92                                 
93                                 m->openOutputFile(probFileName, out);
94                                 
95                                 //output mothur version
96                                 out << "#" << m->getVersion() << endl;
97                                 
98                                 out << numKmers << endl;
99                                 
100                                 m->openOutputFile(probFileName2, out2);
101                                 
102                                 //output mothur version
103                                 out2 << "#" << m->getVersion() << endl;
104                                 
105                                 #ifdef USE_MPI
106                                         }
107                                 #endif
108
109                                 
110                                 //for each word
111                                 for (int i = 0; i < numKmers; i++) {
112                                         if (m->control_pressed) {  break; }
113                                         
114                                         #ifdef USE_MPI
115                                                 MPI_Comm_rank(MPI_COMM_WORLD, &pid); //find out who we are
116
117                                                 if (pid == 0) {  
118                                         #endif
119
120                                         out << i << '\t';
121                                         
122                                         #ifdef USE_MPI
123                                                 }
124                                         #endif
125                                         
126                                         vector<int> seqsWithWordi = database->getSequencesWithKmer(i);
127                                         
128                                         map<int, int> count;
129                                         for (int k = 0; k < genusNodes.size(); k++) {  count[genusNodes[k]] = 0;  }                     
130                                                         
131                                         //for each sequence with that word
132                                         for (int j = 0; j < seqsWithWordi.size(); j++) {
133                                                 int temp = phyloTree->getIndex(names[seqsWithWordi[j]]);
134                                                 count[temp]++;  //increment count of seq in this genus who have this word
135                                         }
136                                         
137                                         //probabilityInTemplate = (# of seqs with that word in template + 0.05) / (total number of seqs in template + 1);
138                                         float probabilityInTemplate = (seqsWithWordi.size() + 0.50) / (float) (names.size() + 1);
139                                         
140                                         int numNotZero = 0;
141                                         for (int k = 0; k < genusNodes.size(); k++) {
142                                                 //probabilityInThisTaxonomy = (# of seqs with that word in this taxonomy + probabilityInTemplate) / (total number of seqs in this taxonomy + 1);
143                                                 wordGenusProb[i][k] = log((count[genusNodes[k]] + probabilityInTemplate) / (float) (genusTotals[k] + 1));  
144                                                 if (count[genusNodes[k]] != 0) { 
145                                                         #ifdef USE_MPI
146                                                                 int pid;
147                                                                 MPI_Comm_rank(MPI_COMM_WORLD, &pid); //find out who we are
148                                                 
149                                                                 if (pid == 0) {  
150                                                         #endif
151
152                                                         out << k << '\t' << wordGenusProb[i][k] << '\t'; 
153                                                         
154                                                         #ifdef USE_MPI
155                                                                 }
156                                                         #endif
157
158                                                         numNotZero++;  
159                                                 }
160                                         }
161                                         
162                                         #ifdef USE_MPI
163                                                 MPI_Comm_rank(MPI_COMM_WORLD, &pid); //find out who we are
164                                 
165                                                 if (pid == 0) {  
166                                         #endif
167                                         
168                                         out << endl;
169                                         out2 << probabilityInTemplate << '\t' << numNotZero << endl;
170                                         
171                                         #ifdef USE_MPI
172                                                 }
173                                         #endif
174                                 }
175                                 
176                                 #ifdef USE_MPI
177                                         MPI_Comm_rank(MPI_COMM_WORLD, &pid); //find out who we are
178                                 
179                                         if (pid == 0) {  
180                                 #endif
181                                 
182                                 out.close();
183                                 out2.close();
184                                 
185                                 #ifdef USE_MPI
186                                         }
187                                 #endif
188                                 
189                                 //read in new phylotree with less info. - its faster
190                                 ifstream phyloTreeTest(phyloTreeName.c_str());
191                                 delete phyloTree;
192                                 
193                                 phyloTree = new PhyloTree(phyloTreeTest, phyloTreeName);
194                         }
195                 }
196         
197                 m->mothurOut("DONE."); m->mothurOutEndLine();
198                 m->mothurOut("It took " + toString(time(NULL) - start) + " seconds get probabilities. "); m->mothurOutEndLine();
199         }
200         catch(exception& e) {
201                 m->errorOut(e, "Bayesian", "Bayesian");
202                 exit(1);
203         }
204 }
205 /**************************************************************************************************/
206 Bayesian::~Bayesian() {
207         try {
208                 
209                  delete phyloTree; 
210                  if (database != NULL) {  delete database; }
211         }
212         catch(exception& e) {
213                 m->errorOut(e, "Bayesian", "~Bayesian");
214                 exit(1);
215         }
216 }
217
218 /**************************************************************************************************/
219 string Bayesian::getTaxonomy(Sequence* seq) {
220         try {
221                 string tax = "";
222                 Kmer kmer(kmerSize);
223                 
224                 //get words contained in query
225                 //getKmerString returns a string where the index in the string is hte kmer number 
226                 //and the character at that index can be converted to be the number of times that kmer was seen
227                 
228                 string queryKmerString = kmer.getKmerString(seq->getUnaligned()); 
229
230                 vector<int> queryKmers;
231                 for (int i = 0; i < queryKmerString.length(); i++) {
232                         if (queryKmerString[i] != '!') { //this kmer is in the query
233                                 queryKmers.push_back(i);
234                         }
235                 }
236                 
237                 if (queryKmers.size() == 0) {  m->mothurOut(seq->getName() + "is bad."); m->mothurOutEndLine(); return "bad seq"; }
238                 
239                 
240                 int index = getMostProbableTaxonomy(queryKmers);
241                 
242                 if (m->control_pressed) { return tax; }
243                                         
244                 //bootstrap - to set confidenceScore
245                 int numToSelect = queryKmers.size() / 8;
246         
247                 tax = bootstrapResults(queryKmers, index, numToSelect);
248                                 
249                 return tax;     
250         }
251         catch(exception& e) {
252                 m->errorOut(e, "Bayesian", "getTaxonomy");
253                 exit(1);
254         }
255 }
256 /**************************************************************************************************/
257 string Bayesian::bootstrapResults(vector<int> kmers, int tax, int numToSelect) {
258         try {
259                                 
260                 map<int, int> confidenceScores; 
261                 
262                 //initialize confidences to 0 
263                 int seqIndex = tax;
264                 TaxNode seq = phyloTree->get(tax);
265                 confidenceScores[tax] = 0;
266                 
267                 while (seq.level != 0) { //while you are not at the root
268                         seqIndex = seq.parent;
269                         confidenceScores[seqIndex] = 0;
270                         seq = phyloTree->get(seq.parent);
271                 }
272                                 
273                 map<int, int>::iterator itBoot;
274                 map<int, int>::iterator itBoot2;
275                 map<int, int>::iterator itConvert;
276                 
277                 for (int i = 0; i < iters; i++) {
278                         if (m->control_pressed) { return "control"; }
279                         
280                         vector<int> temp;
281                         for (int j = 0; j < numToSelect; j++) {
282                                 int index = int(rand() % kmers.size());
283                                 
284                                 //add word to temp
285                                 temp.push_back(kmers[index]);
286                         }
287                         
288                         //get taxonomy
289                         int newTax = getMostProbableTaxonomy(temp);
290                         //int newTax = 1;
291                         TaxNode taxonomyTemp = phyloTree->get(newTax);
292                         
293                         //add to confidence results
294                         while (taxonomyTemp.level != 0) { //while you are not at the root
295                                 itBoot2 = confidenceScores.find(newTax); //is this a classification we already have a count on
296                                 
297                                 if (itBoot2 != confidenceScores.end()) { //this is a classification we need a confidence for
298                                         (itBoot2->second)++;
299                                 }
300                                 
301                                 newTax = taxonomyTemp.parent;
302                                 taxonomyTemp = phyloTree->get(newTax);
303                         }
304         
305                 }
306                 
307                 string confidenceTax = "";
308                 simpleTax = "";
309                 
310                 int seqTaxIndex = tax;
311                 TaxNode seqTax = phyloTree->get(tax);
312                 
313                 while (seqTax.level != 0) { //while you are not at the root
314                                         
315                                 itBoot2 = confidenceScores.find(seqTaxIndex); //is this a classification we already have a count on
316                                 
317                                 int confidence = 0;
318                                 if (itBoot2 != confidenceScores.end()) { //already in confidence scores
319                                         confidence = itBoot2->second;
320                                 }
321                                 
322                                 if (((confidence/(float)iters) * 100) >= confidenceThreshold) {
323                                         confidenceTax = seqTax.name + "(" + toString(((confidence/(float)iters) * 100)) + ");" + confidenceTax;
324                                         simpleTax = seqTax.name + ";" + simpleTax;
325                                 }
326                                 
327                                 seqTaxIndex = seqTax.parent;
328                                 seqTax = phyloTree->get(seqTax.parent);
329                 }
330                 
331                 if (confidenceTax == "") { confidenceTax = "unclassified;"; simpleTax = "unclassified;"; }
332                 return confidenceTax;
333                 
334         }
335         catch(exception& e) {
336                 m->errorOut(e, "Bayesian", "bootstrapResults");
337                 exit(1);
338         }
339 }
340 /**************************************************************************************************/
341 int Bayesian::getMostProbableTaxonomy(vector<int> queryKmer) {
342         try {
343                 int indexofGenus = 0;
344                 
345                 double maxProbability = -1000000.0;
346                 //find taxonomy with highest probability that this sequence is from it
347                 for (int k = 0; k < genusNodes.size(); k++) {
348                         //for each taxonomy calc its probability
349                         double prob = 1.0;
350                         for (int i = 0; i < queryKmer.size(); i++) {
351                                 prob += wordGenusProb[queryKmer[i]][k];
352                         }
353                         
354                         //is this the taxonomy with the greatest probability?
355                         if (prob > maxProbability) { 
356                                 indexofGenus = genusNodes[k];
357                                 maxProbability = prob;
358                         }
359                 }
360
361                 return indexofGenus;
362         }
363         catch(exception& e) {
364                 m->errorOut(e, "Bayesian", "getMostProbableTaxonomy");
365                 exit(1);
366         }
367 }
368 /*************************************************************************************************
369 map<string, int> Bayesian::parseTaxMap(string newTax) {
370         try{
371         
372                 map<string, int> parsed;
373                 
374                 newTax = newTax.substr(0, newTax.length()-1);  //get rid of last ';'
375         
376                 //parse taxonomy
377                 string individual;
378                 while (newTax.find_first_of(';') != -1) {
379                         individual = newTax.substr(0,newTax.find_first_of(';'));
380                         newTax = newTax.substr(newTax.find_first_of(';')+1, newTax.length());
381                         parsed[individual] = 1;
382                 }
383                 
384                 //get last one
385                 parsed[newTax] = 1;
386
387                 return parsed;
388                 
389         }
390         catch(exception& e) {
391                 m->errorOut(e, "Bayesian", "parseTax");
392                 exit(1);
393         }
394 }
395 /**************************************************************************************************/
396 void Bayesian::readProbFile(ifstream& in, ifstream& inNum, string inName, string inNumName) {
397         try{
398                 
399                 #ifdef USE_MPI
400                         
401                         int pid, num, num2, processors;
402                         vector<unsigned long int> positions;
403                         vector<unsigned long int> positions2;
404                         
405                         MPI_Status status; 
406                         MPI_File inMPI;
407                         MPI_File inMPI2;
408                         MPI_Comm_rank(MPI_COMM_WORLD, &pid); //find out who we are
409                         MPI_Comm_size(MPI_COMM_WORLD, &processors);
410                         int tag = 2001;
411
412                         char inFileName[1024];
413                         strcpy(inFileName, inNumName.c_str());
414                         
415                         char inFileName2[1024];
416                         strcpy(inFileName2, inName.c_str());
417
418                         MPI_File_open(MPI_COMM_WORLD, inFileName, MPI_MODE_RDONLY, MPI_INFO_NULL, &inMPI);  //comm, filename, mode, info, filepointer
419                         MPI_File_open(MPI_COMM_WORLD, inFileName2, MPI_MODE_RDONLY, MPI_INFO_NULL, &inMPI2);  //comm, filename, mode, info, filepointer
420
421                         if (pid == 0) {
422                                 positions = m->setFilePosEachLine(inNumName, num);
423                                 positions2 = m->setFilePosEachLine(inName, num2);
424                                 
425                                 for(int i = 1; i < processors; i++) { 
426                                         MPI_Send(&num, 1, MPI_INT, i, tag, MPI_COMM_WORLD);
427                                         MPI_Send(&positions[0], (num+1), MPI_LONG, i, tag, MPI_COMM_WORLD);
428                                         
429                                         MPI_Send(&num2, 1, MPI_INT, i, tag, MPI_COMM_WORLD);
430                                         MPI_Send(&positions2[0], (num2+1), MPI_LONG, i, tag, MPI_COMM_WORLD);
431                                 }
432
433                         }else{
434                                 MPI_Recv(&num, 1, MPI_INT, 0, tag, MPI_COMM_WORLD, &status);
435                                 positions.resize(num+1);
436                                 MPI_Recv(&positions[0], (num+1), MPI_LONG, 0, tag, MPI_COMM_WORLD, &status);
437                                 
438                                 MPI_Recv(&num2, 1, MPI_INT, 0, tag, MPI_COMM_WORLD, &status);
439                                 positions2.resize(num2+1);
440                                 MPI_Recv(&positions2[0], (num2+1), MPI_LONG, 0, tag, MPI_COMM_WORLD, &status);
441                         }
442                         
443                         //read version
444                         int length = positions2[1] - positions2[0];
445                         char* buf5 = new char[length];
446
447                         MPI_File_read_at(inMPI2, positions2[0], buf5, length, MPI_CHAR, &status);
448                         delete buf5;
449
450                         //read numKmers
451                         length = positions2[2] - positions2[1];
452                         char* buf = new char[length];
453
454                         MPI_File_read_at(inMPI2, positions2[1], buf, length, MPI_CHAR, &status);
455
456                         string tempBuf = buf;
457                         if (tempBuf.length() > length) { tempBuf = tempBuf.substr(0, length); }
458                         delete buf;
459
460                         istringstream iss (tempBuf,istringstream::in);
461                         iss >> numKmers;  
462                         
463                         //initialze probabilities
464                         wordGenusProb.resize(numKmers);
465                         
466                         for (int j = 0; j < wordGenusProb.size(); j++) {        wordGenusProb[j].resize(genusNodes.size());             }
467                         
468                         int kmer, name;  
469                         vector<int> numbers; numbers.resize(numKmers);
470                         float prob;
471                         vector<float> zeroCountProb; zeroCountProb.resize(numKmers);    
472                         
473                         //read version
474                         length = positions[1] - positions[0];
475                         char* buf6 = new char[length];
476
477                         MPI_File_read_at(inMPI2, positions[0], buf6, length, MPI_CHAR, &status);
478                         delete buf6;
479                         
480                         //read file 
481                         for(int i=1;i<num;i++){
482                                 //read next sequence
483                                 length = positions[i+1] - positions[i];
484                                 char* buf4 = new char[length];
485
486                                 MPI_File_read_at(inMPI, positions[i], buf4, length, MPI_CHAR, &status);
487
488                                 tempBuf = buf4;
489                                 if (tempBuf.length() > length) { tempBuf = tempBuf.substr(0, length); }
490                                 delete buf4;
491
492                                 istringstream iss (tempBuf,istringstream::in);
493                                 iss >> zeroCountProb[i] >> numbers[i];  
494                         }
495                         
496                         MPI_File_close(&inMPI);
497                         
498                         for(int i=2;i<num2;i++){
499                                 //read next sequence
500                                 length = positions2[i+1] - positions2[i];
501                                 char* buf4 = new char[length];
502
503                                 MPI_File_read_at(inMPI2, positions2[i], buf4, length, MPI_CHAR, &status);
504
505                                 tempBuf = buf4;
506                                 if (tempBuf.length() > length) { tempBuf = tempBuf.substr(0, length); }
507                                 delete buf4;
508
509                                 istringstream iss (tempBuf,istringstream::in);
510                                 
511                                 iss >> kmer;
512                                 
513                                 //set them all to zero value
514                                 for (int i = 0; i < genusNodes.size(); i++) {
515                                         wordGenusProb[kmer][i] = log(zeroCountProb[kmer] / (float) (genusTotals[i]+1));
516                                 }
517                                 
518                                 //get probs for nonzero values
519                                 for (int i = 0; i < numbers[kmer]; i++) {
520                                         iss >> name >> prob;
521                                         wordGenusProb[kmer][name] = prob;
522                                 }
523                                 
524                         }
525                         MPI_File_close(&inMPI2);
526                         MPI_Barrier(MPI_COMM_WORLD); //make everyone wait - just in case
527                 #else
528                         //read version
529                         string line = m->getline(in); m->gobble(in);
530                         
531                         in >> numKmers; m->gobble(in);
532                         
533                         //initialze probabilities
534                         wordGenusProb.resize(numKmers);
535                         
536                         for (int j = 0; j < wordGenusProb.size(); j++) {        wordGenusProb[j].resize(genusNodes.size());             }
537                         
538                         int kmer, name, count;  count = 0;
539                         vector<int> num; num.resize(numKmers);
540                         float prob;
541                         vector<float> zeroCountProb; zeroCountProb.resize(numKmers);            
542                         
543                         //read version
544                         string line2 = m->getline(inNum); m->gobble(inNum);
545                         
546                         while (inNum) {
547                                 inNum >> zeroCountProb[count] >> num[count];  
548                                 count++;
549                                 m->gobble(inNum);
550                         }
551                         inNum.close();
552                 
553                         while(in) {
554                                 in >> kmer;
555                                 
556                                 //set them all to zero value
557                                 for (int i = 0; i < genusNodes.size(); i++) {
558                                         wordGenusProb[kmer][i] = log(zeroCountProb[kmer] / (float) (genusTotals[i]+1));
559                                 }
560                                 
561                                 //get probs for nonzero values
562                                 for (int i = 0; i < num[kmer]; i++) {
563                                         in >> name >> prob;
564                                         wordGenusProb[kmer][name] = prob;
565                                 }
566                                 
567                                 m->gobble(in);
568                         }
569                         in.close();
570                         
571                 #endif
572         }
573         catch(exception& e) {
574                 m->errorOut(e, "Bayesian", "readProbFile");
575                 exit(1);
576         }
577 }
578 /**************************************************************************************************/
579 bool Bayesian::checkReleaseDate(ifstream& file1, ifstream& file2, ifstream& file3, ifstream& file4) {
580         try {
581                 
582                 bool good = true;
583                 
584                 vector<string> lines;
585                 lines.push_back(m->getline(file1));  
586                 lines.push_back(m->getline(file2)); 
587                 lines.push_back(m->getline(file3)); 
588                 lines.push_back(m->getline(file4)); 
589
590                 //before we added this check
591                 if ((lines[0][0] != '#') || (lines[1][0] != '#') || (lines[2][0] != '#') || (lines[3][0] != '#')) {  good = false;  }
592                 else {
593                         //rip off #
594                         for (int i = 0; i < lines.size(); i++) { lines[i] = lines[i].substr(1);  }
595                         
596                         //get mothurs current version
597                         string version = m->getVersion();
598                         
599                         vector<string> versionVector;
600                         m->splitAtChar(version, versionVector, '.');
601                         
602                         //check each files version
603                         for (int i = 0; i < lines.size(); i++) { 
604                                 vector<string> linesVector;
605                                 m->splitAtChar(lines[i], linesVector, '.');
606                         
607                                 if (versionVector.size() != linesVector.size()) { good = false; break; }
608                                 else {
609                                         for (int j = 0; j < versionVector.size(); j++) {
610                                                 int num1, num2;
611                                                 convert(versionVector[j], num1);
612                                                 convert(linesVector[j], num2);
613                                                 
614                                                 //if mothurs version is newer than this files version, then we want to remake it
615                                                 if (num1 > num2) {  good = false; break;  }
616                                         }
617                                 }
618                                 
619                                 if (!good) { break; }
620                         }
621                 }
622                 
623                 if (!good) {  file1.close(); file2.close(); file3.close(); file4.close();  }
624                 else { file1.seekg(0);  file2.seekg(0);  file3.seekg(0);  file4.seekg(0);  }
625                 
626                 return good;
627         }
628         catch(exception& e) {
629                 m->errorOut(e, "Bayesian", "checkReleaseDate");
630                 exit(1);
631         }
632 }
633 /**************************************************************************************************/
634
635
636
637
638
639