5 // Created by Sarah Westcott on 10/2/12.
6 // Copyright (c) 2012 Schloss Lab. All rights reserved.
9 #include "randomforest.hpp"
11 /***********************************************************************/
13 RandomForest::RandomForest(const vector <vector<int> > dataSet,
14 const int numDecisionTrees,
15 const string treeSplitCriterion = "gainratio",
16 const bool doPruning = false,
17 const float pruneAggressiveness = 0.9,
18 const bool discardHighErrorTrees = true,
19 const float highErrorTreeDiscardThreshold = 0.4,
20 const string optimumFeatureSubsetSelectionCriteria = "log2",
21 const float featureStandardDeviationThreshold = 0.0)
22 : Forest(dataSet, numDecisionTrees, treeSplitCriterion, doPruning, pruneAggressiveness, discardHighErrorTrees, highErrorTreeDiscardThreshold, optimumFeatureSubsetSelectionCriteria, featureStandardDeviationThreshold) {
23 m = MothurOut::getInstance();
26 /***********************************************************************/
28 int RandomForest::calcForrestErrorRate() {
31 for (map<int, vector<int> >::iterator it = globalOutOfBagEstimates.begin(); it != globalOutOfBagEstimates.end(); it++) {
33 if (m->control_pressed) { return 0; }
35 int indexOfSample = it->first;
36 vector<int> predictedOutComes = it->second;
37 vector<int>::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end());
38 int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin());
39 int realOutcome = dataSet[indexOfSample][numFeatures];
41 if (majorityVotedOutcome == realOutcome) { numCorrect++; }
44 // TODO: save or return forrestErrorRate for future use;
45 double forrestErrorRate = 1 - ((double)numCorrect / (double)globalOutOfBagEstimates.size());
47 m->mothurOut("numCorrect = " + toString(numCorrect)+ "\n");
48 m->mothurOut("forrestErrorRate = " + toString(forrestErrorRate)+ "\n");
53 m->errorOut(e, "RandomForest", "calcForrestErrorRate");
58 /***********************************************************************/
59 int RandomForest::calcForrestVariableImportance(string filename) {
62 // follow the link: http://en.wikipedia.org/wiki/Dynamic_cast
63 //if you are going to dynamically cast, aren't you undoing the advantage of abstraction. Why abstract at all?
64 //could cause maintenance issues later if other types of Abstract decison trees are created that cannot be cast as a decision tree.
65 for (int i = 0; i < decisionTrees.size(); i++) {
66 if (m->control_pressed) { return 0; }
68 DecisionTree* decisionTree = dynamic_cast<DecisionTree*>(decisionTrees[i]);
70 for (int j = 0; j < numFeatures; j++) {
71 globalVariableImportanceList[j] += (double)decisionTree->variableImportanceList[j];
75 for (int i = 0; i < numFeatures; i++) {
76 globalVariableImportanceList[i] /= (double)numDecisionTrees;
79 vector< pair<int, double> > globalVariableRanks;
80 for (int i = 0; i < globalVariableImportanceList.size(); i++) {
81 //cout << "[" << i << ',' << globalVariableImportanceList[i] << "], ";
82 if (globalVariableImportanceList[i] > 0) {
83 pair<int, double> globalVariableRank(0, 0.0);
84 globalVariableRank.first = i;
85 globalVariableRank.second = globalVariableImportanceList[i];
86 globalVariableRanks.push_back(globalVariableRank);
90 // for (int i = 0; i < globalVariableRanks.size(); i++) {
91 // cout << m->currentBinLabels[(int)globalVariableRanks[i][0]] << '\t' << globalVariableImportanceList[globalVariableRanks[i][0]] << endl;
95 VariableRankDescendingSorterDouble variableRankDescendingSorter;
96 sort(globalVariableRanks.begin(), globalVariableRanks.end(), variableRankDescendingSorter);
99 m->openOutputFile(filename, out);
101 for (int i = 0; i < globalVariableRanks.size(); i++) {
102 out << m->currentBinLabels[(int)globalVariableRanks[i].first] << '\t' << globalVariableImportanceList[globalVariableRanks[i].first] << endl;
107 catch(exception& e) {
108 m->errorOut(e, "RandomForest", "calcForrestVariableImportance");
112 /***********************************************************************/
113 int RandomForest::populateDecisionTrees() {
116 vector<double> errorRateImprovements;
118 for (int i = 0; i < numDecisionTrees; i++) {
120 if (m->control_pressed) { return 0; }
121 if (((i+1) % 100) == 0) { m->mothurOut("Creating " + toString(i+1) + " (th) Decision tree\n"); }
123 // TODO: need to first fix if we are going to use pointer based system or anything else
124 DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector(optimumFeatureSubsetSelectionCriteria), treeSplitCriterion, featureStandardDeviationThreshold);
126 if (m->debug && doPruning) {
127 m->mothurOut("Before pruning\n");
128 decisionTree->printTree(decisionTree->rootNode, "ROOT");
132 double treeErrorRate;
134 decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate);
135 double prePrunedErrorRate = treeErrorRate;
138 m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n");
142 decisionTree->pruneTree(pruneAggressiveness);
144 m->mothurOut("After pruning\n");
145 decisionTree->printTree(decisionTree->rootNode, "ROOT");
147 decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate);
149 double postPrunedErrorRate = treeErrorRate;
152 decisionTree->calcTreeVariableImportanceAndError(numCorrect, treeErrorRate);
153 double errorRateImprovement = (prePrunedErrorRate - postPrunedErrorRate) / prePrunedErrorRate;
156 m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n");
158 m->mothurOut("errorRateImprovement: " + toString(errorRateImprovement) + "\n");
163 if (discardHighErrorTrees) {
164 if (treeErrorRate < highErrorTreeDiscardThreshold) {
165 updateGlobalOutOfBagEstimates(decisionTree);
166 decisionTree->purgeDataSetsFromTree();
167 decisionTrees.push_back(decisionTree);
169 errorRateImprovements.push_back(errorRateImprovement);
175 updateGlobalOutOfBagEstimates(decisionTree);
176 decisionTree->purgeDataSetsFromTree();
177 decisionTrees.push_back(decisionTree);
179 errorRateImprovements.push_back(errorRateImprovement);
184 double avgErrorRateImprovement = -1.0;
185 if (errorRateImprovements.size() > 0) {
186 avgErrorRateImprovement = accumulate(errorRateImprovements.begin(), errorRateImprovements.end(), 0.0);
187 // cout << "Total " << avgErrorRateImprovement << " size " << errorRateImprovements.size() << endl;
188 avgErrorRateImprovement /= errorRateImprovements.size();
191 if (m->debug && doPruning) {
192 m->mothurOut("avgErrorRateImprovement:" + toString(avgErrorRateImprovement) + "\n");
194 // m->mothurOut("globalOutOfBagEstimates = " + toStringVectorMap(globalOutOfBagEstimates)+ "\n");
199 catch(exception& e) {
200 m->errorOut(e, "RandomForest", "populateDecisionTrees");
204 /***********************************************************************/
205 // TODO: need to finalize bettween reference and pointer for DecisionTree [partially solved]
206 // DONE: make this pure virtual in superclass
208 int RandomForest::updateGlobalOutOfBagEstimates(DecisionTree* decisionTree) {
210 for (map<int, int>::iterator it = decisionTree->outOfBagEstimates.begin(); it != decisionTree->outOfBagEstimates.end(); it++) {
212 if (m->control_pressed) { return 0; }
214 int indexOfSample = it->first;
215 int predictedOutcomeOfSample = it->second;
217 if (globalOutOfBagEstimates.count(indexOfSample) == 0) {
218 globalOutOfBagEstimates[indexOfSample] = vector<int>(decisionTree->numOutputClasses, 0);
221 globalOutOfBagEstimates[indexOfSample][predictedOutcomeOfSample] += 1;
225 catch(exception& e) {
226 m->errorOut(e, "RandomForest", "updateGlobalOutOfBagEstimates");
230 /***********************************************************************/