]> git.donarmstrong.com Git - mothur.git/blob - forest.cpp
working on pam
[mothur.git] / forest.cpp
1 //
2 //  forest.cpp
3 //  Mothur
4 //
5 //  Created by Kathryn Iverson on 10/26/12.
6 //  Copyright (c) 2012 Schloss Lab. All rights reserved.
7 //
8
9 #include "forest.h"
10
11 /***********************************************************************/
12 Forest::Forest(const std::vector < std::vector<int> > dataSet,
13                const int numDecisionTrees,
14                const string treeSplitCriterion = "gainratio",
15                const bool doPruning = false,
16                const float pruneAggressiveness = 0.9,
17                const bool discardHighErrorTrees = true,
18                const float highErrorTreeDiscardThreshold = 0.4,
19                const string optimumFeatureSubsetSelectionCriteria = "log2",
20                const float featureStandardDeviationThreshold = 0.0)
21       : dataSet(dataSet),
22         numDecisionTrees(numDecisionTrees),
23         numSamples((int)dataSet.size()),
24         numFeatures((int)(dataSet[0].size() - 1)),
25         globalVariableImportanceList(numFeatures, 0),
26         treeSplitCriterion(treeSplitCriterion),
27         doPruning(doPruning),
28         pruneAggressiveness(pruneAggressiveness),
29         discardHighErrorTrees(discardHighErrorTrees),
30         highErrorTreeDiscardThreshold(highErrorTreeDiscardThreshold),
31         optimumFeatureSubsetSelectionCriteria(optimumFeatureSubsetSelectionCriteria),
32         featureStandardDeviationThreshold(featureStandardDeviationThreshold)
33         {
34         
35     m = MothurOut::getInstance();
36     globalDiscardedFeatureIndices = getGlobalDiscardedFeatureIndices();
37     // TODO: double check if the implemenatation of 'globalOutOfBagEstimates' is correct
38 }
39
40 /***********************************************************************/
41
42 vector<int> Forest::getGlobalDiscardedFeatureIndices() {
43     try {
44         //vector<int> globalDiscardedFeatureIndices;
45         //globalDiscardedFeatureIndices.push_back(1);
46         
47         // calculate feature vectors
48         vector< vector<int> > featureVectors(numFeatures, vector<int>(numSamples, 0) );
49         for (int i = 0; i < numSamples; i++) {
50             if (m->control_pressed) { return globalDiscardedFeatureIndices; }
51             for (int j = 0; j < numFeatures; j++) { featureVectors[j][i] = dataSet[i][j]; }
52         }
53         
54         for (int i = 0; i < featureVectors.size(); i++) {
55             if (m->control_pressed) { return globalDiscardedFeatureIndices; }
56             double standardDeviation = m->getStandardDeviation(featureVectors[i]);
57             if (standardDeviation <= featureStandardDeviationThreshold){ globalDiscardedFeatureIndices.push_back(i); }
58         }
59         
60         if (m->debug) {
61             m->mothurOut("number of global discarded features:  " + toString(globalDiscardedFeatureIndices.size())+ "\n");
62             m->mothurOut("total features: " + toString(featureVectors.size())+ "\n");
63         }
64         
65         return globalDiscardedFeatureIndices;
66     }
67         catch(exception& e) {
68                 m->errorOut(e, "Forest", "getGlobalDiscardedFeatureIndices");
69                 exit(1);
70         }
71 }
72
73 /***********************************************************************/
74