]> git.donarmstrong.com Git - mothur.git/commitdiff
Merge remote-tracking branch 'origin'
authorSarah Westcott <mothur.westcott@gmail.com>
Fri, 24 May 2013 14:16:22 +0000 (10:16 -0400)
committerSarah Westcott <mothur.westcott@gmail.com>
Fri, 24 May 2013 14:16:22 +0000 (10:16 -0400)
16 files changed:
.gitignore
Mothur.xcodeproj/project.pbxproj
abstractdecisiontree.cpp
abstractdecisiontree.hpp
classifysharedcommand.cpp
classifysharedcommand.h
decisiontree.cpp
decisiontree.hpp
forest.cpp
forest.h
macros.h
randomforest.cpp
randomforest.hpp
regularizedrandomforest.cpp
rftreenode.cpp
rftreenode.hpp

index 853fd84a4e280293f81727c25057644c758cd2c0..fb4ae5123e7cdc9f64ab449b496415f8edd27733 100644 (file)
@@ -1,4 +1,9 @@
-.DS_Store
-*.zip
+*.logfile
+*.o
 *.pbxproj
-*.xcuserdata
\ No newline at end of file
+*.xcuserdata
+*.zip
+.DS_Store
+.idea
+build
+xcuserdata
index d7474058b9a98dc706b1af88db7303d84dc0a074..0ac7ba85a278ffde691f78cf49f7101dd5aa51b1 100644 (file)
@@ -32,7 +32,6 @@
                A727864412E9E28C00F86ABA /* removerarecommand.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A727864312E9E28C00F86ABA /* removerarecommand.cpp */; };
                A7386C231619CCE600651424 /* classifysharedcommand.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A7386C211619CCE600651424 /* classifysharedcommand.cpp */; };
                A7386C251619E52300651424 /* abstractdecisiontree.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A7386C241619E52200651424 /* abstractdecisiontree.cpp */; };
-               A7386C27161A0F9D00651424 /* abstractrandomforest.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A7386C26161A0F9C00651424 /* abstractrandomforest.cpp */; };
                A7386C29161A110800651424 /* decisiontree.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A7386C28161A110700651424 /* decisiontree.cpp */; };
                A73901081588C40900ED2ED6 /* loadlogfilecommand.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A73901071588C40900ED2ED6 /* loadlogfilecommand.cpp */; };
                A73DDBBA13C4A0D1006AAE38 /* clearmemorycommand.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A73DDBB913C4A0D1006AAE38 /* clearmemorycommand.cpp */; };
                A727864212E9E28C00F86ABA /* removerarecommand.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = removerarecommand.h; sourceTree = "<group>"; };
                A727864312E9E28C00F86ABA /* removerarecommand.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = removerarecommand.cpp; sourceTree = "<group>"; };
                A7386C1B1619CACB00651424 /* abstractdecisiontree.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = abstractdecisiontree.hpp; sourceTree = "<group>"; };
-               A7386C1C1619CACB00651424 /* abstractrandomforest.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = abstractrandomforest.hpp; sourceTree = "<group>"; };
                A7386C1D1619CACB00651424 /* decisiontree.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = decisiontree.hpp; sourceTree = "<group>"; };
                A7386C1E1619CACB00651424 /* macros.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = macros.h; sourceTree = "<group>"; };
                A7386C1F1619CACB00651424 /* randomforest.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = randomforest.hpp; sourceTree = "<group>"; };
                A7386C211619CCE600651424 /* classifysharedcommand.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = classifysharedcommand.cpp; sourceTree = "<group>"; };
                A7386C221619CCE600651424 /* classifysharedcommand.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = classifysharedcommand.h; sourceTree = "<group>"; };
                A7386C241619E52200651424 /* abstractdecisiontree.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = abstractdecisiontree.cpp; sourceTree = "<group>"; };
-               A7386C26161A0F9C00651424 /* abstractrandomforest.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = abstractrandomforest.cpp; sourceTree = "<group>"; };
                A7386C28161A110700651424 /* decisiontree.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = decisiontree.cpp; sourceTree = "<group>"; };
                A73901051588C3EF00ED2ED6 /* loadlogfilecommand.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = loadlogfilecommand.h; sourceTree = "<group>"; };
                A73901071588C40900ED2ED6 /* loadlogfilecommand.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = loadlogfilecommand.cpp; sourceTree = "<group>"; };
                        children = (
                                A7386C1B1619CACB00651424 /* abstractdecisiontree.hpp */,
                                A7386C241619E52200651424 /* abstractdecisiontree.cpp */,
-                               A7386C1C1619CACB00651424 /* abstractrandomforest.hpp */,
-                               A7386C26161A0F9C00651424 /* abstractrandomforest.cpp */,
                                A7386C1D1619CACB00651424 /* decisiontree.hpp */,
                                A7386C28161A110700651424 /* decisiontree.cpp */,
                                A7386C1E1619CACB00651424 /* macros.h */,
                                A7C7DAB915DA758B0059B0CF /* sffmultiplecommand.cpp in Sources */,
                                A7386C231619CCE600651424 /* classifysharedcommand.cpp in Sources */,
                                A7386C251619E52300651424 /* abstractdecisiontree.cpp in Sources */,
-                               A7386C27161A0F9D00651424 /* abstractrandomforest.cpp in Sources */,
                                A7386C29161A110800651424 /* decisiontree.cpp in Sources */,
                                A77E1938161B201E00DB1A2A /* randomforest.cpp in Sources */,
                                A77E193B161B289600DB1A2A /* rftreenode.cpp in Sources */,
                                GCC_ENABLE_SSE3_EXTENSIONS = NO;
                                GCC_ENABLE_SSE41_EXTENSIONS = NO;
                                GCC_ENABLE_SSE42_EXTENSIONS = NO;
-                               GCC_OPTIMIZATION_LEVEL = 3;
+                               GCC_OPTIMIZATION_LEVEL = s;
                                GCC_PREPROCESSOR_DEFINITIONS = (
                                        "MOTHUR_FILES=\"\\\"../../release\\\"\"",
                                        "VERSION=\"\\\"1.29.2\\\"\"",
                                GCC_C_LANGUAGE_STANDARD = gnu99;
                                GCC_GENERATE_DEBUGGING_SYMBOLS = NO;
                                GCC_MODEL_TUNING = "";
-                               GCC_OPTIMIZATION_LEVEL = 3;
+                               GCC_OPTIMIZATION_LEVEL = s;
                                GCC_PREPROCESSOR_DEFINITIONS = (
                                        "VERSION=\"\\\"1.30.0\\\"\"",
                                        "RELEASE_DATE=\"\\\"4/01/2013\\\"\"",
index 085cd316225df92402d0a9fe86c623b45acdd553..595f483dc5ba6cde01b825f8ac44d9ceea4ece16 100644 (file)
 
 /**************************************************************************************************/
 
-AbstractDecisionTree::AbstractDecisionTree(vector<vector<int> >baseDataSet, 
-                     vector<int> globalDiscardedFeatureIndices, 
-                     OptimumFeatureSubsetSelector optimumFeatureSubsetSelector, 
-                     string treeSplitCriterion) : baseDataSet(baseDataSet),
-numSamples((int)baseDataSet.size()),
-numFeatures((int)(baseDataSet[0].size() - 1)),
-numOutputClasses(0),
-rootNode(NULL),
-globalDiscardedFeatureIndices(globalDiscardedFeatureIndices),
-optimumFeatureSubsetSize(optimumFeatureSubsetSelector.getOptimumFeatureSubsetSize(numFeatures)),
-treeSplitCriterion(treeSplitCriterion) {
+AbstractDecisionTree::AbstractDecisionTree(vector<vector<int> >& baseDataSet,
+                                         vector<int> globalDiscardedFeatureIndices,
+                                         OptimumFeatureSubsetSelector optimumFeatureSubsetSelector, 
+                                         string treeSplitCriterion)
+
+                    : baseDataSet(baseDataSet),
+                    numSamples((int)baseDataSet.size()),
+                    numFeatures((int)(baseDataSet[0].size() - 1)),
+                    numOutputClasses(0),
+                    rootNode(NULL),
+                    nodeIdCount(0),
+                    globalDiscardedFeatureIndices(globalDiscardedFeatureIndices),
+                    optimumFeatureSubsetSize(optimumFeatureSubsetSelector.getOptimumFeatureSubsetSize(numFeatures)),
+                    treeSplitCriterion(treeSplitCriterion) {
 
     try {
-    // TODO: istead of calculating this for every DecisionTree
-    // clacualte this once in the RandomForest class and pass the values
-    m = MothurOut::getInstance();
-    for (int i = 0;  i < numSamples; i++) {
-        if (m->control_pressed) { break; }
-        int outcome = baseDataSet[i][numFeatures];
-        vector<int>::iterator it = find(outputClasses.begin(), outputClasses.end(), outcome);
-        if (it == outputClasses.end()){       // find() will return classes.end() if the element is not found
-            outputClasses.push_back(outcome);
-            numOutputClasses++;
+        // TODO: istead of calculating this for every DecisionTree
+        // clacualte this once in the RandomForest class and pass the values
+        m = MothurOut::getInstance();
+        for (int i = 0;  i < numSamples; i++) {
+            if (m->control_pressed) { break; }
+            int outcome = baseDataSet[i][numFeatures];
+            vector<int>::iterator it = find(outputClasses.begin(), outputClasses.end(), outcome);
+            if (it == outputClasses.end()){       // find() will return classes.end() if the element is not found
+                outputClasses.push_back(outcome);
+                numOutputClasses++;
+            }
+        }
+        
+        if (m->debug) {
+            //m->mothurOut("outputClasses = " + toStringVectorInt(outputClasses));
+            m->mothurOut("numOutputClasses = " + toString(numOutputClasses) + '\n');
         }
-    }
-    
-    if (m->debug) {
-        //m->mothurOut("outputClasses = " + toStringVectorInt(outputClasses));
-        m->mothurOut("numOutputClasses = " + toString(numOutputClasses) + '\n');
-    }
 
     }
        catch(exception& e) {
@@ -50,25 +53,38 @@ treeSplitCriterion(treeSplitCriterion) {
 /**************************************************************************************************/
 int AbstractDecisionTree::createBootStrappedSamples(){
     try {    
-    vector<bool> isInTrainingSamples(numSamples, false);
-    
-    for (int i = 0; i < numSamples; i++) {
-        if (m->control_pressed) { return 0; }
-        // TODO: optimize the rand() function call + double check if it's working properly
-        int randomIndex = rand() % numSamples;
-        bootstrappedTrainingSamples.push_back(baseDataSet[randomIndex]);
-        isInTrainingSamples[randomIndex] = true;
-    }
-    
-    for (int i = 0; i < numSamples; i++) {
-        if (m->control_pressed) { return 0; }
-        if (isInTrainingSamples[i]){ bootstrappedTrainingSampleIndices.push_back(i); }
-        else{
-            bootstrappedTestSamples.push_back(baseDataSet[i]);
-            bootstrappedTestSampleIndices.push_back(i);
+        vector<bool> isInTrainingSamples(numSamples, false);
+        
+        for (int i = 0; i < numSamples; i++) {
+            if (m->control_pressed) { return 0; }
+            // TODO: optimize the rand() function call + double check if it's working properly
+            int randomIndex = rand() % numSamples;
+            bootstrappedTrainingSamples.push_back(baseDataSet[randomIndex]);
+            isInTrainingSamples[randomIndex] = true;
         }
-    }
-    
+        
+        for (int i = 0; i < numSamples; i++) {
+            if (m->control_pressed) { return 0; }
+            if (isInTrainingSamples[i]){ bootstrappedTrainingSampleIndices.push_back(i); }
+            else{
+                bootstrappedTestSamples.push_back(baseDataSet[i]);
+                bootstrappedTestSampleIndices.push_back(i);
+            }
+        }
+        
+            // do the transpose of Test Samples
+        for (int i = 0; i < bootstrappedTestSamples[0].size(); i++) {
+            if (m->control_pressed) { return 0; }
+            
+            vector<int> tmpFeatureVector(bootstrappedTestSamples.size(), 0);
+            for (int j = 0; j < bootstrappedTestSamples.size(); j++) {
+                if (m->control_pressed) { return 0; }
+                
+                tmpFeatureVector[j] = bootstrappedTestSamples[j][i];
+            }
+            testSampleFeatureVectors.push_back(tmpFeatureVector);
+        }
+        
         return 0;
     }
        catch(exception& e) {
@@ -77,25 +93,34 @@ int AbstractDecisionTree::createBootStrappedSamples(){
        } 
 }
 /**************************************************************************************************/
-int AbstractDecisionTree::getMinEntropyOfFeature(vector<int> featureVector, vector<int> outputVector, double& minEntropy, int& featureSplitValue, double& intrinsicValue){
+int AbstractDecisionTree::getMinEntropyOfFeature(vector<int> featureVector,
+                                                 vector<int> outputVector,
+                                                 double& minEntropy,
+                                                 int& featureSplitValue,
+                                                 double& intrinsicValue){
     try {
 
-        vector< vector<int> > featureOutputPair(featureVector.size(), vector<int>(2, 0));
+        vector< pair<int, int> > featureOutputPair(featureVector.size(), pair<int, int>(0, 0));
+        
         for (int i = 0; i < featureVector.size(); i++) { 
             if (m->control_pressed) { return 0; }
-            featureOutputPair[i][0] = featureVector[i];
-            featureOutputPair[i][1] = outputVector[i];
+            
+            featureOutputPair[i].first = featureVector[i];
+            featureOutputPair[i].second = outputVector[i];
         }
-        // TODO: using default behavior to sort(), need to specify the comparator for added safety and compiler portability
-        sort(featureOutputPair.begin(), featureOutputPair.end());
+        // TODO: using default behavior to sort(), need to specify the comparator for added safety and compiler portability,
         
+        IntPairVectorSorter intPairVectorSorter;
+        sort(featureOutputPair.begin(), featureOutputPair.end(), intPairVectorSorter);
         
         vector<int> splitPoints;
-        vector<int> uniqueFeatureValues(1, featureOutputPair[0][0]);
+        vector<int> uniqueFeatureValues(1, featureOutputPair[0].first);
         
         for (int i = 0; i < featureOutputPair.size(); i++) {
+
             if (m->control_pressed) { return 0; }
-            int featureValue = featureOutputPair[i][0];
+            int featureValue = featureOutputPair[i].first;
+
             vector<int>::iterator it = find(uniqueFeatureValues.begin(), uniqueFeatureValues.end(), featureValue);
             if (it == uniqueFeatureValues.end()){                 // NOT FOUND
                 uniqueFeatureValues.push_back(featureValue);
@@ -115,7 +140,7 @@ int AbstractDecisionTree::getMinEntropyOfFeature(vector<int> featureVector, vect
             featureSplitValue = -1;                                                   // OUTPUT
         }else{
             getBestSplitAndMinEntropy(featureOutputPair, splitPoints, minEntropy, bestSplitIndex, intrinsicValue);  // OUTPUT
-            featureSplitValue = featureOutputPair[splitPoints[bestSplitIndex]][0];    // OUTPUT
+            featureSplitValue = featureOutputPair[splitPoints[bestSplitIndex]].first;    // OUTPUT
         }
         
         return 0;
@@ -146,8 +171,9 @@ double AbstractDecisionTree::calcIntrinsicValue(int numLessThanValueAtSplitPoint
        } 
 }
 /**************************************************************************************************/
-int AbstractDecisionTree::getBestSplitAndMinEntropy(vector< vector<int> > featureOutputPairs, vector<int> splitPoints,
-                               double& minEntropy, int& minEntropyIndex, double& relatedIntrinsicValue){
+
+int AbstractDecisionTree::getBestSplitAndMinEntropy(vector< pair<int, int> > featureOutputPairs, vector<int> splitPoints,
+                                                    double& minEntropy, int& minEntropyIndex, double& relatedIntrinsicValue){
     try {
         
         int numSamples = (int)featureOutputPairs.size();
@@ -155,16 +181,17 @@ int AbstractDecisionTree::getBestSplitAndMinEntropy(vector< vector<int> > featur
         vector<double> intrinsicValues;
         
         for (int i = 0; i < splitPoints.size(); i++) {
-             if (m->control_pressed) { return 0; }
+            if (m->control_pressed) { return 0; }
             int index = splitPoints[i];
-            int valueAtSplitPoint = featureOutputPairs[index][0];
+            int valueAtSplitPoint = featureOutputPairs[index].first;
+
             int numLessThanValueAtSplitPoint = 0;
             int numGreaterThanValueAtSplitPoint = 0;
             
             for (int j = 0; j < featureOutputPairs.size(); j++) {
-                 if (m->control_pressed) { return 0; }
-                vector<int> record = featureOutputPairs[j];
-                if (record[0] < valueAtSplitPoint){ numLessThanValueAtSplitPoint++; }
+                if (m->control_pressed) { return 0; }
+                pair<int, int> record = featureOutputPairs[j];
+                if (record.first < valueAtSplitPoint){ numLessThanValueAtSplitPoint++; }
                 else{ numGreaterThanValueAtSplitPoint++; }
             }
             
@@ -193,19 +220,19 @@ int AbstractDecisionTree::getBestSplitAndMinEntropy(vector< vector<int> > featur
 }
 /**************************************************************************************************/
 
-double AbstractDecisionTree::calcSplitEntropy(vector< vector<int> > featureOutputPairs, int splitIndex, int numOutputClasses, bool isUpperSplit = true) {
+double AbstractDecisionTree::calcSplitEntropy(vector< pair<int, int> > featureOutputPairs, int splitIndex, int numOutputClasses, bool isUpperSplit = true) {
     try {
         vector<int> classCounts(numOutputClasses, 0);
         
         if (isUpperSplit) { 
-            for (int i = 0; i < splitIndex; i++) { 
+            for (int i = 0; i < splitIndex; i++) {
                 if (m->control_pressed) { return 0; }
-                classCounts[featureOutputPairs[i][1]]++; 
+                classCounts[featureOutputPairs[i].second]++;
             }
         } else {
             for (int i = splitIndex; i < featureOutputPairs.size(); i++) { 
                 if (m->control_pressed) { return 0; }
-                classCounts[featureOutputPairs[i][1]]++; 
+                classCounts[featureOutputPairs[i].second]++;
             }
         }
         
@@ -245,8 +272,9 @@ int AbstractDecisionTree::getSplitPopulation(RFTreeNode* node, vector< vector<in
             if (m->control_pressed) { return 0; }
             vector<int> sample =  node->getBootstrappedTrainingSamples()[i];
             if (m->control_pressed) { return 0; }
-            if (sample[splitFeatureGlobalIndex] < node->getSplitFeatureValue()){ leftChildSamples.push_back(sample); }
-            else{ rightChildSamples.push_back(sample); }
+            
+            if (sample[splitFeatureGlobalIndex] < node->getSplitFeatureValue()) { leftChildSamples.push_back(sample); }
+            else { rightChildSamples.push_back(sample); }
         }
         
         return 0;
index 3445db4a511e8705ed5b0bed14626eb869ce5043..cc238b4687e57c6cab840f5ec31ce10adbc8a655 100755 (executable)
@@ -6,8 +6,8 @@
 //  Copyright (c) 2012 Schloss Lab. All rights reserved.
 //
 
-#ifndef rrf_fs_prototype_abstractdecisiontree_hpp
-#define rrf_fs_prototype_abstractdecisiontree_hpp
+#ifndef RF_ABSTRACTDECISIONTREE_HPP
+#define RF_ABSTRACTDECISIONTREE_HPP
 
 #include "mothurout.h"
 #include "macros.h"
 
 /**************************************************************************************************/
 
+struct IntPairVectorSorter{
+    bool operator() (const pair<int, int>& firstPair, const pair<int, int>& secondPair) {
+        return firstPair.first < secondPair.first;
+    }
+};
+
+/**************************************************************************************************/
+
 class AbstractDecisionTree{
   
 public:
   
-    AbstractDecisionTree(vector<vector<int> >baseDataSet, 
-                       vector<int> globalDiscardedFeatureIndices, 
-                       OptimumFeatureSubsetSelector optimumFeatureSubsetSelector, 
-                       string treeSplitCriterion);    
+    AbstractDecisionTree(vector<vector<int> >& baseDataSet,
+                           vector<int> globalDiscardedFeatureIndices, 
+                           OptimumFeatureSubsetSelector optimumFeatureSubsetSelector, 
+                           string treeSplitCriterion);    
     virtual ~AbstractDecisionTree(){}
     
   
@@ -32,23 +40,29 @@ protected:
   
     virtual int createBootStrappedSamples();
     virtual int getMinEntropyOfFeature(vector<int> featureVector, vector<int> outputVector, double& minEntropy, int& featureSplitValue, double& intrinsicValue);
-    virtual int getBestSplitAndMinEntropy(vector< vector<int> > featureOutputPairs, vector<int> splitPoints, double& minEntropy, int& minEntropyIndex, double& relatedIntrinsicValue);
+        virtual int getBestSplitAndMinEntropy(vector< pair<int, int> > featureOutputPairs, vector<int> splitPoints, double& minEntropy, int& minEntropyIndex, double& relatedIntrinsicValue);
     virtual double calcIntrinsicValue(int numLessThanValueAtSplitPoint, int numGreaterThanValueAtSplitPoint, int numSamples);
-    virtual double calcSplitEntropy(vector< vector<int> > featureOutputPairs, int splitIndex, int numOutputClasses, bool);
+    virtual double calcSplitEntropy(vector< pair<int, int> > featureOutputPairs, int splitIndex, int numOutputClasses, bool);
+
     virtual int getSplitPopulation(RFTreeNode* node, vector< vector<int> >& leftChildSamples, vector< vector<int> >& rightChildSamples);
     virtual bool checkIfAlreadyClassified(RFTreeNode* treeNode, int& outputClass);
 
-    vector< vector<int> > baseDataSet;
+    vector< vector<int> >& baseDataSet;
     int numSamples;
     int numFeatures;
     int numOutputClasses;
     vector<int> outputClasses;
+    
     vector< vector<int> > bootstrappedTrainingSamples;
     vector<int> bootstrappedTrainingSampleIndices;
     vector< vector<int> > bootstrappedTestSamples;
     vector<int> bootstrappedTestSampleIndices;
     
+    vector<vector<int> > testSampleFeatureVectors;
+    
     RFTreeNode* rootNode;
+    int nodeIdCount;
+    map<int, int> nodeMisclassificationCounts;
     vector<int> globalDiscardedFeatureIndices;
     int optimumFeatureSubsetSize;
     string treeSplitCriterion;
index 2dc963babfc0ebea56cb0505e647667d1e964170..6e32fd1e800023348e4a58eeaad678f262430d25 100755 (executable)
@@ -20,6 +20,14 @@ vector<string> ClassifySharedCommand::setParameters(){
         CommandParameter potupersplit("otupersplit", "Multiple", "log2-squareroot", "log2", "", "", "","",false,false); parameters.push_back(potupersplit);
         CommandParameter psplitcriteria("splitcriteria", "Multiple", "gainratio-infogain", "gainratio", "", "", "","",false,false); parameters.push_back(psplitcriteria);
                CommandParameter pnumtrees("numtrees", "Number", "", "100", "", "", "","",false,false); parameters.push_back(pnumtrees);
+        
+            // parameters related to pruning
+        CommandParameter pdopruning("prune", "Boolean", "", "T", "", "", "", "", false, false); parameters.push_back(pdopruning);
+        CommandParameter ppruneaggrns("pruneaggressiveness", "Number", "", "0.9", "", "", "", "", false, false); parameters.push_back(ppruneaggrns);
+        CommandParameter pdiscardhetrees("discarderrortrees", "Boolean", "", "T", "", "", "", "", false, false); parameters.push_back(pdiscardhetrees);
+        CommandParameter phetdiscardthreshold("errorthreshold", "Number", "", "0.4", "", "", "", "", false, false); parameters.push_back(phetdiscardthreshold);
+        CommandParameter psdthreshold("stdthreshold", "Number", "", "0.0", "", "", "", "", false, false); parameters.push_back(psdthreshold);
+            // pruning params end
 
         CommandParameter pgroups("groups", "String", "", "", "", "", "","",false,false); parameters.push_back(pgroups);
                CommandParameter plabel("label", "String", "", "", "", "", "","",false,false); parameters.push_back(plabel);
@@ -81,6 +89,7 @@ ClassifySharedCommand::ClassifySharedCommand() {
     exit(1);
   }
 }
+
 //**********************************************************************************************************************
 ClassifySharedCommand::ClassifySharedCommand(string option) {
   try {
@@ -104,7 +113,6 @@ ClassifySharedCommand::ClassifySharedCommand(string option) {
       for (it = parameters.begin(); it != parameters.end(); it++) {
         if (validParameter.isValidParameter(it->first, myArray, it->second) != true) {  abort = true;  }
       }
-        
         vector<string> tempOutNames;
         outputTypes["summary"] = tempOutNames;
       
@@ -130,7 +138,6 @@ ClassifySharedCommand::ClassifySharedCommand(string option) {
         }
         
       }
-       
         //check for parameters
         //get shared file, it is required
       sharedfile = validParameter.validFile(parameters, "shared", true);
@@ -158,25 +165,51 @@ ClassifySharedCommand::ClassifySharedCommand(string option) {
         outputDir = m->hasPath(sharedfile); //if user entered a file with a path then preserve it
       }
       
-    
         // NEW CODE for OTU per split selection criteria
-      otupersplit = validParameter.validFile(parameters, "otupersplit", false);
-      if (otupersplit == "not found") { otupersplit = "log2"; }
-      if ((otupersplit == "squareroot") || (otupersplit == "log2")) {
-        optimumFeatureSubsetSelectionCriteria = otupersplit;
-      }else { m->mothurOut("Not a valid OTU per split selection method. Valid OTU per split selection methods are 'log2' and 'squareroot'."); m->mothurOutEndLine(); abort = true; }
-      
-        // splitcriteria
-      splitcriteria = validParameter.validFile(parameters, "splitcriteria", false);
-      if (splitcriteria == "not found") { splitcriteria = "gainratio"; }
-      if ((splitcriteria == "gainratio") || (splitcriteria == "infogain")) {
-        treeSplitCriterion = splitcriteria;
-      }else { m->mothurOut("Not a valid tree splitting criterio. Valid tree splitting criteria are 'gainratio' and 'infogain'."); m->mothurOutEndLine(); abort = true; }
-      
-      
-      string temp = validParameter.validFile(parameters, "numtrees", false); if (temp == "not found"){ temp = "100";   }
-      m->mothurConvert(temp, numDecisionTrees);
-
+        string temp = validParameter.validFile(parameters, "splitcriteria", false);
+        if (temp == "not found") { temp = "gainratio"; }
+        if ((temp == "gainratio") || (temp == "infogain")) {
+            treeSplitCriterion = temp;
+        } else { m->mothurOut("Not a valid tree splitting criterio. Valid tree splitting criteria are 'gainratio' and 'infogain'.");
+            m->mothurOutEndLine();
+            abort = true;
+        }
+        
+        temp = validParameter.validFile(parameters, "numtrees", false); if (temp == "not found"){      temp = "100";   }
+        m->mothurConvert(temp, numDecisionTrees);
+        
+            // parameters for pruning
+        temp = validParameter.validFile(parameters, "prune", false);
+        if (temp == "not found") { temp = "f"; }
+        doPruning = m->isTrue(temp);
+        
+        temp = validParameter.validFile(parameters, "pruneaggressiveness", false);
+        if (temp == "not found") { temp = "0.9"; }
+        m->mothurConvert(temp, pruneAggressiveness);
+        
+        temp = validParameter.validFile(parameters, "discarderrortrees", false);
+        if (temp == "not found") { temp = "f"; }
+        discardHighErrorTrees = m->isTrue(temp);
+        
+        temp = validParameter.validFile(parameters, "errorthreshold", false);
+        if (temp == "not found") { temp = "0.4"; }
+        m->mothurConvert(temp, highErrorTreeDiscardThreshold);
+        
+        temp = validParameter.validFile(parameters, "otupersplit", false);
+        if (temp == "not found") { temp = "log2"; }
+        if ((temp == "squareroot") || (temp == "log2")) {
+            optimumFeatureSubsetSelectionCriteria = temp;
+        } else { m->mothurOut("Not a valid OTU per split selection method. Valid OTU per split selection methods are 'log2' and 'squareroot'.");
+            m->mothurOutEndLine();
+            abort = true;
+        }
+        
+        temp = validParameter.validFile(parameters, "stdthreshold", false);
+        if (temp == "not found") { temp = "0.0"; }
+        m->mothurConvert(temp, featureStandardDeviationThreshold);
+                        
+            // end of pruning params
+        
         //Groups must be checked later to make sure they are valid. SharedUtilities has functions of check the validity, just make to so m->setGroups() after the checks.  If you are using these with a shared file no need to check the SharedRAbundVector class will call SharedUtilites for you, kinda nice, huh?
       string groups = validParameter.validFile(parameters, "groups", false);
       if (groups == "not found") { groups = ""; }
@@ -235,7 +268,6 @@ int ClassifySharedCommand::execute() {
         for (int i = 0; i < lookup.size(); i++) {  delete lookup[i];  }
         lookup = input.getSharedRAbundVectors(lastLabel);
         m->mothurOut(lookup[0]->getLabel()); m->mothurOutEndLine();
-           
         processSharedAndDesignData(lookup);        
         
         processedLabels.insert(lookup[0]->getLabel());
@@ -339,7 +371,8 @@ void ClassifySharedCommand::processSharedAndDesignData(vector<SharedRAbundVector
             dataSet[i][j] = treatmentToIntMap[treatmentName];
         }
         
-        RandomForest randomForest(dataSet, numDecisionTrees, treeSplitCriterion);
+        RandomForest randomForest(dataSet, numDecisionTrees, treeSplitCriterion, doPruning, pruneAggressiveness, discardHighErrorTrees, highErrorTreeDiscardThreshold, optimumFeatureSubsetSelectionCriteria, featureStandardDeviationThreshold);
+        
         randomForest.populateDecisionTrees();
         randomForest.calcForrestErrorRate();
         
index fe0637494db78aa96ecbac56c2f82b08ba2f9c6d..276b71ccdfbeda299125477eefbc3796a2c9aa7b 100755 (executable)
@@ -20,10 +20,9 @@ public:
   
   vector<string> setParameters();
   string getCommandName()                      { return "classify.shared";     }
-   string getCommandCategory()         { return "OTU-Based Approaches";                }
-  
+  string getCommandCategory()          { return "OTU-Based Approaches";                }  
   string getHelpString();      
-    string getOutputPattern(string);
+  string getOutputPattern(string);
   string getCitation() { return "http://www.mothur.org/wiki/Classify.shared\n"; }
   string getDescription()              { return "description"; }
   int execute();
@@ -35,7 +34,7 @@ private:
     string outputDir;
     vector<string> outputNames, Groups;
   
-    string sharedfile, designfile, otupersplit, splitcriteria;
+    string sharedfile, designfile;
     set<string> labels;
     bool allLines;
   
index 99853f303c3562e41c43909948276dd82e0f9c6d..f5310bc1983b50eda0411c78061d7a58c6447ffd 100644 (file)
@@ -8,13 +8,18 @@
 
 #include "decisiontree.hpp"
 
-DecisionTree::DecisionTree(vector< vector<int> > baseDataSet,
-             vector<int> globalDiscardedFeatureIndices,
-             OptimumFeatureSubsetSelector optimumFeatureSubsetSelector,
-             string treeSplitCriterion) : AbstractDecisionTree(baseDataSet,
-                       globalDiscardedFeatureIndices,
-                       optimumFeatureSubsetSelector,
-                       treeSplitCriterion), variableImportanceList(numFeatures, 0){
+DecisionTree::DecisionTree(vector< vector<int> >& baseDataSet,
+                           vector<int> globalDiscardedFeatureIndices,
+                           OptimumFeatureSubsetSelector optimumFeatureSubsetSelector,
+                           string treeSplitCriterion,
+                           float featureStandardDeviationThreshold)
+            : AbstractDecisionTree(baseDataSet,
+                                   globalDiscardedFeatureIndices,
+                                   optimumFeatureSubsetSelector,
+                                   treeSplitCriterion),
+            variableImportanceList(numFeatures, 0),
+            featureStandardDeviationThreshold(featureStandardDeviationThreshold) {
+                
     try {
         m = MothurOut::getInstance();
         createBootStrappedSamples();
@@ -28,41 +33,57 @@ DecisionTree::DecisionTree(vector< vector<int> > baseDataSet,
 
 /***********************************************************************/
 
-int DecisionTree::calcTreeVariableImportanceAndError() {
+int DecisionTree::calcTreeVariableImportanceAndError(int& numCorrect, double& treeErrorRate) {
     try {
+        vector< vector<int> > randomlySampledTestData(bootstrappedTestSamples.size(), vector<int>(bootstrappedTestSamples[0].size(), 0));
         
-        int numCorrect;
-        double treeErrorRate;
-        calcTreeErrorRate(numCorrect, treeErrorRate);
+            // TODO: is is possible to further speed up the following O(N^2) by using std::copy?
+        for (int i = 0; i < bootstrappedTestSamples.size(); i++) {
+            for (int j = 0; j < bootstrappedTestSamples[i].size(); j++) {
+                randomlySampledTestData[i][j] = bootstrappedTestSamples[i][j];
+            }
+        }
         
-        if (m->control_pressed) {return 0; }
-                
         for (int i = 0; i < numFeatures; i++) {
-            if (m->control_pressed) {return 0; }
-            // NOTE: only shuffle the features, never shuffle the output vector
-            // so i = 0 and i will be alwaays <= (numFeatures - 1) as the index at numFeatures will denote
-            // the feature vector
-            vector< vector<int> > randomlySampledTestData = randomlyShuffleAttribute(bootstrappedTestSamples, i);
+            if (m->control_pressed) { return 0; }
             
-            int numCorrectAfterShuffle = 0;
-            for (int j = 0; j < randomlySampledTestData.size(); j++) {
-                if (m->control_pressed) {return 0; }
-                vector<int> shuffledSample = randomlySampledTestData[j];
-                int actualSampleOutputClass = shuffledSample[numFeatures];
-                int predictedSampleOutputClass = evaluateSample(shuffledSample);
-                if (actualSampleOutputClass == predictedSampleOutputClass) { numCorrectAfterShuffle++; }
+                // if the index is in globalDiscardedFeatureIndices (i.e, null feature) we don't want to shuffle them
+            vector<int>::iterator it = find(globalDiscardedFeatureIndices.begin(), globalDiscardedFeatureIndices.end(), i);
+            if (it == globalDiscardedFeatureIndices.end()) {        // NOT FOUND
+                // if the standard deviation is very low, we know it's not a good feature at all
+                // we can save some time here by discarding that feature
+                
+                vector<int> featureVector = testSampleFeatureVectors[i];
+                if (m->getStandardDeviation(featureVector) > featureStandardDeviationThreshold) {
+                    // NOTE: only shuffle the features, never shuffle the output vector
+                    // so i = 0 and i will be alwaays <= (numFeatures - 1) as the index at numFeatures will denote
+                    // the feature vector
+                    randomlyShuffleAttribute(bootstrappedTestSamples, i, i - 1, randomlySampledTestData);
+
+                    int numCorrectAfterShuffle = 0;
+                    for (int j = 0; j < randomlySampledTestData.size(); j++) {
+                        if (m->control_pressed) {return 0; }
+                        
+                        vector<int> shuffledSample = randomlySampledTestData[j];
+                        int actualSampleOutputClass = shuffledSample[numFeatures];
+                        int predictedSampleOutputClass = evaluateSample(shuffledSample);
+                        if (actualSampleOutputClass == predictedSampleOutputClass) { numCorrectAfterShuffle++; }
+                    }
+                    variableImportanceList[i] += (numCorrect - numCorrectAfterShuffle);
+                }
             }
-            variableImportanceList[i] += (numCorrect - numCorrectAfterShuffle);
         }
         
         // TODO: do we need to save the variableRanks in the DecisionTree, do we need it later?
-        vector< vector<int> > variableRanks;
+        vector< pair<int, int> > variableRanks;
+        
         for (int i = 0; i < variableImportanceList.size(); i++) {
             if (m->control_pressed) {return 0; }
             if (variableImportanceList[i] > 0) {
                 // TODO: is there a way to optimize the follow line's code?
-                vector<int> variableRank(2, 0);
-                variableRank[0] = i; variableRank[1] = variableImportanceList[i];
+                pair<int, int> variableRank(0, 0);
+                variableRank.first = i;
+                variableRank.second = variableImportanceList[i];
                 variableRanks.push_back(variableRank);
             }
         }
@@ -84,8 +105,10 @@ int DecisionTree::evaluateSample(vector<int> testSample) {
     try {
         RFTreeNode *node = rootNode;
         while (true) {
-            if (m->control_pressed) {return 0; }
+            if (m->control_pressed) { return 0; }
+            
             if (node->checkIsLeaf()) { return node->getOutputClass(); }
+            
             int sampleSplitFeatureValue = testSample[node->getSplitFeatureIndex()];
             if (sampleSplitFeatureValue < node->getSplitFeatureValue()) { node = node->getLeftChildNode(); }
             else { node = node->getRightChildNode(); } 
@@ -101,8 +124,8 @@ int DecisionTree::evaluateSample(vector<int> testSample) {
 /***********************************************************************/
 
 int DecisionTree::calcTreeErrorRate(int& numCorrect, double& treeErrorRate){
+    numCorrect = 0;
     try {
-        numCorrect = 0;
         for (int i = 0; i < bootstrappedTestSamples.size(); i++) {
              if (m->control_pressed) {return 0; }
             
@@ -128,35 +151,44 @@ int DecisionTree::calcTreeErrorRate(int& numCorrect, double& treeErrorRate){
 }
 
 /***********************************************************************/
-
 // TODO: optimize the algo, instead of transposing two time, we can extarct the feature,
 // shuffle it and then re-insert in the original place, thus iproving runnting time
 //This function randomize abundances for a given OTU/feature.
-vector< vector<int> > DecisionTree::randomlyShuffleAttribute(vector< vector<int> > samples, int featureIndex) {
+
+void DecisionTree::randomlyShuffleAttribute(const vector< vector<int> >& samples,
+                               const int featureIndex,
+                               const int prevFeatureIndex,
+                               vector< vector<int> >& shuffledSample) {
     try {
         // NOTE: we need (numFeatures + 1) featureVecotors, the last extra vector is actually outputVector
-        vector< vector<int> > shuffledSample = samples;
-        vector<int> featureVectors(samples.size(), 0);
         
+        // restore previously shuffled feature
+        if (prevFeatureIndex > -1) {
+            for (int j = 0; j < samples.size(); j++) {
+                if (m->control_pressed) { return; }
+                shuffledSample[j][prevFeatureIndex] = samples[j][prevFeatureIndex];
+            }
+        }
+        
+        // now do the shuffling
+        vector<int> featureVectors(samples.size(), 0);
         for (int j = 0; j < samples.size(); j++) {
-            if (m->control_pressed) { return shuffledSample; }
+            if (m->control_pressed) { return; }
             featureVectors[j] = samples[j][featureIndex];
         }
-        
         random_shuffle(featureVectors.begin(), featureVectors.end());
-
         for (int j = 0; j < samples.size(); j++) {
-            if (m->control_pressed) {return shuffledSample; }
+            if (m->control_pressed) { return; }
             shuffledSample[j][featureIndex] = featureVectors[j];
         }
-        
-        return shuffledSample;
     }
        catch(exception& e) {
-               m->errorOut(e, "DecisionTree", "randomlyShuffleAttribute");
+        m->errorOut(e, "DecisionTree", "randomlyShuffleAttribute");
                exit(1);
-       } 
+       }
+    
 }
+
 /***********************************************************************/
 
 int DecisionTree::purgeTreeNodesDataRecursively(RFTreeNode* treeNode) {
@@ -181,10 +213,12 @@ int DecisionTree::purgeTreeNodesDataRecursively(RFTreeNode* treeNode) {
 void DecisionTree::buildDecisionTree(){
     try {
     
-    int generation = 0;
-    rootNode = new RFTreeNode(bootstrappedTrainingSamples, globalDiscardedFeatureIndices, numFeatures, numSamples, numOutputClasses, generation);
-    
-    splitRecursively(rootNode);
+        int generation = 0;
+        rootNode = new RFTreeNode(bootstrappedTrainingSamples, globalDiscardedFeatureIndices, numFeatures, numSamples, numOutputClasses, generation, nodeIdCount, featureStandardDeviationThreshold);
+        nodeIdCount++;
+        
+        splitRecursively(rootNode);
+        
         }
        catch(exception& e) {
                m->errorOut(e, "DecisionTree", "buildDecisionTree");
@@ -210,24 +244,32 @@ int DecisionTree::splitRecursively(RFTreeNode* rootNode) {
             rootNode->setOutputClass(classifiedOutputClass);
             return 0;
         }
-        if (m->control_pressed) {return 0;}
+        if (m->control_pressed) { return 0; }
         vector<int> featureSubsetIndices = selectFeatureSubsetRandomly(globalDiscardedFeatureIndices, rootNode->getLocalDiscardedFeatureIndices());
+        
+            // TODO: need to check if the value is actually copied correctly
         rootNode->setFeatureSubsetIndices(featureSubsetIndices);
-        if (m->control_pressed) {return 0;}
+        if (m->control_pressed) { return 0; }
       
         findAndUpdateBestFeatureToSplitOn(rootNode);
         
-        if (m->control_pressed) {return 0;}
+        // update rootNode outputClass, this is needed for pruning
+        // this is only for internal nodes
+        updateOutputClassOfNode(rootNode);
+        
+        if (m->control_pressed) { return 0; }
         
         vector< vector<int> > leftChildSamples;
         vector< vector<int> > rightChildSamples;
         getSplitPopulation(rootNode, leftChildSamples, rightChildSamples);
         
-        if (m->control_pressed) {return 0;}
+        if (m->control_pressed) { return 0; }
         
         // TODO: need to write code to clear this memory
-        RFTreeNode* leftChildNode = new RFTreeNode(leftChildSamples, globalDiscardedFeatureIndices, numFeatures, (int)leftChildSamples.size(), numOutputClasses, rootNode->getGeneration() + 1);
-        RFTreeNode* rightChildNode = new RFTreeNode(rightChildSamples, globalDiscardedFeatureIndices, numFeatures, (int)rightChildSamples.size(), numOutputClasses, rootNode->getGeneration() + 1);
+        RFTreeNode* leftChildNode = new RFTreeNode(leftChildSamples, globalDiscardedFeatureIndices, numFeatures, (int)leftChildSamples.size(), numOutputClasses, rootNode->getGeneration() + 1, nodeIdCount, featureStandardDeviationThreshold);
+        nodeIdCount++;
+        RFTreeNode* rightChildNode = new RFTreeNode(rightChildSamples, globalDiscardedFeatureIndices, numFeatures, (int)rightChildSamples.size(), numOutputClasses, rootNode->getGeneration() + 1, nodeIdCount, featureStandardDeviationThreshold);
+        nodeIdCount++;
         
         rootNode->setLeftChildNode(leftChildNode);
         leftChildNode->setParentNode(rootNode);
@@ -237,7 +279,7 @@ int DecisionTree::splitRecursively(RFTreeNode* rootNode) {
         
         // TODO: This recursive split can be parrallelized later
         splitRecursively(leftChildNode);
-        if (m->control_pressed) {return 0;}
+        if (m->control_pressed) { return 0; }
         
         splitRecursively(rightChildNode);
         return 0;
@@ -254,11 +296,11 @@ int DecisionTree::findAndUpdateBestFeatureToSplitOn(RFTreeNode* node){
     try {
 
         vector< vector<int> > bootstrappedFeatureVectors = node->getBootstrappedFeatureVectors();
-        if (m->control_pressed) {return 0;}
+        if (m->control_pressed) { return 0; }
         vector<int> bootstrappedOutputVector = node->getBootstrappedOutputVector();
-        if (m->control_pressed) {return 0;}
+        if (m->control_pressed) { return 0; }
         vector<int> featureSubsetIndices = node->getFeatureSubsetIndices();
-        if (m->control_pressed) {return 0;}
+        if (m->control_pressed) { return 0; }
         
         vector<double> featureSubsetEntropies;
         vector<int> featureSubsetSplitValues;
@@ -266,7 +308,7 @@ int DecisionTree::findAndUpdateBestFeatureToSplitOn(RFTreeNode* node){
         vector<double> featureSubsetGainRatios;
         
         for (int i = 0; i < featureSubsetIndices.size(); i++) {
-            if (m->control_pressed) {return 0;}
+            if (m->control_pressed) { return 0; }
             
             int tryIndex = featureSubsetIndices[i];
                        
@@ -275,7 +317,7 @@ int DecisionTree::findAndUpdateBestFeatureToSplitOn(RFTreeNode* node){
             double featureIntrinsicValue;
             
             getMinEntropyOfFeature(bootstrappedFeatureVectors[tryIndex], bootstrappedOutputVector, featureMinEntropy, featureSplitValue, featureIntrinsicValue);
-            if (m->control_pressed) {return 0;}
+            if (m->control_pressed) { return 0; }
             
             featureSubsetEntropies.push_back(featureMinEntropy);
             featureSubsetSplitValues.push_back(featureSplitValue);
@@ -290,23 +332,33 @@ int DecisionTree::findAndUpdateBestFeatureToSplitOn(RFTreeNode* node){
         vector<double>::iterator minEntropyIterator = min_element(featureSubsetEntropies.begin(), featureSubsetEntropies.end());
         vector<double>::iterator maxGainRatioIterator = max_element(featureSubsetGainRatios.begin(), featureSubsetGainRatios.end());
         double featureMinEntropy = *minEntropyIterator;
-        //double featureMaxGainRatio = *maxGainRatioIterator;
+        
+        // TODO: kept the following line as future reference, can be useful
+        // double featureMaxGainRatio = *maxGainRatioIterator;
         
         double bestFeatureSplitEntropy = featureMinEntropy;
         int bestFeatureToSplitOnIndex = -1;
-        if (treeSplitCriterion == "gainRatio"){ 
+        if (treeSplitCriterion == "gainratio"){
             bestFeatureToSplitOnIndex = (int)(maxGainRatioIterator - featureSubsetGainRatios.begin());
             // if using 'gainRatio' measure, then featureMinEntropy must be re-updated, as the index
             // for 'featureMaxGainRatio' would be different
             bestFeatureSplitEntropy = featureSubsetEntropies[bestFeatureToSplitOnIndex];
+        } else  if ( treeSplitCriterion == "infogain"){
+            bestFeatureToSplitOnIndex = (int)(minEntropyIterator - featureSubsetEntropies.begin());
+        } else {
+                // TODO: we need an abort mechanism here
         }
-        else { bestFeatureToSplitOnIndex = (int)(minEntropyIterator - featureSubsetEntropies.begin()); }
+        
+            // TODO: is the following line needed? kept is as future reference
+        // splitInformationGain = node.ownEntropy - node.splitFeatureEntropy
         
         int bestFeatureSplitValue = featureSubsetSplitValues[bestFeatureToSplitOnIndex];
         
         node->setSplitFeatureIndex(featureSubsetIndices[bestFeatureToSplitOnIndex]);
         node->setSplitFeatureValue(bestFeatureSplitValue);
         node->setSplitFeatureEntropy(bestFeatureSplitEntropy);
+            // TODO: kept the following line as future reference
+        // node.splitInformationGain = splitInformationGain
         
         return 0;
     }
@@ -363,17 +415,17 @@ vector<int> DecisionTree::selectFeatureSubsetRandomly(vector<int> globalDiscarde
 int DecisionTree::printTree(RFTreeNode* treeNode, string caption){
     try { 
         string tabs = "";
-        for (int i = 0; i < treeNode->getGeneration(); i++) { tabs += "   "; }
+        for (int i = 0; i < treeNode->getGeneration(); i++) { tabs += "|--"; }
         //    for (int i = 0; i < treeNode->getGeneration() - 1; i++) { tabs += "|  "; }
         //    if (treeNode->getGeneration() != 0) { tabs += "|--"; }
         
         if (treeNode != NULL && treeNode->checkIsLeaf() == false){
-            m->mothurOut(tabs + caption + " [ gen: " + toString(treeNode->getGeneration()) + " ] ( " + toString(treeNode->getSplitFeatureValue()) + " < X" + toString(treeNode->getSplitFeatureIndex()) +" )\n");
+            m->mothurOut(tabs + caption + " [ gen: " + toString(treeNode->getGeneration()) + " , id: " + toString(treeNode->nodeId) + " ] ( " + toString(treeNode->getSplitFeatureValue()) + " < X" + toString(treeNode->getSplitFeatureIndex()) + " ) ( predicted: " + toString(treeNode->outputClass) + " , misclassified: " + toString(treeNode->testSampleMisclassificationCount) + " )\n");
             
-            printTree(treeNode->getLeftChildNode(), "leftChild");
-            printTree(treeNode->getRightChildNode(), "rightChild");
+            printTree(treeNode->getLeftChildNode(), "left ");
+            printTree(treeNode->getRightChildNode(), "right");
         }else {
-            m->mothurOut(tabs + caption + " [ gen: " + toString(treeNode->getGeneration()) + " ] ( classified to: " + toString(treeNode->getOutputClass()) + ", samples: " + toString(treeNode->getNumSamples()) + " )\n");
+            m->mothurOut(tabs + caption + " [ gen: " + toString(treeNode->getGeneration()) + " , id: " + toString(treeNode->nodeId) + " ] ( classified to: " + toString(treeNode->getOutputClass()) + ", samples: " + toString(treeNode->getNumSamples()) + " , misclassified: " + toString(treeNode->testSampleMisclassificationCount) + " )\n");
         }
         return 0;
     }
@@ -397,3 +449,87 @@ void DecisionTree::deleteTreeNodesRecursively(RFTreeNode* treeNode) {
 }
 /***********************************************************************/
 
+void DecisionTree::pruneTree(double pruneAggressiveness = 0.9) {
+    
+    // find out the number of misclassification by each of the nodes
+    for (int i = 0; i < bootstrappedTestSamples.size(); i++) {
+        if (m->control_pressed) { return; }
+        
+        vector<int> testSample = bootstrappedTestSamples[i];
+        updateMisclassificationCountRecursively(rootNode, testSample);
+    }
+    
+    // do the actual pruning
+    pruneRecursively(rootNode, pruneAggressiveness);
+}
+/***********************************************************************/
+
+void DecisionTree::pruneRecursively(RFTreeNode* treeNode, double pruneAggressiveness){
+    
+    if (treeNode != NULL && treeNode->checkIsLeaf() == false) {
+        if (m->control_pressed) { return; }
+        
+        pruneRecursively(treeNode->leftChildNode, pruneAggressiveness);
+        pruneRecursively(treeNode->rightChildNode, pruneAggressiveness);
+        
+        int subTreeMisclassificationCount = treeNode->leftChildNode->getTestSampleMisclassificationCount() + treeNode->rightChildNode->getTestSampleMisclassificationCount();
+        int ownMisclassificationCount = treeNode->getTestSampleMisclassificationCount();
+        
+        if (subTreeMisclassificationCount * pruneAggressiveness > ownMisclassificationCount) {
+                // TODO: need to check the effect of these two delete calls
+            delete treeNode->leftChildNode;
+            treeNode->leftChildNode = NULL;
+            
+            delete treeNode->rightChildNode;
+            treeNode->rightChildNode = NULL;
+            
+            treeNode->isLeaf = true;
+        }
+        
+    }
+}
+/***********************************************************************/
+
+void DecisionTree::updateMisclassificationCountRecursively(RFTreeNode* treeNode, vector<int> testSample) {
+    
+    int actualSampleOutputClass = testSample[numFeatures];
+    int nodePredictedOutputClass = treeNode->outputClass;
+    
+    if (actualSampleOutputClass != nodePredictedOutputClass) {
+        treeNode->testSampleMisclassificationCount++;
+        map<int, int>::iterator it = nodeMisclassificationCounts.find(treeNode->nodeId);
+        if (it == nodeMisclassificationCounts.end()) {  // NOT FOUND
+            nodeMisclassificationCounts[treeNode->nodeId] = 0;
+        }
+        nodeMisclassificationCounts[treeNode->nodeId]++;
+    }
+    
+    if (treeNode->checkIsLeaf() == false) { // NOT A LEAF
+        int sampleSplitFeatureValue = testSample[treeNode->splitFeatureIndex];
+        if (sampleSplitFeatureValue < treeNode->splitFeatureValue) {
+            updateMisclassificationCountRecursively(treeNode->leftChildNode, testSample);
+        } else {
+            updateMisclassificationCountRecursively(treeNode->rightChildNode, testSample);
+        }
+    }
+}
+
+/***********************************************************************/
+
+void DecisionTree::updateOutputClassOfNode(RFTreeNode* treeNode) {
+    vector<int> counts(numOutputClasses, 0);
+    for (int i = 0; i < treeNode->bootstrappedOutputVector.size(); i++) {
+        int bootstrappedOutput = treeNode->bootstrappedOutputVector[i];
+        counts[bootstrappedOutput]++;
+    }
+
+    vector<int>::iterator majorityVotedOutputClassCountIterator = max_element(counts.begin(), counts.end());
+    int majorityVotedOutputClassCount = *majorityVotedOutputClassCountIterator;
+    vector<int>::iterator it = find(counts.begin(), counts.end(), majorityVotedOutputClassCount);
+    int majorityVotedOutputClass = (int)(it - counts.begin());
+    treeNode->setOutputClass(majorityVotedOutputClass);
+
+}
+/***********************************************************************/
+
+
index d4441ed738a3049c1ec6fce73b67a91c7b1cdfb3..e890c3214c496f5a208b68744cb18d665cd1aabd 100755 (executable)
@@ -6,8 +6,8 @@
   //  Copyright (c) 2012 Schloss Lab. All rights reserved.
   //
 
-#ifndef rrf_fs_prototype_decisiontree_hpp
-#define rrf_fs_prototype_decisiontree_hpp
+#ifndef RF_DECISIONTREE_HPP
+#define RF_DECISIONTREE_HPP
 
 #include "macros.h"
 #include "rftreenode.hpp"
 /***********************************************************************/
 
 struct VariableRankDescendingSorter {
-  bool operator() (vector<int> first, vector<int> second){ return first[1] > second[1]; }
+  bool operator() (const pair<int, int>& firstPair, const pair<int, int>& secondPair){
+      return firstPair.second > secondPair.second;
+  }
 };
 struct VariableRankDescendingSorterDouble {
-    bool operator() (vector<double> first, vector<double> second){ return first[1] > second[1]; }
+    bool operator() (const pair<int, double>& firstPair, const pair<int, double>& secondPair){
+        return firstPair.second > secondPair.second;
+    }
 };
 /***********************************************************************/
 
@@ -29,19 +33,31 @@ class DecisionTree: public AbstractDecisionTree{
     
 public:
     
-    DecisionTree(vector< vector<int> > baseDataSet,
+    DecisionTree(vector< vector<int> >& baseDataSet,
                  vector<int> globalDiscardedFeatureIndices,
                  OptimumFeatureSubsetSelector optimumFeatureSubsetSelector,
-                 string treeSplitCriterion);
+                 string treeSplitCriterion,
+                 float featureStandardDeviationThreshold);
+    
     virtual ~DecisionTree(){ deleteTreeNodesRecursively(rootNode); }
     
-    int calcTreeVariableImportanceAndError();
+    int calcTreeVariableImportanceAndError(int& numCorrect, double& treeErrorRate);
     int evaluateSample(vector<int> testSample);
     int calcTreeErrorRate(int& numCorrect, double& treeErrorRate);
-    vector< vector<int> > randomlyShuffleAttribute(vector< vector<int> > samples, int featureIndex);  
+    
+    void randomlyShuffleAttribute(const vector< vector<int> >& samples,
+                                  const int featureIndex,
+                                  const int prevFeatureIndex,
+                                  vector< vector<int> >& shuffledSample);
+    
     void purgeDataSetsFromTree() { purgeTreeNodesDataRecursively(rootNode); }
     int purgeTreeNodesDataRecursively(RFTreeNode* treeNode);
     
+    void pruneTree(double pruneAggressiveness);
+    void pruneRecursively(RFTreeNode* treeNode, double pruneAggressiveness);
+    void updateMisclassificationCountRecursively(RFTreeNode* treeNode, vector<int> testSample);
+    void updateOutputClassOfNode(RFTreeNode* treeNode);
+    
     
 private:
     
@@ -54,6 +70,8 @@ private:
     
     vector<int> variableImportanceList;
     map<int, int> outOfBagEstimates;
+  
+    float featureStandardDeviationThreshold;
 };
 
 #endif
index 179ecef19d26de2a61f68776445dfdfc72132f07..3cfb5b9054cf5d43fbbb35eb7b81a6a2ba0f5833 100644 (file)
 
 /***********************************************************************/
 Forest::Forest(const std::vector < std::vector<int> > dataSet,
-                                           const int numDecisionTrees,
-                                           const string treeSplitCriterion = "informationGain")
-: dataSet(dataSet),
-numDecisionTrees(numDecisionTrees),
-numSamples((int)dataSet.size()),
-numFeatures((int)(dataSet[0].size() - 1)),
-globalVariableImportanceList(numFeatures, 0),
-treeSplitCriterion(treeSplitCriterion) {
+               const int numDecisionTrees,
+               const string treeSplitCriterion = "gainratio",
+               const bool doPruning = false,
+               const float pruneAggressiveness = 0.9,
+               const bool discardHighErrorTrees = true,
+               const float highErrorTreeDiscardThreshold = 0.4,
+               const string optimumFeatureSubsetSelectionCriteria = "log2",
+               const float featureStandardDeviationThreshold = 0.0)
+      : dataSet(dataSet),
+        numDecisionTrees(numDecisionTrees),
+        numSamples((int)dataSet.size()),
+        numFeatures((int)(dataSet[0].size() - 1)),
+        globalVariableImportanceList(numFeatures, 0),
+        treeSplitCriterion(treeSplitCriterion),
+        doPruning(doPruning),
+        pruneAggressiveness(pruneAggressiveness),
+        discardHighErrorTrees(discardHighErrorTrees),
+        highErrorTreeDiscardThreshold(highErrorTreeDiscardThreshold),
+        optimumFeatureSubsetSelectionCriteria(optimumFeatureSubsetSelectionCriteria),
+        featureStandardDeviationThreshold(featureStandardDeviationThreshold)
+        {
+        
     m = MothurOut::getInstance();
     globalDiscardedFeatureIndices = getGlobalDiscardedFeatureIndices();
     // TODO: double check if the implemenatation of 'globalOutOfBagEstimates' is correct
@@ -40,7 +54,7 @@ vector<int> Forest::getGlobalDiscardedFeatureIndices() {
         for (int i = 0; i < featureVectors.size(); i++) {
             if (m->control_pressed) { return globalDiscardedFeatureIndices; }
             double standardDeviation = m->getStandardDeviation(featureVectors[i]);
-            if (standardDeviation <= 0){ globalDiscardedFeatureIndices.push_back(i); }
+            if (standardDeviation <= featureStandardDeviationThreshold){ globalDiscardedFeatureIndices.push_back(i); }
         }
         
         if (m->debug) {
index 78f61b3646c446bd9279eb09986d7e5d4bdc10df..9e9860de61f2ec10e9bf02192856a3224ca46193 100644 (file)
--- a/forest.h
+++ b/forest.h
@@ -21,8 +21,14 @@ class Forest{
 public:
     // intialization with vectors
     Forest(const std::vector < std::vector<int> > dataSet,
-                         const int numDecisionTrees,
-                         const string);
+           const int numDecisionTrees,
+           const string treeSplitCriterion,
+           const bool doPruning,
+           const float pruneAggressiveness,
+           const bool discardHighErrorTrees,
+           const float highErrorTreeDiscardThreshold,
+           const string optimumFeatureSubsetSelectionCriteria,
+           const float featureStandardDeviationThreshold);
     virtual ~Forest(){ }
     virtual int populateDecisionTrees() = 0;
     virtual int calcForrestErrorRate() = 0;
@@ -53,6 +59,14 @@ protected:
     vector<int> globalDiscardedFeatureIndices;
     vector<double> globalVariableImportanceList;
     string treeSplitCriterion;
+  
+    bool doPruning;
+    float pruneAggressiveness;
+    bool discardHighErrorTrees;
+    float highErrorTreeDiscardThreshold;
+    string optimumFeatureSubsetSelectionCriteria;
+    float featureStandardDeviationThreshold;
+  
     // This is a map of each feature to outcome count of each classes
     // e.g. 1 => [2 7] means feature 1 has 2 outcome of 0 and 7 outcome of 1
     map<int, vector<int> > globalOutOfBagEstimates;
index f95acbed9864506e08041765107339125e7fb273..a9fc627d6ae774e7e151effbc5d700ae72bbe6f6 100755 (executable)
--- a/macros.h
+++ b/macros.h
@@ -6,8 +6,8 @@
 //  Copyright (c) 2012 Schloss Lab. All rights reserved.
 //
 
-#ifndef rrf_fs_prototype_macros_h
-#define rrf_fs_prototype_macros_h
+#ifndef RF_MACROS_H
+#define RF_MACROS_H
 
 #include "mothurout.h" 
 
index bd96cd2f7177633e3d21181c95e7ff2c07682eb2..2ae0eb595c1c8d5f989c44a5e561a0ff3dac2d2b 100644 (file)
 
 /***********************************************************************/
 
-RandomForest::RandomForest(const vector <vector<int> > dataSet,const int numDecisionTrees,
-             const string treeSplitCriterion = "informationGain") : Forest(dataSet, numDecisionTrees, treeSplitCriterion) {
+RandomForest::RandomForest(const vector <vector<int> > dataSet,
+                           const int numDecisionTrees,
+                           const string treeSplitCriterion = "gainratio",
+                           const bool doPruning = false,
+                           const float pruneAggressiveness = 0.9,
+                           const bool discardHighErrorTrees = true,
+                           const float highErrorTreeDiscardThreshold = 0.4,
+                           const string optimumFeatureSubsetSelectionCriteria = "log2",
+                           const float featureStandardDeviationThreshold = 0.0)
+            : Forest(dataSet, numDecisionTrees, treeSplitCriterion, doPruning, pruneAggressiveness, discardHighErrorTrees, highErrorTreeDiscardThreshold, optimumFeatureSubsetSelectionCriteria, featureStandardDeviationThreshold) {
     m = MothurOut::getInstance();
 }
 
@@ -48,45 +56,50 @@ int RandomForest::calcForrestErrorRate() {
 }
 
 /***********************************************************************/
-// DONE
 int RandomForest::calcForrestVariableImportance(string filename) {
     try {
     
-    // TODO: need to add try/catch operators to fix this
-    // follow the link: http://en.wikipedia.org/wiki/Dynamic_cast
+        // follow the link: http://en.wikipedia.org/wiki/Dynamic_cast
         //if you are going to dynamically cast, aren't you undoing the advantage of abstraction. Why abstract at all?
         //could cause maintenance issues later if other types of Abstract decison trees are created that cannot be cast as a decision tree.
-    for (int i = 0; i < decisionTrees.size(); i++) {
-        if (m->control_pressed) { return 0; }
-        
-        DecisionTree* decisionTree = dynamic_cast<DecisionTree*>(decisionTrees[i]);
+        for (int i = 0; i < decisionTrees.size(); i++) {
+            if (m->control_pressed) { return 0; }
+            
+            DecisionTree* decisionTree = dynamic_cast<DecisionTree*>(decisionTrees[i]);
+            
+            for (int j = 0; j < numFeatures; j++) {
+                globalVariableImportanceList[j] += (double)decisionTree->variableImportanceList[j];
+            }
+        }
         
-        for (int j = 0; j < numFeatures; j++) {
-            globalVariableImportanceList[j] += (double)decisionTree->variableImportanceList[j];
+        for (int i = 0;  i < numFeatures; i++) {
+            globalVariableImportanceList[i] /= (double)numDecisionTrees;
         }
-    }
-    
-    for (int i = 0;  i < numFeatures; i++) {
-        cout << "[" << i << ',' << globalVariableImportanceList[i] << "], ";
-        globalVariableImportanceList[i] /= (double)numDecisionTrees;
-    }
-    
-    vector< vector<double> > globalVariableRanks;
-    for (int i = 0; i < globalVariableImportanceList.size(); i++) {
-        if (globalVariableImportanceList[i] > 0) {
-            vector<double> globalVariableRank(2, 0);
-            globalVariableRank[0] = i; globalVariableRank[1] = globalVariableImportanceList[i];
-            globalVariableRanks.push_back(globalVariableRank);
+        
+        vector< pair<int, double> > globalVariableRanks;
+        for (int i = 0; i < globalVariableImportanceList.size(); i++) {
+            //cout << "[" << i << ',' << globalVariableImportanceList[i] << "], ";
+            if (globalVariableImportanceList[i] > 0) {
+                pair<int, double> globalVariableRank(0, 0.0);
+                globalVariableRank.first = i;
+                globalVariableRank.second = globalVariableImportanceList[i];
+                globalVariableRanks.push_back(globalVariableRank);
+            }
         }
-    }
-    
-    VariableRankDescendingSorterDouble variableRankDescendingSorter;
-    sort(globalVariableRanks.begin(), globalVariableRanks.end(), variableRankDescendingSorter);
+        
+//        for (int i = 0; i < globalVariableRanks.size(); i++) {
+//            cout << m->currentBinLabels[(int)globalVariableRanks[i][0]] << '\t' << globalVariableImportanceList[globalVariableRanks[i][0]] << endl;
+//        }
+
+        
+        VariableRankDescendingSorterDouble variableRankDescendingSorter;
+        sort(globalVariableRanks.begin(), globalVariableRanks.end(), variableRankDescendingSorter);
+        
         ofstream out;
         m->openOutputFile(filename, out);
         out <<"OTU\tRank\n";
         for (int i = 0; i < globalVariableRanks.size(); i++) {
-            out << m->currentBinLabels[(int)globalVariableRanks[i][0]] << '\t' << globalVariableImportanceList[globalVariableRanks[i][0]] << endl;
+            out << m->currentBinLabels[(int)globalVariableRanks[i].first] << '\t' << globalVariableImportanceList[globalVariableRanks[i].first] << endl;
         }
         out.close();
         return 0;
@@ -97,27 +110,89 @@ int RandomForest::calcForrestVariableImportance(string filename) {
        }  
 }
 /***********************************************************************/
-// DONE
 int RandomForest::populateDecisionTrees() {
     try {
         
+        vector<double> errorRateImprovements;
+        
         for (int i = 0; i < numDecisionTrees; i++) {
+          
             if (m->control_pressed) { return 0; }
-            if (((i+1) % 10) == 0) {  m->mothurOut("Creating " + toString(i+1) + " (th) Decision tree\n");  }
+            if (((i+1) % 100) == 0) {  m->mothurOut("Creating " + toString(i+1) + " (th) Decision tree\n");  }
+          
             // TODO: need to first fix if we are going to use pointer based system or anything else
-            DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector("log2"), treeSplitCriterion);
-            decisionTree->calcTreeVariableImportanceAndError();
-            if (m->control_pressed) { return 0; }
-            updateGlobalOutOfBagEstimates(decisionTree);
-            if (m->control_pressed) { return 0; }
-            decisionTree->purgeDataSetsFromTree();
-            if (m->control_pressed) { return 0; }
-            decisionTrees.push_back(decisionTree);
+            DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector(optimumFeatureSubsetSelectionCriteria), treeSplitCriterion, featureStandardDeviationThreshold);
+          
+            if (m->debug && doPruning) {
+                m->mothurOut("Before pruning\n");
+                decisionTree->printTree(decisionTree->rootNode, "ROOT");
+            }
+            
+            int numCorrect;
+            double treeErrorRate;
+            
+            decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate);
+            double prePrunedErrorRate = treeErrorRate;
+            
+            if (m->debug) {
+                m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n");
+            }
+            
+            if (doPruning) {
+                decisionTree->pruneTree(pruneAggressiveness);
+                if (m->debug) {
+                    m->mothurOut("After pruning\n");
+                    decisionTree->printTree(decisionTree->rootNode, "ROOT");
+                }
+                decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate);
+            }
+            double postPrunedErrorRate = treeErrorRate;
+            
+          
+            decisionTree->calcTreeVariableImportanceAndError(numCorrect, treeErrorRate);
+            double errorRateImprovement = (prePrunedErrorRate - postPrunedErrorRate) / prePrunedErrorRate;
+
+            if (m->debug) {
+                m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n");
+                if (doPruning) {
+                    m->mothurOut("errorRateImprovement: " + toString(errorRateImprovement) + "\n");
+                }
+            }
+            
+            
+            if (discardHighErrorTrees) {
+                if (treeErrorRate < highErrorTreeDiscardThreshold) {
+                    updateGlobalOutOfBagEstimates(decisionTree);
+                    decisionTree->purgeDataSetsFromTree();
+                    decisionTrees.push_back(decisionTree);
+                    if (doPruning) {
+                        errorRateImprovements.push_back(errorRateImprovement);
+                    }
+                } else {
+                    delete decisionTree;
+                }
+            } else {
+                updateGlobalOutOfBagEstimates(decisionTree);
+                decisionTree->purgeDataSetsFromTree();
+                decisionTrees.push_back(decisionTree);
+                if (doPruning) {
+                    errorRateImprovements.push_back(errorRateImprovement);
+                }
+            }          
+        }
+        
+        double avgErrorRateImprovement = -1.0;
+        if (errorRateImprovements.size() > 0) {
+            avgErrorRateImprovement = accumulate(errorRateImprovements.begin(), errorRateImprovements.end(), 0.0);
+//            cout << "Total " << avgErrorRateImprovement << " size " << errorRateImprovements.size() << endl;
+            avgErrorRateImprovement /= errorRateImprovements.size();
         }
         
-        if (m->debug) {
-            // m->mothurOut("globalOutOfBagEstimates = " + toStringVectorMap(globalOutOfBagEstimates)+ "\n");
+        if (m->debug && doPruning) {
+            m->mothurOut("avgErrorRateImprovement:" + toString(avgErrorRateImprovement) + "\n");
         }
+        // m->mothurOut("globalOutOfBagEstimates = " + toStringVectorMap(globalOutOfBagEstimates)+ "\n");
+
         
         return 0;
     }
index 30eb43842f8cb280e1e4e95919909885f2702d28..d0ac1ec063d46d4ff1367e87cf322870f7e49a9c 100644 (file)
@@ -6,8 +6,8 @@
 //  Copyright (c) 2012 Schloss Lab. All rights reserved.
 //
 
-#ifndef rrf_fs_prototype_randomforest_hpp
-#define rrf_fs_prototype_randomforest_hpp
+#ifndef RF_RANDOMFOREST_HPP
+#define RF_RANDOMFOREST_HPP
 
 #include "macros.h"
 #include "forest.h"
@@ -17,20 +17,27 @@ class RandomForest: public Forest {
     
 public:
     
-    // DONE
-    RandomForest(const vector <vector<int> > dataSet,const int numDecisionTrees, const string);
+    RandomForest(const vector <vector<int> > dataSet,
+                 const int numDecisionTrees,
+                 const string treeSplitCriterion,
+                 const bool doPruning,
+                 const float pruneAggressiveness,
+                 const bool discardHighErrorTrees,
+                 const float highErrorTreeDiscardThreshold,
+                 const string optimumFeatureSubsetSelectionCriteria,
+                 const float featureStandardDeviationThreshold);
     
     
     //NOTE:: if you are going to dynamically cast, aren't you undoing the advantage of abstraction. Why abstract at all?
     //could cause maintenance issues later if other types of Abstract decison trees are created that cannot be cast as a decision tree.
-//    virtual ~RandomForest() {
-//        for (vector<AbstractDecisionTree*>::iterator it = decisionTrees.begin(); it != decisionTrees.end(); it++) {
-//            // we know that this is decision tree, so we can do a dynamic_case<DecisionTree*> here
-//            DecisionTree* decisionTree = dynamic_cast<DecisionTree*>(*it);
-//            // calling the destructor by deleting
-//            delete decisionTree;
-//        }
-//    }
+    virtual ~RandomForest() {
+        for (vector<AbstractDecisionTree*>::iterator it = decisionTrees.begin(); it != decisionTrees.end(); it++) {
+            // we know that this is decision tree, so we can do a dynamic_case<DecisionTree*> here
+            DecisionTree* decisionTree = dynamic_cast<DecisionTree*>(*it);
+            // calling the destructor by deleting
+            delete decisionTree;
+        }
+    }
     
     int calcForrestErrorRate();
     int calcForrestVariableImportance(string);
index 9fa19aba00ceb68353f32936c21962e14791a173..0a33f4e3ba6ed2a79a301c2db5c8254ae153784c 100644 (file)
@@ -8,8 +8,14 @@
 
 #include "regularizedrandomforest.h"
 
-RegularizedRandomForest::RegularizedRandomForest(const vector <vector<int> > dataSet,const int numDecisionTrees,
-                           const string treeSplitCriterion = "informationGain") : Forest(dataSet, numDecisionTrees, treeSplitCriterion) {
+RegularizedRandomForest::RegularizedRandomForest(const vector <vector<int> > dataSet,
+                                                 const int numDecisionTrees,
+                                                 const string treeSplitCriterion = "gainratio")
+                        // TODO: update ctor according to basic RandomForest Class
+                      : Forest(dataSet,
+                               numDecisionTrees,
+                               treeSplitCriterion,
+                               false, 0.9, true, 0.4, "log2", 0.0) {
     m = MothurOut::getInstance();
 }
 
index 170cfb16f10d5c87b98efb55969ec733eae7d826..acfae544aa7660215d71643fc741d9a65cb2dfd1 100644 (file)
 
 /***********************************************************************/
 RFTreeNode::RFTreeNode(vector< vector<int> > bootstrappedTrainingSamples,
-           vector<int> globalDiscardedFeatureIndices,
-           int numFeatures,
-           int numSamples,
-           int numOutputClasses,
-           int generation)
+                       vector<int> globalDiscardedFeatureIndices,
+                       int numFeatures,
+                       int numSamples,
+                       int numOutputClasses,
+                       int generation,
+                       int nodeId,
+                       float featureStandardDeviationThreshold)
 
-: bootstrappedTrainingSamples(bootstrappedTrainingSamples),
-globalDiscardedFeatureIndices(globalDiscardedFeatureIndices),
-numFeatures(numFeatures),
-numSamples(numSamples),
-numOutputClasses(numOutputClasses),
-generation(generation),
-isLeaf(false),
-outputClass(-1),
-splitFeatureIndex(-1),
-splitFeatureValue(-1),
-splitFeatureEntropy(-1.0),
-ownEntropy(-1.0),
-bootstrappedFeatureVectors(numFeatures, vector<int>(numSamples, 0)),
-bootstrappedOutputVector(numSamples, 0),
-leftChildNode(NULL),
-rightChildNode(NULL),
-parentNode(NULL) {
+            : bootstrappedTrainingSamples(bootstrappedTrainingSamples),
+            globalDiscardedFeatureIndices(globalDiscardedFeatureIndices),
+            numFeatures(numFeatures),
+            numSamples(numSamples),
+            numOutputClasses(numOutputClasses),
+            generation(generation),
+            isLeaf(false),
+            outputClass(-1),
+            nodeId(nodeId),
+            testSampleMisclassificationCount(0),
+            splitFeatureIndex(-1),
+            splitFeatureValue(-1),
+            splitFeatureEntropy(-1.0),
+            ownEntropy(-1.0),
+            featureStandardDeviationThreshold(featureStandardDeviationThreshold),
+            bootstrappedFeatureVectors(numFeatures, vector<int>(numSamples, 0)),
+            bootstrappedOutputVector(numSamples, 0),
+            leftChildNode(NULL),
+            rightChildNode(NULL),
+            parentNode(NULL) {
+                
     m = MothurOut::getInstance();
     
     for (int i = 0; i < numSamples; i++) {    // just doing a simple transpose of the matrix
@@ -40,7 +46,8 @@ parentNode(NULL) {
         for (int j = 0; j < numFeatures; j++) { bootstrappedFeatureVectors[j][i] = bootstrappedTrainingSamples[i][j]; }
     }
     
-    for (int i = 0; i < numSamples; i++) { if (m->control_pressed) { break; } bootstrappedOutputVector[i] = bootstrappedTrainingSamples[i][numFeatures]; }
+    for (int i = 0; i < numSamples; i++) { if (m->control_pressed) { break; }
+        bootstrappedOutputVector[i] = bootstrappedTrainingSamples[i][numFeatures]; }
     
     createLocalDiscardedFeatureList();
     updateNodeEntropy();
@@ -48,13 +55,14 @@ parentNode(NULL) {
 /***********************************************************************/
 int RFTreeNode::createLocalDiscardedFeatureList(){
     try {
-
+        
         for (int i = 0; i < numFeatures; i++) {
+                // TODO: need to check if bootstrappedFeatureVectors == numFeatures, in python code we are using bootstrappedFeatureVectors instead of numFeatures
             if (m->control_pressed) { return 0; } 
             vector<int>::iterator it = find(globalDiscardedFeatureIndices.begin(), globalDiscardedFeatureIndices.end(), i);
-            if (it == globalDiscardedFeatureIndices.end()){                           // NOT FOUND
+            if (it == globalDiscardedFeatureIndices.end()) {                           // NOT FOUND
                 double standardDeviation = m->getStandardDeviation(bootstrappedFeatureVectors[i]);  
-                if (standardDeviation <= 0){ localDiscardedFeatureIndices.push_back(i); }
+                if (standardDeviation <= featureStandardDeviationThreshold) { localDiscardedFeatureIndices.push_back(i); }
             }
         }
         
@@ -70,7 +78,9 @@ int RFTreeNode::updateNodeEntropy() {
     try {
         
         vector<int> classCounts(numOutputClasses, 0);
-        for (int i = 0; i < bootstrappedOutputVector.size(); i++) { classCounts[bootstrappedOutputVector[i]]++; }
+        for (int i = 0; i < bootstrappedOutputVector.size(); i++) {
+            classCounts[bootstrappedOutputVector[i]]++;
+        }
         int totalClassCounts = accumulate(classCounts.begin(), classCounts.end(), 0);
         double nodeEntropy = 0.0;
         for (int i = 0; i < classCounts.size(); i++) {
index 8987ebcec58b04d8eaaff87e86f514887a4bfd17..d53cc2b614db5b1e9e8d8626127cd58d14bb0c0e 100755 (executable)
@@ -6,8 +6,8 @@
 //  Copyright (c) 2012 Schloss Lab. All rights reserved.
 //
 
-#ifndef rrf_fs_prototype_treenode_hpp
-#define rrf_fs_prototype_treenode_hpp
+#ifndef RF_RFTREENODE_HPP
+#define RF_RFTREENODE_HPP
 
 #include "mothurout.h"
 #include "macros.h"
@@ -16,7 +16,14 @@ class RFTreeNode{
     
 public:
     
-    RFTreeNode(vector< vector<int> > bootstrappedTrainingSamples, vector<int> globalDiscardedFeatureIndices, int numFeatures, int numSamples, int numOutputClasses, int generation);
+    RFTreeNode(vector< vector<int> > bootstrappedTrainingSamples,
+               vector<int> globalDiscardedFeatureIndices,
+               int numFeatures,
+               int numSamples,
+               int numOutputClasses,
+               int generation,
+               int nodeId,
+               float featureStandardDeviationThreshold = 0.0);
     
     virtual ~RFTreeNode(){}
     
@@ -41,6 +48,7 @@ public:
     const vector<int>& getBootstrappedOutputVector() { return bootstrappedOutputVector; }
     const vector<int>& getFeatureSubsetIndices() { return featureSubsetIndices; }
     const double getOwnEntropy() { return ownEntropy; }
+    const int getTestSampleMisclassificationCount() { return testSampleMisclassificationCount; }
     
     // setters
     void setIsLeaf(bool isLeaf) { this->isLeaf = isLeaf; }
@@ -77,6 +85,10 @@ private:
     double splitFeatureEntropy;
     double ownEntropy;
     
+    int nodeId;
+    float featureStandardDeviationThreshold;
+    int testSampleMisclassificationCount;
+    
     RFTreeNode* leftChildNode;
     RFTreeNode* rightChildNode;
     RFTreeNode* parentNode;