]> git.donarmstrong.com Git - mothur.git/blob - randomforest.cpp
changed random forest output filename
[mothur.git] / randomforest.cpp
1 //
2 //  randomforest.cpp
3 //  Mothur
4 //
5 //  Created by Sarah Westcott on 10/2/12.
6 //  Copyright (c) 2012 Schloss Lab. All rights reserved.
7 //
8
9 #include "randomforest.hpp" 
10
11 /***********************************************************************/
12
13 RandomForest::RandomForest(const vector <vector<int> > dataSet,
14                            const int numDecisionTrees,
15                            const string treeSplitCriterion = "gainratio",
16                            const bool doPruning = false,
17                            const float pruneAggressiveness = 0.9,
18                            const bool discardHighErrorTrees = true,
19                            const float highErrorTreeDiscardThreshold = 0.4,
20                            const string optimumFeatureSubsetSelectionCriteria = "log2",
21                            const float featureStandardDeviationThreshold = 0.0)
22             : Forest(dataSet, numDecisionTrees, treeSplitCriterion, doPruning, pruneAggressiveness, discardHighErrorTrees, highErrorTreeDiscardThreshold, optimumFeatureSubsetSelectionCriteria, featureStandardDeviationThreshold) {
23     m = MothurOut::getInstance();
24 }
25
26 /***********************************************************************/
27 // DONE
28 int RandomForest::calcForrestErrorRate() {
29     try {
30         int numCorrect = 0;
31         for (map<int, vector<int> >::iterator it = globalOutOfBagEstimates.begin(); it != globalOutOfBagEstimates.end(); it++) {
32             
33             if (m->control_pressed) { return 0; }
34             
35             int indexOfSample = it->first;
36             vector<int> predictedOutComes = it->second;
37             vector<int>::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end());
38             int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin());
39             int realOutcome = dataSet[indexOfSample][numFeatures];
40             
41             if (majorityVotedOutcome == realOutcome) { numCorrect++; }
42         }
43         
44         // TODO: save or return forrestErrorRate for future use;
45         double forrestErrorRate = 1 - ((double)numCorrect / (double)globalOutOfBagEstimates.size());
46         
47         m->mothurOut("numCorrect = " + toString(numCorrect)+ "\n");
48         m->mothurOut("forrestErrorRate = " + toString(forrestErrorRate)+ "\n");
49     
50         return 0;
51     }
52         catch(exception& e) {
53                 m->errorOut(e, "RandomForest", "calcForrestErrorRate");
54                 exit(1);
55         } 
56 }
57
58 /***********************************************************************/
59 int RandomForest::calcForrestVariableImportance(string filename) {
60     try {
61     
62         // follow the link: http://en.wikipedia.org/wiki/Dynamic_cast
63         //if you are going to dynamically cast, aren't you undoing the advantage of abstraction. Why abstract at all?
64         //could cause maintenance issues later if other types of Abstract decison trees are created that cannot be cast as a decision tree.
65         for (int i = 0; i < decisionTrees.size(); i++) {
66             if (m->control_pressed) { return 0; }
67             
68             DecisionTree* decisionTree = dynamic_cast<DecisionTree*>(decisionTrees[i]);
69             
70             for (int j = 0; j < numFeatures; j++) {
71                 globalVariableImportanceList[j] += (double)decisionTree->variableImportanceList[j];
72             }
73         }
74         
75         for (int i = 0;  i < numFeatures; i++) {
76             globalVariableImportanceList[i] /= (double)numDecisionTrees;
77         }
78         
79         vector< pair<int, double> > globalVariableRanks;
80         for (int i = 0; i < globalVariableImportanceList.size(); i++) {
81             //cout << "[" << i << ',' << globalVariableImportanceList[i] << "], ";
82             if (globalVariableImportanceList[i] > 0) {
83                 pair<int, double> globalVariableRank(0, 0.0);
84                 globalVariableRank.first = i;
85                 globalVariableRank.second = globalVariableImportanceList[i];
86                 globalVariableRanks.push_back(globalVariableRank);
87             }
88         }
89         
90 //        for (int i = 0; i < globalVariableRanks.size(); i++) {
91 //            cout << m->currentBinLabels[(int)globalVariableRanks[i][0]] << '\t' << globalVariableImportanceList[globalVariableRanks[i][0]] << endl;
92 //        }
93
94         
95         VariableRankDescendingSorterDouble variableRankDescendingSorter;
96         sort(globalVariableRanks.begin(), globalVariableRanks.end(), variableRankDescendingSorter);
97         
98         ofstream out;
99         m->openOutputFile(filename, out);
100         out <<"OTU\tRank\n";
101         for (int i = 0; i < globalVariableRanks.size(); i++) {
102             out << m->currentBinLabels[(int)globalVariableRanks[i].first] << '\t' << globalVariableImportanceList[globalVariableRanks[i].first] << endl;
103         }
104         out.close();
105         return 0;
106     }
107         catch(exception& e) {
108                 m->errorOut(e, "RandomForest", "calcForrestVariableImportance");
109                 exit(1);
110         }  
111 }
112 /***********************************************************************/
113 int RandomForest::populateDecisionTrees() {
114     try {
115         
116         vector<double> errorRateImprovements;
117         
118         for (int i = 0; i < numDecisionTrees; i++) {
119           
120             if (m->control_pressed) { return 0; }
121             if (((i+1) % 100) == 0) {  m->mothurOut("Creating " + toString(i+1) + " (th) Decision tree\n");  }
122           
123             // TODO: need to first fix if we are going to use pointer based system or anything else
124             DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector(optimumFeatureSubsetSelectionCriteria), treeSplitCriterion, featureStandardDeviationThreshold);
125           
126             if (m->debug && doPruning) {
127                 m->mothurOut("Before pruning\n");
128                 decisionTree->printTree(decisionTree->rootNode, "ROOT");
129             }
130             
131             int numCorrect;
132             double treeErrorRate;
133             
134             decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate);
135             double prePrunedErrorRate = treeErrorRate;
136             
137             if (m->debug) {
138                 m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n");
139             }
140             
141             if (doPruning) {
142                 decisionTree->pruneTree(pruneAggressiveness);
143                 if (m->debug) {
144                     m->mothurOut("After pruning\n");
145                     decisionTree->printTree(decisionTree->rootNode, "ROOT");
146                 }
147                 decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate);
148             }
149             double postPrunedErrorRate = treeErrorRate;
150             
151           
152             decisionTree->calcTreeVariableImportanceAndError(numCorrect, treeErrorRate);
153             double errorRateImprovement = (prePrunedErrorRate - postPrunedErrorRate) / prePrunedErrorRate;
154
155             if (m->debug) {
156                 m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n");
157                 if (doPruning) {
158                     m->mothurOut("errorRateImprovement: " + toString(errorRateImprovement) + "\n");
159                 }
160             }
161             
162             
163             if (discardHighErrorTrees) {
164                 if (treeErrorRate < highErrorTreeDiscardThreshold) {
165                     updateGlobalOutOfBagEstimates(decisionTree);
166                     decisionTree->purgeDataSetsFromTree();
167                     decisionTrees.push_back(decisionTree);
168                     if (doPruning) {
169                         errorRateImprovements.push_back(errorRateImprovement);
170                     }
171                 } else {
172                     delete decisionTree;
173                 }
174             } else {
175                 updateGlobalOutOfBagEstimates(decisionTree);
176                 decisionTree->purgeDataSetsFromTree();
177                 decisionTrees.push_back(decisionTree);
178                 if (doPruning) {
179                     errorRateImprovements.push_back(errorRateImprovement);
180                 }
181             }          
182         }
183         
184         double avgErrorRateImprovement = -1.0;
185         if (errorRateImprovements.size() > 0) {
186             avgErrorRateImprovement = accumulate(errorRateImprovements.begin(), errorRateImprovements.end(), 0.0);
187 //            cout << "Total " << avgErrorRateImprovement << " size " << errorRateImprovements.size() << endl;
188             avgErrorRateImprovement /= errorRateImprovements.size();
189         }
190         
191         if (m->debug && doPruning) {
192             m->mothurOut("avgErrorRateImprovement:" + toString(avgErrorRateImprovement) + "\n");
193         }
194         // m->mothurOut("globalOutOfBagEstimates = " + toStringVectorMap(globalOutOfBagEstimates)+ "\n");
195
196         
197         return 0;
198     }
199     catch(exception& e) {
200         m->errorOut(e, "RandomForest", "populateDecisionTrees");
201         exit(1);
202     }  
203 }
204 /***********************************************************************/
205 // TODO: need to finalize bettween reference and pointer for DecisionTree [partially solved]
206 // DONE: make this pure virtual in superclass
207 // DONE
208 int RandomForest::updateGlobalOutOfBagEstimates(DecisionTree* decisionTree) {
209     try {
210         for (map<int, int>::iterator it = decisionTree->outOfBagEstimates.begin(); it != decisionTree->outOfBagEstimates.end(); it++) {
211             
212             if (m->control_pressed) { return 0; }
213             
214             int indexOfSample = it->first;
215             int predictedOutcomeOfSample = it->second;
216             
217             if (globalOutOfBagEstimates.count(indexOfSample) == 0) {
218                 globalOutOfBagEstimates[indexOfSample] = vector<int>(decisionTree->numOutputClasses, 0);
219             };
220             
221             globalOutOfBagEstimates[indexOfSample][predictedOutcomeOfSample] += 1;
222         }
223         return 0;
224     }
225     catch(exception& e) {
226         m->errorOut(e, "RandomForest", "updateGlobalOutOfBagEstimates");
227         exit(1);
228     }  
229 }
230 /***********************************************************************/
231
232