]> git.donarmstrong.com Git - mothur.git/blob - randomforest.cpp
adding code for confusion matrix
[mothur.git] / randomforest.cpp
1 //
2 //  randomforest.cpp
3 //  Mothur
4 //
5 //  Created by Sarah Westcott on 10/2/12.
6 //  Copyright (c) 2012 Schloss Lab. All rights reserved.
7 //
8
9 #include "randomforest.hpp" 
10
11 /***********************************************************************/
12
13 RandomForest::RandomForest(const vector <vector<int> > dataSet,
14                            const int numDecisionTrees,
15                            const string treeSplitCriterion = "gainratio",
16                            const bool doPruning = false,
17                            const float pruneAggressiveness = 0.9,
18                            const bool discardHighErrorTrees = true,
19                            const float highErrorTreeDiscardThreshold = 0.4,
20                            const string optimumFeatureSubsetSelectionCriteria = "log2",
21                            const float featureStandardDeviationThreshold = 0.0)
22             : Forest(dataSet, numDecisionTrees, treeSplitCriterion, doPruning, pruneAggressiveness, discardHighErrorTrees, highErrorTreeDiscardThreshold, optimumFeatureSubsetSelectionCriteria, featureStandardDeviationThreshold) {
23     m = MothurOut::getInstance();
24 }
25
26 /***********************************************************************/
27 // DONE
28 int RandomForest::calcForrestErrorRate() {
29     try {
30         int numCorrect = 0;
31         for (map<int, vector<int> >::iterator it = globalOutOfBagEstimates.begin(); it != globalOutOfBagEstimates.end(); it++) {
32             
33             if (m->control_pressed) { return 0; }
34             
35             int indexOfSample = it->first;
36             vector<int> predictedOutComes = it->second;
37             vector<int>::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end());
38             int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin());
39             int realOutcome = dataSet[indexOfSample][numFeatures];
40                                    
41             if (majorityVotedOutcome == realOutcome) { numCorrect++; }
42         }
43         
44         // TODO: save or return forrestErrorRate for future use;
45         double forrestErrorRate = 1 - ((double)numCorrect / (double)globalOutOfBagEstimates.size());
46         
47         m->mothurOut("numCorrect = " + toString(numCorrect)+ "\n");
48         m->mothurOut("forrestErrorRate = " + toString(forrestErrorRate)+ "\n");
49             
50         return 0;
51     }
52         catch(exception& e) {
53                 m->errorOut(e, "RandomForest", "calcForrestErrorRate");
54                 exit(1);
55         } 
56 }
57 /***********************************************************************/
58
59 int RandomForest::printConfusionMatrix(map<int, string> intToTreatmentMap) {
60     try {
61         int numGroups = intToTreatmentMap.size();
62         vector<vector<int> > cm(numGroups, vector<int>(numGroups, 0));
63         
64         for (map<int, vector<int> >::iterator it = globalOutOfBagEstimates.begin(); it != globalOutOfBagEstimates.end(); it++) {
65             
66             if (m->control_pressed) { return 0; }
67             
68             int indexOfSample = it->first; //key
69             vector<int> predictedOutComes = it->second; //value, vector of all predicted classes
70             vector<int>::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end());
71             int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin());
72             int realOutcome = dataSet[indexOfSample][numFeatures];                       
73             cm[realOutcome][majorityVotedOutcome] = cm[realOutcome][majorityVotedOutcome] + 1;            
74         }
75         
76         m->mothurOut("confusion matrix:\n\t\t");
77         for (int i = 0; i < numGroups; i++) {
78             m->mothurOut(intToTreatmentMap[i] + "\t");
79         }
80         for (int i = 0; i < numGroups; i++) {
81             //m->mothurOut("\n" + intToTreatmentMap[i] + "\t");
82             if (m->control_pressed) { return 0; }
83             for (int j = 0; j < numGroups; j++) {
84                 m->mothurOut(cm[i][j] + "\t");            
85             }    
86         }
87         m->mothurOut("\n");
88     
89         return 0;
90     }
91     
92     catch(exception& e) {
93                 m->errorOut(e, "RandomForest", "printConfusionMatrix");
94                 exit(1);
95         }
96 }
97
98 /***********************************************************************/
99 int RandomForest::calcForrestVariableImportance(string filename) {
100     try {
101     
102         // follow the link: http://en.wikipedia.org/wiki/Dynamic_cast
103         //if you are going to dynamically cast, aren't you undoing the advantage of abstraction. Why abstract at all?
104         //could cause maintenance issues later if other types of Abstract decison trees are created that cannot be cast as a decision tree.
105         for (int i = 0; i < decisionTrees.size(); i++) {
106             if (m->control_pressed) { return 0; }
107             
108             DecisionTree* decisionTree = dynamic_cast<DecisionTree*>(decisionTrees[i]);
109             
110             for (int j = 0; j < numFeatures; j++) {
111                 globalVariableImportanceList[j] += (double)decisionTree->variableImportanceList[j];
112             }
113         }
114         
115         for (int i = 0;  i < numFeatures; i++) {
116             globalVariableImportanceList[i] /= (double)numDecisionTrees;
117         }
118         
119         vector< pair<int, double> > globalVariableRanks;
120         for (int i = 0; i < globalVariableImportanceList.size(); i++) {
121             //cout << "[" << i << ',' << globalVariableImportanceList[i] << "], ";
122             if (globalVariableImportanceList[i] > 0) {
123                 pair<int, double> globalVariableRank(0, 0.0);
124                 globalVariableRank.first = i;
125                 globalVariableRank.second = globalVariableImportanceList[i];
126                 globalVariableRanks.push_back(globalVariableRank);
127             }
128         }
129         
130 //        for (int i = 0; i < globalVariableRanks.size(); i++) {
131 //            cout << m->currentBinLabels[(int)globalVariableRanks[i][0]] << '\t' << globalVariableImportanceList[globalVariableRanks[i][0]] << endl;
132 //        }
133
134         
135         VariableRankDescendingSorterDouble variableRankDescendingSorter;
136         sort(globalVariableRanks.begin(), globalVariableRanks.end(), variableRankDescendingSorter);
137         
138         ofstream out;
139         m->openOutputFile(filename, out);
140         out <<"OTU\tMean decrease accuracy\n";
141         for (int i = 0; i < globalVariableRanks.size(); i++) {
142             out << m->currentBinLabels[(int)globalVariableRanks[i].first] << '\t' << globalVariableImportanceList[globalVariableRanks[i].first] << endl;
143         }
144         out.close();
145         return 0;
146     }
147         catch(exception& e) {
148                 m->errorOut(e, "RandomForest", "calcForrestVariableImportance");
149                 exit(1);
150         }  
151 }
152 /***********************************************************************/
153 int RandomForest::populateDecisionTrees() {
154     try {
155         
156         vector<double> errorRateImprovements;
157         
158         for (int i = 0; i < numDecisionTrees; i++) {
159           
160             if (m->control_pressed) { return 0; }
161             if (((i+1) % 100) == 0) {  m->mothurOut("Creating " + toString(i+1) + " (th) Decision tree\n");  }
162           
163             // TODO: need to first fix if we are going to use pointer based system or anything else
164             DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector(optimumFeatureSubsetSelectionCriteria), treeSplitCriterion, featureStandardDeviationThreshold);
165           
166             if (m->debug && doPruning) {
167                 m->mothurOut("Before pruning\n");
168                 decisionTree->printTree(decisionTree->rootNode, "ROOT");
169             }
170             
171             int numCorrect;
172             double treeErrorRate;
173             
174             decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate);
175             double prePrunedErrorRate = treeErrorRate;
176             
177             if (m->debug) {
178                 m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n");
179             }
180             
181             if (doPruning) {
182                 decisionTree->pruneTree(pruneAggressiveness);
183                 if (m->debug) {
184                     m->mothurOut("After pruning\n");
185                     decisionTree->printTree(decisionTree->rootNode, "ROOT");
186                 }
187                 decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate);
188             }
189             double postPrunedErrorRate = treeErrorRate;
190             
191           
192             decisionTree->calcTreeVariableImportanceAndError(numCorrect, treeErrorRate);
193             double errorRateImprovement = (prePrunedErrorRate - postPrunedErrorRate) / prePrunedErrorRate;
194
195             if (m->debug) {
196                 m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n");
197                 if (doPruning) {
198                     m->mothurOut("errorRateImprovement: " + toString(errorRateImprovement) + "\n");
199                 }
200             }
201             
202             
203             if (discardHighErrorTrees) {
204                 if (treeErrorRate < highErrorTreeDiscardThreshold) {
205                     updateGlobalOutOfBagEstimates(decisionTree);
206                     decisionTree->purgeDataSetsFromTree();
207                     decisionTrees.push_back(decisionTree);
208                     if (doPruning) {
209                         errorRateImprovements.push_back(errorRateImprovement);
210                     }
211                 } else {
212                     delete decisionTree;
213                 }
214             } else {
215                 updateGlobalOutOfBagEstimates(decisionTree);
216                 decisionTree->purgeDataSetsFromTree();
217                 decisionTrees.push_back(decisionTree);
218                 if (doPruning) {
219                     errorRateImprovements.push_back(errorRateImprovement);
220                 }
221             }          
222         }
223         
224         double avgErrorRateImprovement = -1.0;
225         if (errorRateImprovements.size() > 0) {
226             avgErrorRateImprovement = accumulate(errorRateImprovements.begin(), errorRateImprovements.end(), 0.0);
227 //            cout << "Total " << avgErrorRateImprovement << " size " << errorRateImprovements.size() << endl;
228             avgErrorRateImprovement /= errorRateImprovements.size();
229         }
230         
231         if (m->debug && doPruning) {
232             m->mothurOut("avgErrorRateImprovement:" + toString(avgErrorRateImprovement) + "\n");
233         }
234         // m->mothurOut("globalOutOfBagEstimates = " + toStringVectorMap(globalOutOfBagEstimates)+ "\n");
235
236         
237         return 0;
238     }
239     catch(exception& e) {
240         m->errorOut(e, "RandomForest", "populateDecisionTrees");
241         exit(1);
242     }  
243 }
244 /***********************************************************************/
245 // TODO: need to finalize bettween reference and pointer for DecisionTree [partially solved]
246 // DONE: make this pure virtual in superclass
247 // DONE
248 int RandomForest::updateGlobalOutOfBagEstimates(DecisionTree* decisionTree) {
249     try {
250         for (map<int, int>::iterator it = decisionTree->outOfBagEstimates.begin(); it != decisionTree->outOfBagEstimates.end(); it++) {
251             
252             if (m->control_pressed) { return 0; }
253             
254             int indexOfSample = it->first;
255             int predictedOutcomeOfSample = it->second;
256             
257             if (globalOutOfBagEstimates.count(indexOfSample) == 0) {
258                 globalOutOfBagEstimates[indexOfSample] = vector<int>(decisionTree->numOutputClasses, 0);
259             };
260             
261             globalOutOfBagEstimates[indexOfSample][predictedOutcomeOfSample] += 1;
262         }
263         return 0;
264     }
265     catch(exception& e) {
266         m->errorOut(e, "RandomForest", "updateGlobalOutOfBagEstimates");
267         exit(1);
268     }  
269 }
270 /***********************************************************************/
271
272