5 // Created by Sarah Westcott on 10/2/12.
6 // Copyright (c) 2012 Schloss Lab. All rights reserved.
9 #include "randomforest.hpp"
11 /***********************************************************************/
13 RandomForest::RandomForest(const vector <vector<int> > dataSet,const int numDecisionTrees,
14 const string treeSplitCriterion = "informationGain") : Forest(dataSet, numDecisionTrees, treeSplitCriterion) {
15 m = MothurOut::getInstance();
18 /***********************************************************************/
20 int RandomForest::calcForrestErrorRate() {
23 for (map<int, vector<int> >::iterator it = globalOutOfBagEstimates.begin(); it != globalOutOfBagEstimates.end(); it++) {
25 if (m->control_pressed) { return 0; }
27 int indexOfSample = it->first;
28 vector<int> predictedOutComes = it->second;
29 vector<int>::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end());
30 int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin());
31 int realOutcome = dataSet[indexOfSample][numFeatures];
33 if (majorityVotedOutcome == realOutcome) { numCorrect++; }
36 // TODO: save or return forrestErrorRate for future use;
37 double forrestErrorRate = 1 - ((double)numCorrect / (double)globalOutOfBagEstimates.size());
39 m->mothurOut("numCorrect = " + toString(numCorrect)+ "\n");
40 m->mothurOut("forrestErrorRate = " + toString(forrestErrorRate)+ "\n");
45 m->errorOut(e, "RandomForest", "calcForrestErrorRate");
50 /***********************************************************************/
52 int RandomForest::calcForrestVariableImportance(string filename) {
55 // TODO: need to add try/catch operators to fix this
56 // follow the link: http://en.wikipedia.org/wiki/Dynamic_cast
57 //if you are going to dynamically cast, aren't you undoing the advantage of abstraction. Why abstract at all?
58 //could cause maintenance issues later if other types of Abstract decison trees are created that cannot be cast as a decision tree.
59 for (int i = 0; i < decisionTrees.size(); i++) {
60 if (m->control_pressed) { return 0; }
62 DecisionTree* decisionTree = dynamic_cast<DecisionTree*>(decisionTrees[i]);
64 for (int j = 0; j < numFeatures; j++) {
65 globalVariableImportanceList[j] += (double)decisionTree->variableImportanceList[j];
69 for (int i = 0; i < numFeatures; i++) {
70 cout << "[" << i << ',' << globalVariableImportanceList[i] << "], ";
71 globalVariableImportanceList[i] /= (double)numDecisionTrees;
74 vector< vector<double> > globalVariableRanks;
75 for (int i = 0; i < globalVariableImportanceList.size(); i++) {
76 if (globalVariableImportanceList[i] > 0) {
77 vector<double> globalVariableRank(2, 0);
78 globalVariableRank[0] = i; globalVariableRank[1] = globalVariableImportanceList[i];
79 globalVariableRanks.push_back(globalVariableRank);
83 VariableRankDescendingSorterDouble variableRankDescendingSorter;
84 sort(globalVariableRanks.begin(), globalVariableRanks.end(), variableRankDescendingSorter);
86 m->openOutputFile(filename, out);
88 for (int i = 0; i < globalVariableRanks.size(); i++) {
89 out << m->currentBinLabels[(int)globalVariableRanks[i][0]] << '\t' << globalVariableImportanceList[globalVariableRanks[i][0]] << endl;
95 m->errorOut(e, "RandomForest", "calcForrestVariableImportance");
99 /***********************************************************************/
101 int RandomForest::populateDecisionTrees() {
104 for (int i = 0; i < numDecisionTrees; i++) {
105 if (m->control_pressed) { return 0; }
106 if (((i+1) % 10) == 0) { m->mothurOut("Creating " + toString(i+1) + " (th) Decision tree\n"); }
107 // TODO: need to first fix if we are going to use pointer based system or anything else
108 DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector("log2"), treeSplitCriterion);
109 decisionTree->calcTreeVariableImportanceAndError();
110 if (m->control_pressed) { return 0; }
111 updateGlobalOutOfBagEstimates(decisionTree);
112 if (m->control_pressed) { return 0; }
113 decisionTree->purgeDataSetsFromTree();
114 if (m->control_pressed) { return 0; }
115 decisionTrees.push_back(decisionTree);
119 // m->mothurOut("globalOutOfBagEstimates = " + toStringVectorMap(globalOutOfBagEstimates)+ "\n");
124 catch(exception& e) {
125 m->errorOut(e, "RandomForest", "populateDecisionTrees");
129 /***********************************************************************/
130 // TODO: need to finalize bettween reference and pointer for DecisionTree [partially solved]
131 // DONE: make this pure virtual in superclass
133 int RandomForest::updateGlobalOutOfBagEstimates(DecisionTree* decisionTree) {
135 for (map<int, int>::iterator it = decisionTree->outOfBagEstimates.begin(); it != decisionTree->outOfBagEstimates.end(); it++) {
137 if (m->control_pressed) { return 0; }
139 int indexOfSample = it->first;
140 int predictedOutcomeOfSample = it->second;
142 if (globalOutOfBagEstimates.count(indexOfSample) == 0) {
143 globalOutOfBagEstimates[indexOfSample] = vector<int>(decisionTree->numOutputClasses, 0);
146 globalOutOfBagEstimates[indexOfSample][predictedOutcomeOfSample] += 1;
150 catch(exception& e) {
151 m->errorOut(e, "RandomForest", "updateGlobalOutOfBagEstimates");
155 /***********************************************************************/