]> git.donarmstrong.com Git - mothur.git/blobdiff - forest.cpp
changes while testing
[mothur.git] / forest.cpp
index 8ac1b79ca12c540ccb8c38498cc62b0201bc2f41..3cfb5b9054cf5d43fbbb35eb7b81a6a2ba0f5833 100644 (file)
 
 /***********************************************************************/
 Forest::Forest(const std::vector < std::vector<int> > dataSet,
-                                           const int numDecisionTrees,
-                                           const string treeSplitCriterion = "informationGain")
-: dataSet(dataSet),
-numDecisionTrees(numDecisionTrees),
-numSamples((int)dataSet.size()),
-numFeatures((int)(dataSet[0].size() - 1)),
-globalDiscardedFeatureIndices(getGlobalDiscardedFeatureIndices()),
-globalVariableImportanceList(numFeatures, 0),
-treeSplitCriterion(treeSplitCriterion) {
+               const int numDecisionTrees,
+               const string treeSplitCriterion = "gainratio",
+               const bool doPruning = false,
+               const float pruneAggressiveness = 0.9,
+               const bool discardHighErrorTrees = true,
+               const float highErrorTreeDiscardThreshold = 0.4,
+               const string optimumFeatureSubsetSelectionCriteria = "log2",
+               const float featureStandardDeviationThreshold = 0.0)
+      : dataSet(dataSet),
+        numDecisionTrees(numDecisionTrees),
+        numSamples((int)dataSet.size()),
+        numFeatures((int)(dataSet[0].size() - 1)),
+        globalVariableImportanceList(numFeatures, 0),
+        treeSplitCriterion(treeSplitCriterion),
+        doPruning(doPruning),
+        pruneAggressiveness(pruneAggressiveness),
+        discardHighErrorTrees(discardHighErrorTrees),
+        highErrorTreeDiscardThreshold(highErrorTreeDiscardThreshold),
+        optimumFeatureSubsetSelectionCriteria(optimumFeatureSubsetSelectionCriteria),
+        featureStandardDeviationThreshold(featureStandardDeviationThreshold)
+        {
+        
     m = MothurOut::getInstance();
+    globalDiscardedFeatureIndices = getGlobalDiscardedFeatureIndices();
     // TODO: double check if the implemenatation of 'globalOutOfBagEstimates' is correct
 }
 
@@ -27,10 +41,11 @@ treeSplitCriterion(treeSplitCriterion) {
 
 vector<int> Forest::getGlobalDiscardedFeatureIndices() {
     try {
-        vector<int> globalDiscardedFeatureIndices;
+        //vector<int> globalDiscardedFeatureIndices;
+        //globalDiscardedFeatureIndices.push_back(1);
         
         // calculate feature vectors
-        vector< vector<int> > featureVectors(numFeatures, vector<int>(numSamples, 0));
+        vector< vector<int> > featureVectors(numFeatures, vector<int>(numSamples, 0) );
         for (int i = 0; i < numSamples; i++) {
             if (m->control_pressed) { return globalDiscardedFeatureIndices; }
             for (int j = 0; j < numFeatures; j++) { featureVectors[j][i] = dataSet[i][j]; }
@@ -39,7 +54,7 @@ vector<int> Forest::getGlobalDiscardedFeatureIndices() {
         for (int i = 0; i < featureVectors.size(); i++) {
             if (m->control_pressed) { return globalDiscardedFeatureIndices; }
             double standardDeviation = m->getStandardDeviation(featureVectors[i]);
-            if (standardDeviation <= 0){ globalDiscardedFeatureIndices.push_back(i); }
+            if (standardDeviation <= featureStandardDeviationThreshold){ globalDiscardedFeatureIndices.push_back(i); }
         }
         
         if (m->debug) {