]> git.donarmstrong.com Git - mothur.git/blob - randomforest.cpp
added classify.shared command and random forest files. added count file to pcr.seqs...
[mothur.git] / randomforest.cpp
1 //
2 //  randomforest.cpp
3 //  Mothur
4 //
5 //  Created by Sarah Westcott on 10/2/12.
6 //  Copyright (c) 2012 Schloss Lab. All rights reserved.
7 //
8
9 #include "randomforest.hpp" 
10
11 /***********************************************************************/
12
13 RandomForest::RandomForest(const vector <vector<int> > dataSet,const int numDecisionTrees,
14              const string treeSplitCriterion = "informationGain") : AbstractRandomForest(dataSet, numDecisionTrees, treeSplitCriterion) {
15     m = MothurOut::getInstance();
16 }
17
18 /***********************************************************************/
19 // DONE
20 int RandomForest::calcForrestErrorRate() {
21     try {
22         int numCorrect = 0;
23         for (map<int, vector<int> >::iterator it = globalOutOfBagEstimates.begin(); it != globalOutOfBagEstimates.end(); it++) {
24             
25             if (m->control_pressed) { return 0; }
26             
27             int indexOfSample = it->first;
28             vector<int> predictedOutComes = it->second;
29             vector<int>::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end());
30             int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin());
31             int realOutcome = dataSet[indexOfSample][numFeatures];
32             
33             if (majorityVotedOutcome == realOutcome) { numCorrect++; }
34         }
35         
36         // TODO: save or return forrestErrorRate for future use;
37         double forrestErrorRate = 1 - ((double)numCorrect / (double)globalOutOfBagEstimates.size());
38         
39         m->mothurOut("numCorrect = " + toString(numCorrect)+ "\n");
40         m->mothurOut("forrestErrorRate = " + toString(forrestErrorRate)+ "\n");
41     
42         return 0;
43     }
44         catch(exception& e) {
45                 m->errorOut(e, "RandomForest", "calcForrestErrorRate");
46                 exit(1);
47         } 
48 }
49
50 /***********************************************************************/
51 // DONE
52 int RandomForest::calcForrestVariableImportance(string filename) {
53     try {
54     
55     // TODO: need to add try/catch operators to fix this
56     // follow the link: http://en.wikipedia.org/wiki/Dynamic_cast
57         //if you are going to dynamically cast, aren't you undoing the advantage of abstraction. Why abstract at all?
58         //could cause maintenance issues later if other types of Abstract decison trees are created that cannot be cast as a decision tree.
59     for (int i = 0; i < decisionTrees.size(); i++) {
60         if (m->control_pressed) { return 0; }
61         DecisionTree* decisionTree = dynamic_cast<DecisionTree*>(decisionTrees[i]);
62         
63         for (int j = 0; j < numFeatures; j++) {
64             globalVariableImportanceList[j] += (double)decisionTree->variableImportanceList[j];
65         }
66     }
67     
68     for (int i = 0;  i < numFeatures; i++) {
69         cout << "[" << i << ',' << globalVariableImportanceList[i] << "], ";
70         globalVariableImportanceList[i] /= (double)numDecisionTrees;
71     }
72     
73     vector< vector<double> > globalVariableRanks;
74     for (int i = 0; i < globalVariableImportanceList.size(); i++) {
75         if (globalVariableImportanceList[i] > 0) {
76             vector<double> globalVariableRank(2, 0);
77             globalVariableRank[0] = i; globalVariableRank[1] = globalVariableImportanceList[i];
78             globalVariableRanks.push_back(globalVariableRank);
79         }
80     }
81     
82     VariableRankDescendingSorterDouble variableRankDescendingSorter;
83     sort(globalVariableRanks.begin(), globalVariableRanks.end(), variableRankDescendingSorter);
84         ofstream out;
85         m->openOutputFile(filename, out);
86         out <<"OTU\tRank\n";
87         for (int i = 0; i < globalVariableRanks.size(); i++) {
88             out << m->currentBinLabels[(int)globalVariableRanks[i][0]] << '\t' << globalVariableImportanceList[globalVariableRanks[i][0]] << endl;
89         }
90         out.close();
91         return 0;
92     }
93         catch(exception& e) {
94                 m->errorOut(e, "RandomForest", "calcForrestVariableImportance");
95                 exit(1);
96         }  
97 }
98 /***********************************************************************/
99 // DONE
100 int RandomForest::populateDecisionTrees() {
101     try {
102         
103         for (int i = 0; i < numDecisionTrees; i++) {
104             if (m->control_pressed) { return 0; }
105             if (((i+1) % 10) == 0) {  m->mothurOut("Creating " + toString(i+1) + " (th) Decision tree\n");  }
106             // TODO: need to first fix if we are going to use pointer based system or anything else
107             DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector("log2"), treeSplitCriterion);
108             decisionTree->calcTreeVariableImportanceAndError();
109             if (m->control_pressed) { return 0; }
110             updateGlobalOutOfBagEstimates(decisionTree);
111             if (m->control_pressed) { return 0; }
112             decisionTree->purgeDataSetsFromTree();
113             if (m->control_pressed) { return 0; }
114             decisionTrees.push_back(decisionTree);
115         }
116         
117         if (m->debug) {
118             // m->mothurOut("globalOutOfBagEstimates = " + toStringVectorMap(globalOutOfBagEstimates)+ "\n");
119         }
120         
121         return 0;
122     }
123     catch(exception& e) {
124         m->errorOut(e, "RandomForest", "populateDecisionTrees");
125         exit(1);
126     }  
127 }
128 /***********************************************************************/
129 // TODO: need to finalize bettween reference and pointer for DecisionTree [partially solved]
130 // TODO: make this pure virtual in superclass
131 // DONE
132 int RandomForest::updateGlobalOutOfBagEstimates(DecisionTree* decisionTree) {
133     try {
134         for (map<int, int>::iterator it = decisionTree->outOfBagEstimates.begin(); it != decisionTree->outOfBagEstimates.end(); it++) {
135             
136             if (m->control_pressed) { return 0; }
137             
138             int indexOfSample = it->first;
139             int predictedOutcomeOfSample = it->second;
140             
141             if (globalOutOfBagEstimates.count(indexOfSample) == 0) {
142                 globalOutOfBagEstimates[indexOfSample] = vector<int>(decisionTree->numOutputClasses, 0);
143             };
144             
145             globalOutOfBagEstimates[indexOfSample][predictedOutcomeOfSample] += 1;
146         }
147         return 0;
148     }
149     catch(exception& e) {
150         m->errorOut(e, "RandomForest", "updateGlobalOutOfBagEstimates");
151         exit(1);
152     }  
153 }
154 /***********************************************************************/
155
156