]> git.donarmstrong.com Git - mothur.git/blob - randomforest.cpp
6f40b7c1361b3f366cca5971eeb709b0119b2240
[mothur.git] / randomforest.cpp
1 //
2 //  randomforest.cpp
3 //  Mothur
4 //
5 //  Created by Sarah Westcott on 10/2/12.
6 //  Copyright (c) 2012 Schloss Lab. All rights reserved.
7 //
8
9 #include "randomforest.hpp" 
10
11 /***********************************************************************/
12
13 RandomForest::RandomForest(const vector <vector<int> > dataSet,
14                            const int numDecisionTrees,
15                            const string treeSplitCriterion = "gainratio",
16                            const bool doPruning = false,
17                            const float pruneAggressiveness = 0.9,
18                            const bool discardHighErrorTrees = true,
19                            const float highErrorTreeDiscardThreshold = 0.4,
20                            const string optimumFeatureSubsetSelectionCriteria = "log2",
21                            const float featureStandardDeviationThreshold = 0.0)
22             : Forest(dataSet, numDecisionTrees, treeSplitCriterion, doPruning, pruneAggressiveness, discardHighErrorTrees, highErrorTreeDiscardThreshold, optimumFeatureSubsetSelectionCriteria, featureStandardDeviationThreshold) {
23     m = MothurOut::getInstance();
24 }
25
26 /***********************************************************************/
27 // DONE
28 int RandomForest::calcForrestErrorRate() {
29     try {
30         int numCorrect = 0;
31         for (map<int, vector<int> >::iterator it = globalOutOfBagEstimates.begin(); it != globalOutOfBagEstimates.end(); it++) {
32             
33             if (m->control_pressed) { return 0; }
34             
35             int indexOfSample = it->first;
36             vector<int> predictedOutComes = it->second;
37             vector<int>::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end());
38             int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin());
39             int realOutcome = dataSet[indexOfSample][numFeatures];
40                                    
41             if (majorityVotedOutcome == realOutcome) { numCorrect++; }
42         }
43         
44         // TODO: save or return forrestErrorRate for future use;
45         double forrestErrorRate = 1 - ((double)numCorrect / (double)globalOutOfBagEstimates.size());
46         
47         m->mothurOut("numCorrect = " + toString(numCorrect)+ "\n");
48         m->mothurOut("forrestErrorRate = " + toString(forrestErrorRate)+ "\n");
49             
50         return 0;
51     }
52         catch(exception& e) {
53                 m->errorOut(e, "RandomForest", "calcForrestErrorRate");
54                 exit(1);
55         } 
56 }
57 /***********************************************************************/
58
59 int RandomForest::printConfusionMatrix(map<int, string> intToTreatmentMap) {
60     try {
61         int numGroups = intToTreatmentMap.size();
62         vector<vector<int> > cm(numGroups, vector<int>(numGroups, 0));
63         
64         for (map<int, vector<int> >::iterator it = globalOutOfBagEstimates.begin(); it != globalOutOfBagEstimates.end(); it++) {
65             
66             if (m->control_pressed) { return 0; }
67             
68             int indexOfSample = it->first; //key
69             vector<int> predictedOutComes = it->second; //value, vector of all predicted classes
70             vector<int>::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end());
71             int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin());
72             int realOutcome = dataSet[indexOfSample][numFeatures];                       
73             cm[realOutcome][majorityVotedOutcome] = cm[realOutcome][majorityVotedOutcome] + 1;
74         }
75         
76         vector<int> fw;
77         for (int w = 0; w <numGroups; w++) {
78             fw.push_back(intToTreatmentMap[w].length());
79         }
80         
81         m->mothurOut("confusion matrix:\n\t\t");
82         for (int k = 0; k < numGroups; k++) {
83             //m->mothurOut(intToTreatmentMap[k] + "\t");
84             cout << setw(fw[k]) << intToTreatmentMap[k] << "\t";
85         }
86         for (int i = 0; i < numGroups; i++) {
87             cout << "\n" << setw(fw[i]) << intToTreatmentMap[i] << "\t";
88             //m->mothurOut("\n" + intToTreatmentMap[i] + "\t");
89             if (m->control_pressed) { return 0; }
90             for (int j = 0; j < numGroups; j++) {
91                 //m->mothurOut(toString(cm[i][j]) + "\t");
92                 cout << setw(fw[i]) << cm[i][j] << "\t";
93             }    
94         }
95         //m->mothurOut("\n");
96         cout << "\n";
97
98         return 0;
99     }
100     
101     catch(exception& e) {
102                 m->errorOut(e, "RandomForest", "printConfusionMatrix");
103                 exit(1);
104         }
105 }
106
107 /***********************************************************************/
108
109 int RandomForest::getMissclassifications(string filename, map<int, string> intToTreatmentMap, vector<string> names) {
110     try {
111         ofstream out;
112         m->openOutputFile(filename, out);
113         out <<"Sample\tRF classification\tActual classification\n";
114         for (map<int, vector<int> >::iterator it = globalOutOfBagEstimates.begin(); it != globalOutOfBagEstimates.end(); it++) {
115             
116             if (m->control_pressed) { return 0; }
117             
118             int indexOfSample = it->first;
119             vector<int> predictedOutComes = it->second;
120             vector<int>::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end());
121             int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin());
122             int realOutcome = dataSet[indexOfSample][numFeatures];
123                                    
124             if (majorityVotedOutcome != realOutcome) {             
125                 //write to file
126                 //dataSet[indexOfSample][];
127                 
128                 out << names[indexOfSample] << "\t" << intToTreatmentMap[majorityVotedOutcome] << "\t" << intToTreatmentMap[realOutcome] << endl;
129                 //out << m->currentBinLabels[(int)globalVariableRanks[i].first] << '\t' << globalVariableImportanceList[globalVariableRanks[i].first] << endl;
130                                 
131             }
132         }
133         
134         out.close();    
135         return 0;
136     }
137         catch(exception& e) {
138                 m->errorOut(e, "RandomForest", "calcForrestErrorRate");
139                 exit(1);
140         } 
141 }
142
143 /***********************************************************************/
144 int RandomForest::calcForrestVariableImportance(string filename) {
145     try {
146     
147         // follow the link: http://en.wikipedia.org/wiki/Dynamic_cast
148         //if you are going to dynamically cast, aren't you undoing the advantage of abstraction. Why abstract at all?
149         //could cause maintenance issues later if other types of Abstract decison trees are created that cannot be cast as a decision tree.
150         for (int i = 0; i < decisionTrees.size(); i++) {
151             if (m->control_pressed) { return 0; }
152             
153             DecisionTree* decisionTree = dynamic_cast<DecisionTree*>(decisionTrees[i]);
154             
155             for (int j = 0; j < numFeatures; j++) {
156                 globalVariableImportanceList[j] += (double)decisionTree->variableImportanceList[j];
157             }
158         }
159         
160         for (int i = 0;  i < numFeatures; i++) {
161             globalVariableImportanceList[i] /= (double)numDecisionTrees;
162         }
163         
164         vector< pair<int, double> > globalVariableRanks;
165         for (int i = 0; i < globalVariableImportanceList.size(); i++) {
166             //cout << "[" << i << ',' << globalVariableImportanceList[i] << "], ";
167             if (globalVariableImportanceList[i] > 0) {
168                 pair<int, double> globalVariableRank(0, 0.0);
169                 globalVariableRank.first = i;
170                 globalVariableRank.second = globalVariableImportanceList[i];
171                 globalVariableRanks.push_back(globalVariableRank);
172             }
173         }
174         
175 //        for (int i = 0; i < globalVariableRanks.size(); i++) {
176 //            cout << m->currentBinLabels[(int)globalVariableRanks[i][0]] << '\t' << globalVariableImportanceList[globalVariableRanks[i][0]] << endl;
177 //        }
178
179         
180         VariableRankDescendingSorterDouble variableRankDescendingSorter;
181         sort(globalVariableRanks.begin(), globalVariableRanks.end(), variableRankDescendingSorter);
182         
183         ofstream out;
184         m->openOutputFile(filename, out);
185         out <<"OTU\tMean decrease accuracy\n";
186         for (int i = 0; i < globalVariableRanks.size(); i++) {
187             out << m->currentBinLabels[(int)globalVariableRanks[i].first] << '\t' << globalVariableImportanceList[globalVariableRanks[i].first] << endl;
188         }
189         out.close();
190         return 0;
191     }
192         catch(exception& e) {
193                 m->errorOut(e, "RandomForest", "calcForrestVariableImportance");
194                 exit(1);
195         }  
196 }
197 /***********************************************************************/
198 int RandomForest::populateDecisionTrees() {
199     try {
200         
201         vector<double> errorRateImprovements;
202         
203         for (int i = 0; i < numDecisionTrees; i++) {
204           
205             if (m->control_pressed) { return 0; }
206             if (((i+1) % 100) == 0) {  m->mothurOut("Creating " + toString(i+1) + " (th) Decision tree\n");  }
207           
208             // TODO: need to first fix if we are going to use pointer based system or anything else
209             DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector(optimumFeatureSubsetSelectionCriteria), treeSplitCriterion, featureStandardDeviationThreshold);
210           
211             if (m->debug && doPruning) {
212                 m->mothurOut("Before pruning\n");
213                 decisionTree->printTree(decisionTree->rootNode, "ROOT");
214             }
215             
216             int numCorrect;
217             double treeErrorRate;
218             
219             decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate);
220             double prePrunedErrorRate = treeErrorRate;
221             
222             if (m->debug) {
223                 m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n");
224             }
225             
226             if (doPruning) {
227                 decisionTree->pruneTree(pruneAggressiveness);
228                 if (m->debug) {
229                     m->mothurOut("After pruning\n");
230                     decisionTree->printTree(decisionTree->rootNode, "ROOT");
231                 }
232                 decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate);
233             }
234             double postPrunedErrorRate = treeErrorRate;
235             
236           
237             decisionTree->calcTreeVariableImportanceAndError(numCorrect, treeErrorRate);
238             double errorRateImprovement = (prePrunedErrorRate - postPrunedErrorRate) / prePrunedErrorRate;
239
240             if (m->debug) {
241                 m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n");
242                 if (doPruning) {
243                     m->mothurOut("errorRateImprovement: " + toString(errorRateImprovement) + "\n");
244                 }
245             }
246             
247             
248             if (discardHighErrorTrees) {
249                 if (treeErrorRate < highErrorTreeDiscardThreshold) {
250                     updateGlobalOutOfBagEstimates(decisionTree);
251                     decisionTree->purgeDataSetsFromTree();
252                     decisionTrees.push_back(decisionTree);
253                     if (doPruning) {
254                         errorRateImprovements.push_back(errorRateImprovement);
255                     }
256                 } else {
257                     delete decisionTree;
258                 }
259             } else {
260                 updateGlobalOutOfBagEstimates(decisionTree);
261                 decisionTree->purgeDataSetsFromTree();
262                 decisionTrees.push_back(decisionTree);
263                 if (doPruning) {
264                     errorRateImprovements.push_back(errorRateImprovement);
265                 }
266             }          
267         }
268         
269         double avgErrorRateImprovement = -1.0;
270         if (errorRateImprovements.size() > 0) {
271             avgErrorRateImprovement = accumulate(errorRateImprovements.begin(), errorRateImprovements.end(), 0.0);
272 //            cout << "Total " << avgErrorRateImprovement << " size " << errorRateImprovements.size() << endl;
273             avgErrorRateImprovement /= errorRateImprovements.size();
274         }
275         
276         if (m->debug && doPruning) {
277             m->mothurOut("avgErrorRateImprovement:" + toString(avgErrorRateImprovement) + "\n");
278         }
279         // m->mothurOut("globalOutOfBagEstimates = " + toStringVectorMap(globalOutOfBagEstimates)+ "\n");
280
281         
282         return 0;
283     }
284     catch(exception& e) {
285         m->errorOut(e, "RandomForest", "populateDecisionTrees");
286         exit(1);
287     }  
288 }
289 /***********************************************************************/
290 // TODO: need to finalize bettween reference and pointer for DecisionTree [partially solved]
291 // DONE: make this pure virtual in superclass
292 // DONE
293 int RandomForest::updateGlobalOutOfBagEstimates(DecisionTree* decisionTree) {
294     try {
295         for (map<int, int>::iterator it = decisionTree->outOfBagEstimates.begin(); it != decisionTree->outOfBagEstimates.end(); it++) {
296             
297             if (m->control_pressed) { return 0; }
298             
299             int indexOfSample = it->first;
300             int predictedOutcomeOfSample = it->second;
301             
302             if (globalOutOfBagEstimates.count(indexOfSample) == 0) {
303                 globalOutOfBagEstimates[indexOfSample] = vector<int>(decisionTree->numOutputClasses, 0);
304             };
305             
306             globalOutOfBagEstimates[indexOfSample][predictedOutcomeOfSample] += 1;
307         }
308         return 0;
309     }
310     catch(exception& e) {
311         m->errorOut(e, "RandomForest", "updateGlobalOutOfBagEstimates");
312         exit(1);
313     }  
314 }
315 /***********************************************************************/
316
317