]> git.donarmstrong.com Git - mothur.git/blob - randomforest.cpp
sffinfo bug with flow grams right index when clipQualRight=0
[mothur.git] / randomforest.cpp
1 //
2 //  randomforest.cpp
3 //  Mothur
4 //
5 //  Created by Sarah Westcott on 10/2/12.
6 //  Copyright (c) 2012 Schloss Lab. All rights reserved.
7 //
8
9 #include "randomforest.hpp" 
10
11 /***********************************************************************/
12
13 RandomForest::RandomForest(const vector <vector<int> > dataSet,const int numDecisionTrees,
14              const string treeSplitCriterion = "informationGain") : Forest(dataSet, numDecisionTrees, treeSplitCriterion) {
15     m = MothurOut::getInstance();
16 }
17
18 /***********************************************************************/
19 // DONE
20 int RandomForest::calcForrestErrorRate() {
21     try {
22         int numCorrect = 0;
23         for (map<int, vector<int> >::iterator it = globalOutOfBagEstimates.begin(); it != globalOutOfBagEstimates.end(); it++) {
24             
25             if (m->control_pressed) { return 0; }
26             
27             int indexOfSample = it->first;
28             vector<int> predictedOutComes = it->second;
29             vector<int>::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end());
30             int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin());
31             int realOutcome = dataSet[indexOfSample][numFeatures];
32             
33             if (majorityVotedOutcome == realOutcome) { numCorrect++; }
34         }
35         
36         // TODO: save or return forrestErrorRate for future use;
37         double forrestErrorRate = 1 - ((double)numCorrect / (double)globalOutOfBagEstimates.size());
38         
39         m->mothurOut("numCorrect = " + toString(numCorrect)+ "\n");
40         m->mothurOut("forrestErrorRate = " + toString(forrestErrorRate)+ "\n");
41     
42         return 0;
43     }
44         catch(exception& e) {
45                 m->errorOut(e, "RandomForest", "calcForrestErrorRate");
46                 exit(1);
47         } 
48 }
49
50 /***********************************************************************/
51 // DONE
52 int RandomForest::calcForrestVariableImportance(string filename) {
53     try {
54     
55     // TODO: need to add try/catch operators to fix this
56     // follow the link: http://en.wikipedia.org/wiki/Dynamic_cast
57         //if you are going to dynamically cast, aren't you undoing the advantage of abstraction. Why abstract at all?
58         //could cause maintenance issues later if other types of Abstract decison trees are created that cannot be cast as a decision tree.
59     for (int i = 0; i < decisionTrees.size(); i++) {
60         if (m->control_pressed) { return 0; }
61         
62         DecisionTree* decisionTree = dynamic_cast<DecisionTree*>(decisionTrees[i]);
63         
64         for (int j = 0; j < numFeatures; j++) {
65             globalVariableImportanceList[j] += (double)decisionTree->variableImportanceList[j];
66         }
67     }
68     
69     for (int i = 0;  i < numFeatures; i++) {
70         cout << "[" << i << ',' << globalVariableImportanceList[i] << "], ";
71         globalVariableImportanceList[i] /= (double)numDecisionTrees;
72     }
73     
74     vector< vector<double> > globalVariableRanks;
75     for (int i = 0; i < globalVariableImportanceList.size(); i++) {
76         if (globalVariableImportanceList[i] > 0) {
77             vector<double> globalVariableRank(2, 0);
78             globalVariableRank[0] = i; globalVariableRank[1] = globalVariableImportanceList[i];
79             globalVariableRanks.push_back(globalVariableRank);
80         }
81     }
82     
83     VariableRankDescendingSorterDouble variableRankDescendingSorter;
84     sort(globalVariableRanks.begin(), globalVariableRanks.end(), variableRankDescendingSorter);
85         ofstream out;
86         m->openOutputFile(filename, out);
87         out <<"OTU\tRank\n";
88         for (int i = 0; i < globalVariableRanks.size(); i++) {
89             out << m->currentBinLabels[(int)globalVariableRanks[i][0]] << '\t' << globalVariableImportanceList[globalVariableRanks[i][0]] << endl;
90         }
91         out.close();
92         return 0;
93     }
94         catch(exception& e) {
95                 m->errorOut(e, "RandomForest", "calcForrestVariableImportance");
96                 exit(1);
97         }  
98 }
99 /***********************************************************************/
100 // DONE
101 int RandomForest::populateDecisionTrees() {
102     try {
103         
104         for (int i = 0; i < numDecisionTrees; i++) {
105             if (m->control_pressed) { return 0; }
106             if (((i+1) % 10) == 0) {  m->mothurOut("Creating " + toString(i+1) + " (th) Decision tree\n");  }
107             // TODO: need to first fix if we are going to use pointer based system or anything else
108             DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector("log2"), treeSplitCriterion);
109             decisionTree->calcTreeVariableImportanceAndError();
110             if (m->control_pressed) { return 0; }
111             updateGlobalOutOfBagEstimates(decisionTree);
112             if (m->control_pressed) { return 0; }
113             decisionTree->purgeDataSetsFromTree();
114             if (m->control_pressed) { return 0; }
115             decisionTrees.push_back(decisionTree);
116         }
117         
118         if (m->debug) {
119             // m->mothurOut("globalOutOfBagEstimates = " + toStringVectorMap(globalOutOfBagEstimates)+ "\n");
120         }
121         
122         return 0;
123     }
124     catch(exception& e) {
125         m->errorOut(e, "RandomForest", "populateDecisionTrees");
126         exit(1);
127     }  
128 }
129 /***********************************************************************/
130 // TODO: need to finalize bettween reference and pointer for DecisionTree [partially solved]
131 // DONE: make this pure virtual in superclass
132 // DONE
133 int RandomForest::updateGlobalOutOfBagEstimates(DecisionTree* decisionTree) {
134     try {
135         for (map<int, int>::iterator it = decisionTree->outOfBagEstimates.begin(); it != decisionTree->outOfBagEstimates.end(); it++) {
136             
137             if (m->control_pressed) { return 0; }
138             
139             int indexOfSample = it->first;
140             int predictedOutcomeOfSample = it->second;
141             
142             if (globalOutOfBagEstimates.count(indexOfSample) == 0) {
143                 globalOutOfBagEstimates[indexOfSample] = vector<int>(decisionTree->numOutputClasses, 0);
144             };
145             
146             globalOutOfBagEstimates[indexOfSample][predictedOutcomeOfSample] += 1;
147         }
148         return 0;
149     }
150     catch(exception& e) {
151         m->errorOut(e, "RandomForest", "updateGlobalOutOfBagEstimates");
152         exit(1);
153     }  
154 }
155 /***********************************************************************/
156
157