]> git.donarmstrong.com Git - mothur.git/blobdiff - randomforest.cpp
changed random forest output filename
[mothur.git] / randomforest.cpp
index 36a2c1a261f27514c394a370d0d97387ac9cbfac..2ae0eb595c1c8d5f989c44a5e561a0ff3dac2d2b 100644 (file)
 
 /***********************************************************************/
 
-RandomForest::RandomForest(const vector <vector<int> > dataSet,const int numDecisionTrees,
-             const string treeSplitCriterion = "informationGain") : AbstractRandomForest(dataSet, numDecisionTrees, treeSplitCriterion) {
+RandomForest::RandomForest(const vector <vector<int> > dataSet,
+                           const int numDecisionTrees,
+                           const string treeSplitCriterion = "gainratio",
+                           const bool doPruning = false,
+                           const float pruneAggressiveness = 0.9,
+                           const bool discardHighErrorTrees = true,
+                           const float highErrorTreeDiscardThreshold = 0.4,
+                           const string optimumFeatureSubsetSelectionCriteria = "log2",
+                           const float featureStandardDeviationThreshold = 0.0)
+            : Forest(dataSet, numDecisionTrees, treeSplitCriterion, doPruning, pruneAggressiveness, discardHighErrorTrees, highErrorTreeDiscardThreshold, optimumFeatureSubsetSelectionCriteria, featureStandardDeviationThreshold) {
     m = MothurOut::getInstance();
 }
 
@@ -48,44 +56,50 @@ int RandomForest::calcForrestErrorRate() {
 }
 
 /***********************************************************************/
-// DONE
 int RandomForest::calcForrestVariableImportance(string filename) {
     try {
     
-    // TODO: need to add try/catch operators to fix this
-    // follow the link: http://en.wikipedia.org/wiki/Dynamic_cast
+        // follow the link: http://en.wikipedia.org/wiki/Dynamic_cast
         //if you are going to dynamically cast, aren't you undoing the advantage of abstraction. Why abstract at all?
         //could cause maintenance issues later if other types of Abstract decison trees are created that cannot be cast as a decision tree.
-    for (int i = 0; i < decisionTrees.size(); i++) {
-        if (m->control_pressed) { return 0; }
-        DecisionTree* decisionTree = dynamic_cast<DecisionTree*>(decisionTrees[i]);
+        for (int i = 0; i < decisionTrees.size(); i++) {
+            if (m->control_pressed) { return 0; }
+            
+            DecisionTree* decisionTree = dynamic_cast<DecisionTree*>(decisionTrees[i]);
+            
+            for (int j = 0; j < numFeatures; j++) {
+                globalVariableImportanceList[j] += (double)decisionTree->variableImportanceList[j];
+            }
+        }
         
-        for (int j = 0; j < numFeatures; j++) {
-            globalVariableImportanceList[j] += (double)decisionTree->variableImportanceList[j];
+        for (int i = 0;  i < numFeatures; i++) {
+            globalVariableImportanceList[i] /= (double)numDecisionTrees;
         }
-    }
-    
-    for (int i = 0;  i < numFeatures; i++) {
-        cout << "[" << i << ',' << globalVariableImportanceList[i] << "], ";
-        globalVariableImportanceList[i] /= (double)numDecisionTrees;
-    }
-    
-    vector< vector<double> > globalVariableRanks;
-    for (int i = 0; i < globalVariableImportanceList.size(); i++) {
-        if (globalVariableImportanceList[i] > 0) {
-            vector<double> globalVariableRank(2, 0);
-            globalVariableRank[0] = i; globalVariableRank[1] = globalVariableImportanceList[i];
-            globalVariableRanks.push_back(globalVariableRank);
+        
+        vector< pair<int, double> > globalVariableRanks;
+        for (int i = 0; i < globalVariableImportanceList.size(); i++) {
+            //cout << "[" << i << ',' << globalVariableImportanceList[i] << "], ";
+            if (globalVariableImportanceList[i] > 0) {
+                pair<int, double> globalVariableRank(0, 0.0);
+                globalVariableRank.first = i;
+                globalVariableRank.second = globalVariableImportanceList[i];
+                globalVariableRanks.push_back(globalVariableRank);
+            }
         }
-    }
-    
-    VariableRankDescendingSorterDouble variableRankDescendingSorter;
-    sort(globalVariableRanks.begin(), globalVariableRanks.end(), variableRankDescendingSorter);
+        
+//        for (int i = 0; i < globalVariableRanks.size(); i++) {
+//            cout << m->currentBinLabels[(int)globalVariableRanks[i][0]] << '\t' << globalVariableImportanceList[globalVariableRanks[i][0]] << endl;
+//        }
+
+        
+        VariableRankDescendingSorterDouble variableRankDescendingSorter;
+        sort(globalVariableRanks.begin(), globalVariableRanks.end(), variableRankDescendingSorter);
+        
         ofstream out;
         m->openOutputFile(filename, out);
         out <<"OTU\tRank\n";
         for (int i = 0; i < globalVariableRanks.size(); i++) {
-            out << m->currentBinLabels[(int)globalVariableRanks[i][0]] << '\t' << globalVariableImportanceList[globalVariableRanks[i][0]] << endl;
+            out << m->currentBinLabels[(int)globalVariableRanks[i].first] << '\t' << globalVariableImportanceList[globalVariableRanks[i].first] << endl;
         }
         out.close();
         return 0;
@@ -96,27 +110,89 @@ int RandomForest::calcForrestVariableImportance(string filename) {
        }  
 }
 /***********************************************************************/
-// DONE
 int RandomForest::populateDecisionTrees() {
     try {
         
+        vector<double> errorRateImprovements;
+        
         for (int i = 0; i < numDecisionTrees; i++) {
+          
             if (m->control_pressed) { return 0; }
-            if (((i+1) % 10) == 0) {  m->mothurOut("Creating " + toString(i+1) + " (th) Decision tree\n");  }
+            if (((i+1) % 100) == 0) {  m->mothurOut("Creating " + toString(i+1) + " (th) Decision tree\n");  }
+          
             // TODO: need to first fix if we are going to use pointer based system or anything else
-            DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector("log2"), treeSplitCriterion);
-            decisionTree->calcTreeVariableImportanceAndError();
-            if (m->control_pressed) { return 0; }
-            updateGlobalOutOfBagEstimates(decisionTree);
-            if (m->control_pressed) { return 0; }
-            decisionTree->purgeDataSetsFromTree();
-            if (m->control_pressed) { return 0; }
-            decisionTrees.push_back(decisionTree);
+            DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector(optimumFeatureSubsetSelectionCriteria), treeSplitCriterion, featureStandardDeviationThreshold);
+          
+            if (m->debug && doPruning) {
+                m->mothurOut("Before pruning\n");
+                decisionTree->printTree(decisionTree->rootNode, "ROOT");
+            }
+            
+            int numCorrect;
+            double treeErrorRate;
+            
+            decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate);
+            double prePrunedErrorRate = treeErrorRate;
+            
+            if (m->debug) {
+                m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n");
+            }
+            
+            if (doPruning) {
+                decisionTree->pruneTree(pruneAggressiveness);
+                if (m->debug) {
+                    m->mothurOut("After pruning\n");
+                    decisionTree->printTree(decisionTree->rootNode, "ROOT");
+                }
+                decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate);
+            }
+            double postPrunedErrorRate = treeErrorRate;
+            
+          
+            decisionTree->calcTreeVariableImportanceAndError(numCorrect, treeErrorRate);
+            double errorRateImprovement = (prePrunedErrorRate - postPrunedErrorRate) / prePrunedErrorRate;
+
+            if (m->debug) {
+                m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n");
+                if (doPruning) {
+                    m->mothurOut("errorRateImprovement: " + toString(errorRateImprovement) + "\n");
+                }
+            }
+            
+            
+            if (discardHighErrorTrees) {
+                if (treeErrorRate < highErrorTreeDiscardThreshold) {
+                    updateGlobalOutOfBagEstimates(decisionTree);
+                    decisionTree->purgeDataSetsFromTree();
+                    decisionTrees.push_back(decisionTree);
+                    if (doPruning) {
+                        errorRateImprovements.push_back(errorRateImprovement);
+                    }
+                } else {
+                    delete decisionTree;
+                }
+            } else {
+                updateGlobalOutOfBagEstimates(decisionTree);
+                decisionTree->purgeDataSetsFromTree();
+                decisionTrees.push_back(decisionTree);
+                if (doPruning) {
+                    errorRateImprovements.push_back(errorRateImprovement);
+                }
+            }          
+        }
+        
+        double avgErrorRateImprovement = -1.0;
+        if (errorRateImprovements.size() > 0) {
+            avgErrorRateImprovement = accumulate(errorRateImprovements.begin(), errorRateImprovements.end(), 0.0);
+//            cout << "Total " << avgErrorRateImprovement << " size " << errorRateImprovements.size() << endl;
+            avgErrorRateImprovement /= errorRateImprovements.size();
         }
         
-        if (m->debug) {
-            // m->mothurOut("globalOutOfBagEstimates = " + toStringVectorMap(globalOutOfBagEstimates)+ "\n");
+        if (m->debug && doPruning) {
+            m->mothurOut("avgErrorRateImprovement:" + toString(avgErrorRateImprovement) + "\n");
         }
+        // m->mothurOut("globalOutOfBagEstimates = " + toStringVectorMap(globalOutOfBagEstimates)+ "\n");
+
         
         return 0;
     }
@@ -127,7 +203,7 @@ int RandomForest::populateDecisionTrees() {
 }
 /***********************************************************************/
 // TODO: need to finalize bettween reference and pointer for DecisionTree [partially solved]
-// TODO: make this pure virtual in superclass
+// DONE: make this pure virtual in superclass
 // DONE
 int RandomForest::updateGlobalOutOfBagEstimates(DecisionTree* decisionTree) {
     try {