5 // Created by Kathryn Iverson on 10/26/12.
6 // Copyright (c) 2012 Schloss Lab. All rights reserved.
11 /***********************************************************************/
12 Forest::Forest(const std::vector < std::vector<int> > dataSet,
13 const int numDecisionTrees,
14 const string treeSplitCriterion = "gainratio",
15 const bool doPruning = false,
16 const float pruneAggressiveness = 0.9,
17 const bool discardHighErrorTrees = true,
18 const float highErrorTreeDiscardThreshold = 0.4,
19 const string optimumFeatureSubsetSelectionCriteria = "log2",
20 const float featureStandardDeviationThreshold = 0.0)
22 numDecisionTrees(numDecisionTrees),
23 numSamples((int)dataSet.size()),
24 numFeatures((int)(dataSet[0].size() - 1)),
25 globalVariableImportanceList(numFeatures, 0),
26 treeSplitCriterion(treeSplitCriterion),
28 pruneAggressiveness(pruneAggressiveness),
29 discardHighErrorTrees(discardHighErrorTrees),
30 highErrorTreeDiscardThreshold(highErrorTreeDiscardThreshold),
31 optimumFeatureSubsetSelectionCriteria(optimumFeatureSubsetSelectionCriteria),
32 featureStandardDeviationThreshold(featureStandardDeviationThreshold)
35 m = MothurOut::getInstance();
36 globalDiscardedFeatureIndices = getGlobalDiscardedFeatureIndices();
37 // TODO: double check if the implemenatation of 'globalOutOfBagEstimates' is correct
40 /***********************************************************************/
42 vector<int> Forest::getGlobalDiscardedFeatureIndices() {
44 //vector<int> globalDiscardedFeatureIndices;
45 //globalDiscardedFeatureIndices.push_back(1);
47 // calculate feature vectors
48 vector< vector<int> > featureVectors(numFeatures, vector<int>(numSamples, 0) );
49 for (int i = 0; i < numSamples; i++) {
50 if (m->control_pressed) { return globalDiscardedFeatureIndices; }
51 for (int j = 0; j < numFeatures; j++) { featureVectors[j][i] = dataSet[i][j]; }
54 for (int i = 0; i < featureVectors.size(); i++) {
55 if (m->control_pressed) { return globalDiscardedFeatureIndices; }
56 double standardDeviation = m->getStandardDeviation(featureVectors[i]);
57 if (standardDeviation <= featureStandardDeviationThreshold){ globalDiscardedFeatureIndices.push_back(i); }
61 m->mothurOut("number of global discarded features: " + toString(globalDiscardedFeatureIndices.size())+ "\n");
62 m->mothurOut("total features: " + toString(featureVectors.size())+ "\n");
65 return globalDiscardedFeatureIndices;
68 m->errorOut(e, "Forest", "getGlobalDiscardedFeatureIndices");
73 /***********************************************************************/