5 // Created by Sarah Westcott on 10/2/12.
6 // Copyright (c) 2012 Schloss Lab. All rights reserved.
9 #include "randomforest.hpp"
11 /***********************************************************************/
13 RandomForest::RandomForest(const vector <vector<int> > dataSet,
14 const int numDecisionTrees,
15 const string treeSplitCriterion = "gainratio",
16 const bool doPruning = false,
17 const float pruneAggressiveness = 0.9,
18 const bool discardHighErrorTrees = true,
19 const float highErrorTreeDiscardThreshold = 0.4,
20 const string optimumFeatureSubsetSelectionCriteria = "log2",
21 const float featureStandardDeviationThreshold = 0.0)
22 : Forest(dataSet, numDecisionTrees, treeSplitCriterion, doPruning, pruneAggressiveness, discardHighErrorTrees, highErrorTreeDiscardThreshold, optimumFeatureSubsetSelectionCriteria, featureStandardDeviationThreshold) {
23 m = MothurOut::getInstance();
26 /***********************************************************************/
28 int RandomForest::calcForrestErrorRate() {
31 for (map<int, vector<int> >::iterator it = globalOutOfBagEstimates.begin(); it != globalOutOfBagEstimates.end(); it++) {
33 if (m->control_pressed) { return 0; }
35 int indexOfSample = it->first;
36 vector<int> predictedOutComes = it->second;
37 vector<int>::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end());
38 int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin());
39 int realOutcome = dataSet[indexOfSample][numFeatures];
41 if (majorityVotedOutcome == realOutcome) { numCorrect++; }
44 // TODO: save or return forrestErrorRate for future use;
45 double forrestErrorRate = 1 - ((double)numCorrect / (double)globalOutOfBagEstimates.size());
47 m->mothurOut("numCorrect = " + toString(numCorrect)+ "\n");
48 m->mothurOut("forrestErrorRate = " + toString(forrestErrorRate)+ "\n");
53 m->errorOut(e, "RandomForest", "calcForrestErrorRate");
57 /***********************************************************************/
59 int RandomForest::printConfusionMatrix(map<int, string> intToTreatmentMap) {
61 int numGroups = intToTreatmentMap.size();
62 vector<vector<int> > cm(numGroups, vector<int>(numGroups, 0));
64 for (map<int, vector<int> >::iterator it = globalOutOfBagEstimates.begin(); it != globalOutOfBagEstimates.end(); it++) {
66 if (m->control_pressed) { return 0; }
68 int indexOfSample = it->first; //key
69 vector<int> predictedOutComes = it->second; //value, vector of all predicted classes
70 vector<int>::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end());
71 int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin());
72 int realOutcome = dataSet[indexOfSample][numFeatures];
73 cm[realOutcome][majorityVotedOutcome] = cm[realOutcome][majorityVotedOutcome] + 1;
77 for (int w = 0; w <numGroups; w++) {
78 fw.push_back(intToTreatmentMap[w].length());
81 m->mothurOut("confusion matrix:\n\t\t");
82 for (int k = 0; k < numGroups; k++) {
83 //m->mothurOut(intToTreatmentMap[k] + "\t");
84 cout << setw(fw[k]) << intToTreatmentMap[k] << "\t";
86 for (int i = 0; i < numGroups; i++) {
87 cout << "\n" << setw(fw[i]) << intToTreatmentMap[i] << "\t";
88 //m->mothurOut("\n" + intToTreatmentMap[i] + "\t");
89 if (m->control_pressed) { return 0; }
90 for (int j = 0; j < numGroups; j++) {
91 //m->mothurOut(toString(cm[i][j]) + "\t");
92 cout << setw(fw[i]) << cm[i][j] << "\t";
101 catch(exception& e) {
102 m->errorOut(e, "RandomForest", "printConfusionMatrix");
107 /***********************************************************************/
108 int RandomForest::calcForrestVariableImportance(string filename) {
111 // follow the link: http://en.wikipedia.org/wiki/Dynamic_cast
112 //if you are going to dynamically cast, aren't you undoing the advantage of abstraction. Why abstract at all?
113 //could cause maintenance issues later if other types of Abstract decison trees are created that cannot be cast as a decision tree.
114 for (int i = 0; i < decisionTrees.size(); i++) {
115 if (m->control_pressed) { return 0; }
117 DecisionTree* decisionTree = dynamic_cast<DecisionTree*>(decisionTrees[i]);
119 for (int j = 0; j < numFeatures; j++) {
120 globalVariableImportanceList[j] += (double)decisionTree->variableImportanceList[j];
124 for (int i = 0; i < numFeatures; i++) {
125 globalVariableImportanceList[i] /= (double)numDecisionTrees;
128 vector< pair<int, double> > globalVariableRanks;
129 for (int i = 0; i < globalVariableImportanceList.size(); i++) {
130 //cout << "[" << i << ',' << globalVariableImportanceList[i] << "], ";
131 if (globalVariableImportanceList[i] > 0) {
132 pair<int, double> globalVariableRank(0, 0.0);
133 globalVariableRank.first = i;
134 globalVariableRank.second = globalVariableImportanceList[i];
135 globalVariableRanks.push_back(globalVariableRank);
139 // for (int i = 0; i < globalVariableRanks.size(); i++) {
140 // cout << m->currentBinLabels[(int)globalVariableRanks[i][0]] << '\t' << globalVariableImportanceList[globalVariableRanks[i][0]] << endl;
144 VariableRankDescendingSorterDouble variableRankDescendingSorter;
145 sort(globalVariableRanks.begin(), globalVariableRanks.end(), variableRankDescendingSorter);
148 m->openOutputFile(filename, out);
149 out <<"OTU\tMean decrease accuracy\n";
150 for (int i = 0; i < globalVariableRanks.size(); i++) {
151 out << m->currentBinLabels[(int)globalVariableRanks[i].first] << '\t' << globalVariableImportanceList[globalVariableRanks[i].first] << endl;
156 catch(exception& e) {
157 m->errorOut(e, "RandomForest", "calcForrestVariableImportance");
161 /***********************************************************************/
162 int RandomForest::populateDecisionTrees() {
165 vector<double> errorRateImprovements;
167 for (int i = 0; i < numDecisionTrees; i++) {
169 if (m->control_pressed) { return 0; }
170 if (((i+1) % 100) == 0) { m->mothurOut("Creating " + toString(i+1) + " (th) Decision tree\n"); }
172 // TODO: need to first fix if we are going to use pointer based system or anything else
173 DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector(optimumFeatureSubsetSelectionCriteria), treeSplitCriterion, featureStandardDeviationThreshold);
175 if (m->debug && doPruning) {
176 m->mothurOut("Before pruning\n");
177 decisionTree->printTree(decisionTree->rootNode, "ROOT");
181 double treeErrorRate;
183 decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate);
184 double prePrunedErrorRate = treeErrorRate;
187 m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n");
191 decisionTree->pruneTree(pruneAggressiveness);
193 m->mothurOut("After pruning\n");
194 decisionTree->printTree(decisionTree->rootNode, "ROOT");
196 decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate);
198 double postPrunedErrorRate = treeErrorRate;
201 decisionTree->calcTreeVariableImportanceAndError(numCorrect, treeErrorRate);
202 double errorRateImprovement = (prePrunedErrorRate - postPrunedErrorRate) / prePrunedErrorRate;
205 m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n");
207 m->mothurOut("errorRateImprovement: " + toString(errorRateImprovement) + "\n");
212 if (discardHighErrorTrees) {
213 if (treeErrorRate < highErrorTreeDiscardThreshold) {
214 updateGlobalOutOfBagEstimates(decisionTree);
215 decisionTree->purgeDataSetsFromTree();
216 decisionTrees.push_back(decisionTree);
218 errorRateImprovements.push_back(errorRateImprovement);
224 updateGlobalOutOfBagEstimates(decisionTree);
225 decisionTree->purgeDataSetsFromTree();
226 decisionTrees.push_back(decisionTree);
228 errorRateImprovements.push_back(errorRateImprovement);
233 double avgErrorRateImprovement = -1.0;
234 if (errorRateImprovements.size() > 0) {
235 avgErrorRateImprovement = accumulate(errorRateImprovements.begin(), errorRateImprovements.end(), 0.0);
236 // cout << "Total " << avgErrorRateImprovement << " size " << errorRateImprovements.size() << endl;
237 avgErrorRateImprovement /= errorRateImprovements.size();
240 if (m->debug && doPruning) {
241 m->mothurOut("avgErrorRateImprovement:" + toString(avgErrorRateImprovement) + "\n");
243 // m->mothurOut("globalOutOfBagEstimates = " + toStringVectorMap(globalOutOfBagEstimates)+ "\n");
248 catch(exception& e) {
249 m->errorOut(e, "RandomForest", "populateDecisionTrees");
253 /***********************************************************************/
254 // TODO: need to finalize bettween reference and pointer for DecisionTree [partially solved]
255 // DONE: make this pure virtual in superclass
257 int RandomForest::updateGlobalOutOfBagEstimates(DecisionTree* decisionTree) {
259 for (map<int, int>::iterator it = decisionTree->outOfBagEstimates.begin(); it != decisionTree->outOfBagEstimates.end(); it++) {
261 if (m->control_pressed) { return 0; }
263 int indexOfSample = it->first;
264 int predictedOutcomeOfSample = it->second;
266 if (globalOutOfBagEstimates.count(indexOfSample) == 0) {
267 globalOutOfBagEstimates[indexOfSample] = vector<int>(decisionTree->numOutputClasses, 0);
270 globalOutOfBagEstimates[indexOfSample][predictedOutcomeOfSample] += 1;
274 catch(exception& e) {
275 m->errorOut(e, "RandomForest", "updateGlobalOutOfBagEstimates");
279 /***********************************************************************/