]> git.donarmstrong.com Git - mothur.git/blob - abstractrandomforest.cpp
changing command name classify.shared to classifyrf.shared
[mothur.git] / abstractrandomforest.cpp
1 //
2 //  abstractrandomforest.cpp
3 //  Mothur
4 //
5 //  Created by Sarah Westcott on 10/1/12.
6 //  Copyright (c) 2012 Schloss Lab. All rights reserved.
7 //
8
9 #include "abstractrandomforest.hpp"
10
11 /***********************************************************************/
12 AbstractRandomForest::AbstractRandomForest(const std::vector < std::vector<int> > dataSet, 
13                      const int numDecisionTrees, 
14                      const string treeSplitCriterion = "informationGain")
15 : dataSet(dataSet), 
16 numDecisionTrees(numDecisionTrees),
17 numSamples((int)dataSet.size()),
18 numFeatures((int)(dataSet[0].size() - 1)),
19 globalDiscardedFeatureIndices(getGlobalDiscardedFeatureIndices()),
20 globalVariableImportanceList(numFeatures, 0),
21 treeSplitCriterion(treeSplitCriterion) {
22     m = MothurOut::getInstance();
23     // TODO: double check if the implemenatation of 'globalOutOfBagEstimates' is correct
24 }
25
26 /***********************************************************************/
27
28 vector<int> AbstractRandomForest::getGlobalDiscardedFeatureIndices() {
29     try {
30         vector<int> globalDiscardedFeatureIndices;
31         
32         // calculate feature vectors
33         vector< vector<int> > featureVectors(numFeatures, vector<int>(numSamples, 0));
34         for (int i = 0; i < numSamples; i++) {
35             if (m->control_pressed) { return globalDiscardedFeatureIndices; }
36             for (int j = 0; j < numFeatures; j++) { featureVectors[j][i] = dataSet[i][j]; }
37         }
38         
39         for (int i = 0; i < featureVectors.size(); i++) {
40             if (m->control_pressed) { return globalDiscardedFeatureIndices; }
41             double standardDeviation = m->getStandardDeviation(featureVectors[i]);
42             if (standardDeviation <= 0){ globalDiscardedFeatureIndices.push_back(i); }
43         }
44         
45         if (m->debug) {
46             m->mothurOut("number of global discarded features:  " + toString(globalDiscardedFeatureIndices.size())+ "\n");
47             m->mothurOut("total features: " + toString(featureVectors.size())+ "\n");
48         }
49         
50         return globalDiscardedFeatureIndices;
51     }
52         catch(exception& e) {
53                 m->errorOut(e, "AbstractRandomForest", "getGlobalDiscardedFeatureIndices");
54                 exit(1);
55         } 
56 }
57
58 /***********************************************************************/
59