]> git.donarmstrong.com Git - mothur.git/blobdiff - decisiontree.cpp
working on pam
[mothur.git] / decisiontree.cpp
index 99853f303c3562e41c43909948276dd82e0f9c6d..1fd29240c156049f61abd9e9bc1f66132275d16c 100644 (file)
@@ -8,13 +8,18 @@
 
 #include "decisiontree.hpp"
 
-DecisionTree::DecisionTree(vector< vector<int> > baseDataSet,
-             vector<int> globalDiscardedFeatureIndices,
-             OptimumFeatureSubsetSelector optimumFeatureSubsetSelector,
-             string treeSplitCriterion) : AbstractDecisionTree(baseDataSet,
-                       globalDiscardedFeatureIndices,
-                       optimumFeatureSubsetSelector,
-                       treeSplitCriterion), variableImportanceList(numFeatures, 0){
+DecisionTree::DecisionTree(vector< vector<int> >& baseDataSet,
+                           vector<int> globalDiscardedFeatureIndices,
+                           OptimumFeatureSubsetSelector optimumFeatureSubsetSelector,
+                           string treeSplitCriterion,
+                           float featureStandardDeviationThreshold)
+            : AbstractDecisionTree(baseDataSet,
+                                   globalDiscardedFeatureIndices,
+                                   optimumFeatureSubsetSelector,
+                                   treeSplitCriterion),
+            variableImportanceList(numFeatures, 0),
+            featureStandardDeviationThreshold(featureStandardDeviationThreshold) {
+                
     try {
         m = MothurOut::getInstance();
         createBootStrappedSamples();
@@ -28,41 +33,57 @@ DecisionTree::DecisionTree(vector< vector<int> > baseDataSet,
 
 /***********************************************************************/
 
-int DecisionTree::calcTreeVariableImportanceAndError() {
+int DecisionTree::calcTreeVariableImportanceAndError(int& numCorrect, double& treeErrorRate) {
     try {
+        vector< vector<int> > randomlySampledTestData(bootstrappedTestSamples.size(), vector<int>(bootstrappedTestSamples[0].size(), 0));
         
-        int numCorrect;
-        double treeErrorRate;
-        calcTreeErrorRate(numCorrect, treeErrorRate);
+            // TODO: is is possible to further speed up the following O(N^2) by using std::copy?
+        for (int i = 0; i < bootstrappedTestSamples.size(); i++) {
+            for (int j = 0; j < bootstrappedTestSamples[i].size(); j++) {
+                randomlySampledTestData[i][j] = bootstrappedTestSamples[i][j];
+            }
+        }
         
-        if (m->control_pressed) {return 0; }
-                
         for (int i = 0; i < numFeatures; i++) {
-            if (m->control_pressed) {return 0; }
-            // NOTE: only shuffle the features, never shuffle the output vector
-            // so i = 0 and i will be alwaays <= (numFeatures - 1) as the index at numFeatures will denote
-            // the feature vector
-            vector< vector<int> > randomlySampledTestData = randomlyShuffleAttribute(bootstrappedTestSamples, i);
+            if (m->control_pressed) { return 0; }
             
-            int numCorrectAfterShuffle = 0;
-            for (int j = 0; j < randomlySampledTestData.size(); j++) {
-                if (m->control_pressed) {return 0; }
-                vector<int> shuffledSample = randomlySampledTestData[j];
-                int actualSampleOutputClass = shuffledSample[numFeatures];
-                int predictedSampleOutputClass = evaluateSample(shuffledSample);
-                if (actualSampleOutputClass == predictedSampleOutputClass) { numCorrectAfterShuffle++; }
+                // if the index is in globalDiscardedFeatureIndices (i.e, null feature) we don't want to shuffle them
+            vector<int>::iterator it = find(globalDiscardedFeatureIndices.begin(), globalDiscardedFeatureIndices.end(), i);
+            if (it == globalDiscardedFeatureIndices.end()) {        // NOT FOUND
+                // if the standard deviation is very low, we know it's not a good feature at all
+                // we can save some time here by discarding that feature
+                
+                vector<int> featureVector = testSampleFeatureVectors[i];
+                if (m->getStandardDeviation(featureVector) > featureStandardDeviationThreshold) {
+                    // NOTE: only shuffle the features, never shuffle the output vector
+                    // so i = 0 and i will be alwaays <= (numFeatures - 1) as the index at numFeatures will denote
+                    // the feature vector
+                    randomlyShuffleAttribute(bootstrappedTestSamples, i, i - 1, randomlySampledTestData);
+
+                    int numCorrectAfterShuffle = 0;
+                    for (int j = 0; j < randomlySampledTestData.size(); j++) {
+                        if (m->control_pressed) {return 0; }
+                        
+                        vector<int> shuffledSample = randomlySampledTestData[j];
+                        int actualSampleOutputClass = shuffledSample[numFeatures];
+                        int predictedSampleOutputClass = evaluateSample(shuffledSample);
+                        if (actualSampleOutputClass == predictedSampleOutputClass) { numCorrectAfterShuffle++; }
+                    }
+                    variableImportanceList[i] += (numCorrect - numCorrectAfterShuffle);
+                }
             }
-            variableImportanceList[i] += (numCorrect - numCorrectAfterShuffle);
         }
         
         // TODO: do we need to save the variableRanks in the DecisionTree, do we need it later?
-        vector< vector<int> > variableRanks;
+        vector< pair<int, int> > variableRanks;
+        
         for (int i = 0; i < variableImportanceList.size(); i++) {
             if (m->control_pressed) {return 0; }
             if (variableImportanceList[i] > 0) {
                 // TODO: is there a way to optimize the follow line's code?
-                vector<int> variableRank(2, 0);
-                variableRank[0] = i; variableRank[1] = variableImportanceList[i];
+                pair<int, int> variableRank(0, 0);
+                variableRank.first = i;
+                variableRank.second = variableImportanceList[i];
                 variableRanks.push_back(variableRank);
             }
         }
@@ -84,8 +105,10 @@ int DecisionTree::evaluateSample(vector<int> testSample) {
     try {
         RFTreeNode *node = rootNode;
         while (true) {
-            if (m->control_pressed) {return 0; }
+            if (m->control_pressed) { return 0; }
+            
             if (node->checkIsLeaf()) { return node->getOutputClass(); }
+            
             int sampleSplitFeatureValue = testSample[node->getSplitFeatureIndex()];
             if (sampleSplitFeatureValue < node->getSplitFeatureValue()) { node = node->getLeftChildNode(); }
             else { node = node->getRightChildNode(); } 
@@ -101,8 +124,8 @@ int DecisionTree::evaluateSample(vector<int> testSample) {
 /***********************************************************************/
 
 int DecisionTree::calcTreeErrorRate(int& numCorrect, double& treeErrorRate){
+    numCorrect = 0;
     try {
-        numCorrect = 0;
         for (int i = 0; i < bootstrappedTestSamples.size(); i++) {
              if (m->control_pressed) {return 0; }
             
@@ -128,35 +151,44 @@ int DecisionTree::calcTreeErrorRate(int& numCorrect, double& treeErrorRate){
 }
 
 /***********************************************************************/
-
 // TODO: optimize the algo, instead of transposing two time, we can extarct the feature,
 // shuffle it and then re-insert in the original place, thus iproving runnting time
 //This function randomize abundances for a given OTU/feature.
-vector< vector<int> > DecisionTree::randomlyShuffleAttribute(vector< vector<int> > samples, int featureIndex) {
+
+void DecisionTree::randomlyShuffleAttribute(const vector< vector<int> >& samples,
+                               const int featureIndex,
+                               const int prevFeatureIndex,
+                               vector< vector<int> >& shuffledSample) {
     try {
         // NOTE: we need (numFeatures + 1) featureVecotors, the last extra vector is actually outputVector
-        vector< vector<int> > shuffledSample = samples;
-        vector<int> featureVectors(samples.size(), 0);
         
+        // restore previously shuffled feature
+        if (prevFeatureIndex > -1) {
+            for (int j = 0; j < samples.size(); j++) {
+                if (m->control_pressed) { return; }
+                shuffledSample[j][prevFeatureIndex] = samples[j][prevFeatureIndex];
+            }
+        }
+        
+        // now do the shuffling
+        vector<int> featureVectors(samples.size(), 0);
         for (int j = 0; j < samples.size(); j++) {
-            if (m->control_pressed) { return shuffledSample; }
+            if (m->control_pressed) { return; }
             featureVectors[j] = samples[j][featureIndex];
         }
-        
         random_shuffle(featureVectors.begin(), featureVectors.end());
-
         for (int j = 0; j < samples.size(); j++) {
-            if (m->control_pressed) {return shuffledSample; }
+            if (m->control_pressed) { return; }
             shuffledSample[j][featureIndex] = featureVectors[j];
         }
-        
-        return shuffledSample;
     }
        catch(exception& e) {
-               m->errorOut(e, "DecisionTree", "randomlyShuffleAttribute");
+        m->errorOut(e, "DecisionTree", "randomlyShuffleAttribute");
                exit(1);
-       } 
+       }
+    
 }
+
 /***********************************************************************/
 
 int DecisionTree::purgeTreeNodesDataRecursively(RFTreeNode* treeNode) {
@@ -181,10 +213,12 @@ int DecisionTree::purgeTreeNodesDataRecursively(RFTreeNode* treeNode) {
 void DecisionTree::buildDecisionTree(){
     try {
     
-    int generation = 0;
-    rootNode = new RFTreeNode(bootstrappedTrainingSamples, globalDiscardedFeatureIndices, numFeatures, numSamples, numOutputClasses, generation);
-    
-    splitRecursively(rootNode);
+        int generation = 0;
+        rootNode = new RFTreeNode(bootstrappedTrainingSamples, globalDiscardedFeatureIndices, numFeatures, numSamples, numOutputClasses, generation, nodeIdCount, featureStandardDeviationThreshold);
+        nodeIdCount++;
+        
+        splitRecursively(rootNode);
+        
         }
        catch(exception& e) {
                m->errorOut(e, "DecisionTree", "buildDecisionTree");
@@ -210,24 +244,32 @@ int DecisionTree::splitRecursively(RFTreeNode* rootNode) {
             rootNode->setOutputClass(classifiedOutputClass);
             return 0;
         }
-        if (m->control_pressed) {return 0;}
+        if (m->control_pressed) { return 0; }
         vector<int> featureSubsetIndices = selectFeatureSubsetRandomly(globalDiscardedFeatureIndices, rootNode->getLocalDiscardedFeatureIndices());
+        
+            // TODO: need to check if the value is actually copied correctly
         rootNode->setFeatureSubsetIndices(featureSubsetIndices);
-        if (m->control_pressed) {return 0;}
+        if (m->control_pressed) { return 0; }
       
         findAndUpdateBestFeatureToSplitOn(rootNode);
         
-        if (m->control_pressed) {return 0;}
+        // update rootNode outputClass, this is needed for pruning
+        // this is only for internal nodes
+        updateOutputClassOfNode(rootNode);
+        
+        if (m->control_pressed) { return 0; }
         
         vector< vector<int> > leftChildSamples;
         vector< vector<int> > rightChildSamples;
         getSplitPopulation(rootNode, leftChildSamples, rightChildSamples);
         
-        if (m->control_pressed) {return 0;}
+        if (m->control_pressed) { return 0; }
         
         // TODO: need to write code to clear this memory
-        RFTreeNode* leftChildNode = new RFTreeNode(leftChildSamples, globalDiscardedFeatureIndices, numFeatures, (int)leftChildSamples.size(), numOutputClasses, rootNode->getGeneration() + 1);
-        RFTreeNode* rightChildNode = new RFTreeNode(rightChildSamples, globalDiscardedFeatureIndices, numFeatures, (int)rightChildSamples.size(), numOutputClasses, rootNode->getGeneration() + 1);
+        RFTreeNode* leftChildNode = new RFTreeNode(leftChildSamples, globalDiscardedFeatureIndices, numFeatures, (int)leftChildSamples.size(), numOutputClasses, rootNode->getGeneration() + 1, nodeIdCount, featureStandardDeviationThreshold);
+        nodeIdCount++;
+        RFTreeNode* rightChildNode = new RFTreeNode(rightChildSamples, globalDiscardedFeatureIndices, numFeatures, (int)rightChildSamples.size(), numOutputClasses, rootNode->getGeneration() + 1, nodeIdCount, featureStandardDeviationThreshold);
+        nodeIdCount++;
         
         rootNode->setLeftChildNode(leftChildNode);
         leftChildNode->setParentNode(rootNode);
@@ -237,7 +279,7 @@ int DecisionTree::splitRecursively(RFTreeNode* rootNode) {
         
         // TODO: This recursive split can be parrallelized later
         splitRecursively(leftChildNode);
-        if (m->control_pressed) {return 0;}
+        if (m->control_pressed) { return 0; }
         
         splitRecursively(rightChildNode);
         return 0;
@@ -254,11 +296,11 @@ int DecisionTree::findAndUpdateBestFeatureToSplitOn(RFTreeNode* node){
     try {
 
         vector< vector<int> > bootstrappedFeatureVectors = node->getBootstrappedFeatureVectors();
-        if (m->control_pressed) {return 0;}
+        if (m->control_pressed) { return 0; }
         vector<int> bootstrappedOutputVector = node->getBootstrappedOutputVector();
-        if (m->control_pressed) {return 0;}
+        if (m->control_pressed) { return 0; }
         vector<int> featureSubsetIndices = node->getFeatureSubsetIndices();
-        if (m->control_pressed) {return 0;}
+        if (m->control_pressed) { return 0; }
         
         vector<double> featureSubsetEntropies;
         vector<int> featureSubsetSplitValues;
@@ -266,7 +308,7 @@ int DecisionTree::findAndUpdateBestFeatureToSplitOn(RFTreeNode* node){
         vector<double> featureSubsetGainRatios;
         
         for (int i = 0; i < featureSubsetIndices.size(); i++) {
-            if (m->control_pressed) {return 0;}
+            if (m->control_pressed) { return 0; }
             
             int tryIndex = featureSubsetIndices[i];
                        
@@ -275,7 +317,7 @@ int DecisionTree::findAndUpdateBestFeatureToSplitOn(RFTreeNode* node){
             double featureIntrinsicValue;
             
             getMinEntropyOfFeature(bootstrappedFeatureVectors[tryIndex], bootstrappedOutputVector, featureMinEntropy, featureSplitValue, featureIntrinsicValue);
-            if (m->control_pressed) {return 0;}
+            if (m->control_pressed) { return 0; }
             
             featureSubsetEntropies.push_back(featureMinEntropy);
             featureSubsetSplitValues.push_back(featureSplitValue);
@@ -290,23 +332,33 @@ int DecisionTree::findAndUpdateBestFeatureToSplitOn(RFTreeNode* node){
         vector<double>::iterator minEntropyIterator = min_element(featureSubsetEntropies.begin(), featureSubsetEntropies.end());
         vector<double>::iterator maxGainRatioIterator = max_element(featureSubsetGainRatios.begin(), featureSubsetGainRatios.end());
         double featureMinEntropy = *minEntropyIterator;
-        //double featureMaxGainRatio = *maxGainRatioIterator;
+        
+        // TODO: kept the following line as future reference, can be useful
+        // double featureMaxGainRatio = *maxGainRatioIterator;
         
         double bestFeatureSplitEntropy = featureMinEntropy;
         int bestFeatureToSplitOnIndex = -1;
-        if (treeSplitCriterion == "gainRatio"){ 
+        if (treeSplitCriterion == "gainratio"){
             bestFeatureToSplitOnIndex = (int)(maxGainRatioIterator - featureSubsetGainRatios.begin());
             // if using 'gainRatio' measure, then featureMinEntropy must be re-updated, as the index
             // for 'featureMaxGainRatio' would be different
             bestFeatureSplitEntropy = featureSubsetEntropies[bestFeatureToSplitOnIndex];
+        } else  if ( treeSplitCriterion == "infogain"){
+            bestFeatureToSplitOnIndex = (int)(minEntropyIterator - featureSubsetEntropies.begin());
+        } else {
+                // TODO: we need an abort mechanism here
         }
-        else { bestFeatureToSplitOnIndex = (int)(minEntropyIterator - featureSubsetEntropies.begin()); }
+        
+            // TODO: is the following line needed? kept is as future reference
+        // splitInformationGain = node.ownEntropy - node.splitFeatureEntropy
         
         int bestFeatureSplitValue = featureSubsetSplitValues[bestFeatureToSplitOnIndex];
         
         node->setSplitFeatureIndex(featureSubsetIndices[bestFeatureToSplitOnIndex]);
         node->setSplitFeatureValue(bestFeatureSplitValue);
         node->setSplitFeatureEntropy(bestFeatureSplitEntropy);
+            // TODO: kept the following line as future reference
+        // node.splitInformationGain = splitInformationGain
         
         return 0;
     }
@@ -363,17 +415,17 @@ vector<int> DecisionTree::selectFeatureSubsetRandomly(vector<int> globalDiscarde
 int DecisionTree::printTree(RFTreeNode* treeNode, string caption){
     try { 
         string tabs = "";
-        for (int i = 0; i < treeNode->getGeneration(); i++) { tabs += "   "; }
+        for (int i = 0; i < treeNode->getGeneration(); i++) { tabs += "|--"; }
         //    for (int i = 0; i < treeNode->getGeneration() - 1; i++) { tabs += "|  "; }
         //    if (treeNode->getGeneration() != 0) { tabs += "|--"; }
         
         if (treeNode != NULL && treeNode->checkIsLeaf() == false){
-            m->mothurOut(tabs + caption + " [ gen: " + toString(treeNode->getGeneration()) + " ] ( " + toString(treeNode->getSplitFeatureValue()) + " < X" + toString(treeNode->getSplitFeatureIndex()) +" )\n");
+            m->mothurOut(tabs + caption + " [ gen: " + toString(treeNode->getGeneration()) + " , id: " + toString(treeNode->nodeId) + " ] ( " + toString(treeNode->getSplitFeatureValue()) + " < X" + toString(treeNode->getSplitFeatureIndex()) + " ) ( predicted: " + toString(treeNode->outputClass) + " , misclassified: " + toString(treeNode->testSampleMisclassificationCount) + " )\n");
             
-            printTree(treeNode->getLeftChildNode(), "leftChild");
-            printTree(treeNode->getRightChildNode(), "rightChild");
+            printTree(treeNode->getLeftChildNode(), "left ");
+            printTree(treeNode->getRightChildNode(), "right");
         }else {
-            m->mothurOut(tabs + caption + " [ gen: " + toString(treeNode->getGeneration()) + " ] ( classified to: " + toString(treeNode->getOutputClass()) + ", samples: " + toString(treeNode->getNumSamples()) + " )\n");
+            m->mothurOut(tabs + caption + " [ gen: " + toString(treeNode->getGeneration()) + " , id: " + toString(treeNode->nodeId) + " ] ( classified to: " + toString(treeNode->getOutputClass()) + ", samples: " + toString(treeNode->getNumSamples()) + " , misclassified: " + toString(treeNode->testSampleMisclassificationCount) + " )\n");
         }
         return 0;
     }
@@ -388,7 +440,7 @@ void DecisionTree::deleteTreeNodesRecursively(RFTreeNode* treeNode) {
         if (treeNode == NULL) { return; }
         deleteTreeNodesRecursively(treeNode->leftChildNode);
         deleteTreeNodesRecursively(treeNode->rightChildNode);
-        delete treeNode;
+        delete treeNode; treeNode = NULL;
     }
        catch(exception& e) {
                m->errorOut(e, "DecisionTree", "deleteTreeNodesRecursively");
@@ -397,3 +449,87 @@ void DecisionTree::deleteTreeNodesRecursively(RFTreeNode* treeNode) {
 }
 /***********************************************************************/
 
+void DecisionTree::pruneTree(double pruneAggressiveness = 0.9) {
+    
+    // find out the number of misclassification by each of the nodes
+    for (int i = 0; i < bootstrappedTestSamples.size(); i++) {
+        if (m->control_pressed) { return; }
+        
+        vector<int> testSample = bootstrappedTestSamples[i];
+        updateMisclassificationCountRecursively(rootNode, testSample);
+    }
+    
+    // do the actual pruning
+    pruneRecursively(rootNode, pruneAggressiveness);
+}
+/***********************************************************************/
+
+void DecisionTree::pruneRecursively(RFTreeNode* treeNode, double pruneAggressiveness){
+    
+    if (treeNode != NULL && treeNode->checkIsLeaf() == false) {
+        if (m->control_pressed) { return; }
+        
+        pruneRecursively(treeNode->leftChildNode, pruneAggressiveness);
+        pruneRecursively(treeNode->rightChildNode, pruneAggressiveness);
+        
+        int subTreeMisclassificationCount = treeNode->leftChildNode->getTestSampleMisclassificationCount() + treeNode->rightChildNode->getTestSampleMisclassificationCount();
+        int ownMisclassificationCount = treeNode->getTestSampleMisclassificationCount();
+        
+        if (subTreeMisclassificationCount * pruneAggressiveness > ownMisclassificationCount) {
+                // TODO: need to check the effect of these two delete calls
+            delete treeNode->leftChildNode;
+            treeNode->leftChildNode = NULL;
+            
+            delete treeNode->rightChildNode;
+            treeNode->rightChildNode = NULL;
+            
+            treeNode->isLeaf = true;
+        }
+        
+    }
+}
+/***********************************************************************/
+
+void DecisionTree::updateMisclassificationCountRecursively(RFTreeNode* treeNode, vector<int> testSample) {
+    
+    int actualSampleOutputClass = testSample[numFeatures];
+    int nodePredictedOutputClass = treeNode->outputClass;
+    
+    if (actualSampleOutputClass != nodePredictedOutputClass) {
+        treeNode->testSampleMisclassificationCount++;
+        map<int, int>::iterator it = nodeMisclassificationCounts.find(treeNode->nodeId);
+        if (it == nodeMisclassificationCounts.end()) {  // NOT FOUND
+            nodeMisclassificationCounts[treeNode->nodeId] = 0;
+        }
+        nodeMisclassificationCounts[treeNode->nodeId]++;
+    }
+    
+    if (treeNode->checkIsLeaf() == false) { // NOT A LEAF
+        int sampleSplitFeatureValue = testSample[treeNode->splitFeatureIndex];
+        if (sampleSplitFeatureValue < treeNode->splitFeatureValue) {
+            updateMisclassificationCountRecursively(treeNode->leftChildNode, testSample);
+        } else {
+            updateMisclassificationCountRecursively(treeNode->rightChildNode, testSample);
+        }
+    }
+}
+
+/***********************************************************************/
+
+void DecisionTree::updateOutputClassOfNode(RFTreeNode* treeNode) {
+    vector<int> counts(numOutputClasses, 0);
+    for (int i = 0; i < treeNode->bootstrappedOutputVector.size(); i++) {
+        int bootstrappedOutput = treeNode->bootstrappedOutputVector[i];
+        counts[bootstrappedOutput]++;
+    }
+
+    vector<int>::iterator majorityVotedOutputClassCountIterator = max_element(counts.begin(), counts.end());
+    int majorityVotedOutputClassCount = *majorityVotedOutputClassCountIterator;
+    vector<int>::iterator it = find(counts.begin(), counts.end(), majorityVotedOutputClassCount);
+    int majorityVotedOutputClass = (int)(it - counts.begin());
+    treeNode->setOutputClass(majorityVotedOutputClass);
+
+}
+/***********************************************************************/
+
+