]> git.donarmstrong.com Git - mothur.git/blobdiff - abstractdecisiontree.cpp
working on pam
[mothur.git] / abstractdecisiontree.cpp
index 085cd316225df92402d0a9fe86c623b45acdd553..595f483dc5ba6cde01b825f8ac44d9ceea4ece16 100644 (file)
 
 /**************************************************************************************************/
 
-AbstractDecisionTree::AbstractDecisionTree(vector<vector<int> >baseDataSet, 
-                     vector<int> globalDiscardedFeatureIndices, 
-                     OptimumFeatureSubsetSelector optimumFeatureSubsetSelector, 
-                     string treeSplitCriterion) : baseDataSet(baseDataSet),
-numSamples((int)baseDataSet.size()),
-numFeatures((int)(baseDataSet[0].size() - 1)),
-numOutputClasses(0),
-rootNode(NULL),
-globalDiscardedFeatureIndices(globalDiscardedFeatureIndices),
-optimumFeatureSubsetSize(optimumFeatureSubsetSelector.getOptimumFeatureSubsetSize(numFeatures)),
-treeSplitCriterion(treeSplitCriterion) {
+AbstractDecisionTree::AbstractDecisionTree(vector<vector<int> >& baseDataSet,
+                                         vector<int> globalDiscardedFeatureIndices,
+                                         OptimumFeatureSubsetSelector optimumFeatureSubsetSelector, 
+                                         string treeSplitCriterion)
+
+                    : baseDataSet(baseDataSet),
+                    numSamples((int)baseDataSet.size()),
+                    numFeatures((int)(baseDataSet[0].size() - 1)),
+                    numOutputClasses(0),
+                    rootNode(NULL),
+                    nodeIdCount(0),
+                    globalDiscardedFeatureIndices(globalDiscardedFeatureIndices),
+                    optimumFeatureSubsetSize(optimumFeatureSubsetSelector.getOptimumFeatureSubsetSize(numFeatures)),
+                    treeSplitCriterion(treeSplitCriterion) {
 
     try {
-    // TODO: istead of calculating this for every DecisionTree
-    // clacualte this once in the RandomForest class and pass the values
-    m = MothurOut::getInstance();
-    for (int i = 0;  i < numSamples; i++) {
-        if (m->control_pressed) { break; }
-        int outcome = baseDataSet[i][numFeatures];
-        vector<int>::iterator it = find(outputClasses.begin(), outputClasses.end(), outcome);
-        if (it == outputClasses.end()){       // find() will return classes.end() if the element is not found
-            outputClasses.push_back(outcome);
-            numOutputClasses++;
+        // TODO: istead of calculating this for every DecisionTree
+        // clacualte this once in the RandomForest class and pass the values
+        m = MothurOut::getInstance();
+        for (int i = 0;  i < numSamples; i++) {
+            if (m->control_pressed) { break; }
+            int outcome = baseDataSet[i][numFeatures];
+            vector<int>::iterator it = find(outputClasses.begin(), outputClasses.end(), outcome);
+            if (it == outputClasses.end()){       // find() will return classes.end() if the element is not found
+                outputClasses.push_back(outcome);
+                numOutputClasses++;
+            }
+        }
+        
+        if (m->debug) {
+            //m->mothurOut("outputClasses = " + toStringVectorInt(outputClasses));
+            m->mothurOut("numOutputClasses = " + toString(numOutputClasses) + '\n');
         }
-    }
-    
-    if (m->debug) {
-        //m->mothurOut("outputClasses = " + toStringVectorInt(outputClasses));
-        m->mothurOut("numOutputClasses = " + toString(numOutputClasses) + '\n');
-    }
 
     }
        catch(exception& e) {
@@ -50,25 +53,38 @@ treeSplitCriterion(treeSplitCriterion) {
 /**************************************************************************************************/
 int AbstractDecisionTree::createBootStrappedSamples(){
     try {    
-    vector<bool> isInTrainingSamples(numSamples, false);
-    
-    for (int i = 0; i < numSamples; i++) {
-        if (m->control_pressed) { return 0; }
-        // TODO: optimize the rand() function call + double check if it's working properly
-        int randomIndex = rand() % numSamples;
-        bootstrappedTrainingSamples.push_back(baseDataSet[randomIndex]);
-        isInTrainingSamples[randomIndex] = true;
-    }
-    
-    for (int i = 0; i < numSamples; i++) {
-        if (m->control_pressed) { return 0; }
-        if (isInTrainingSamples[i]){ bootstrappedTrainingSampleIndices.push_back(i); }
-        else{
-            bootstrappedTestSamples.push_back(baseDataSet[i]);
-            bootstrappedTestSampleIndices.push_back(i);
+        vector<bool> isInTrainingSamples(numSamples, false);
+        
+        for (int i = 0; i < numSamples; i++) {
+            if (m->control_pressed) { return 0; }
+            // TODO: optimize the rand() function call + double check if it's working properly
+            int randomIndex = rand() % numSamples;
+            bootstrappedTrainingSamples.push_back(baseDataSet[randomIndex]);
+            isInTrainingSamples[randomIndex] = true;
         }
-    }
-    
+        
+        for (int i = 0; i < numSamples; i++) {
+            if (m->control_pressed) { return 0; }
+            if (isInTrainingSamples[i]){ bootstrappedTrainingSampleIndices.push_back(i); }
+            else{
+                bootstrappedTestSamples.push_back(baseDataSet[i]);
+                bootstrappedTestSampleIndices.push_back(i);
+            }
+        }
+        
+            // do the transpose of Test Samples
+        for (int i = 0; i < bootstrappedTestSamples[0].size(); i++) {
+            if (m->control_pressed) { return 0; }
+            
+            vector<int> tmpFeatureVector(bootstrappedTestSamples.size(), 0);
+            for (int j = 0; j < bootstrappedTestSamples.size(); j++) {
+                if (m->control_pressed) { return 0; }
+                
+                tmpFeatureVector[j] = bootstrappedTestSamples[j][i];
+            }
+            testSampleFeatureVectors.push_back(tmpFeatureVector);
+        }
+        
         return 0;
     }
        catch(exception& e) {
@@ -77,25 +93,34 @@ int AbstractDecisionTree::createBootStrappedSamples(){
        } 
 }
 /**************************************************************************************************/
-int AbstractDecisionTree::getMinEntropyOfFeature(vector<int> featureVector, vector<int> outputVector, double& minEntropy, int& featureSplitValue, double& intrinsicValue){
+int AbstractDecisionTree::getMinEntropyOfFeature(vector<int> featureVector,
+                                                 vector<int> outputVector,
+                                                 double& minEntropy,
+                                                 int& featureSplitValue,
+                                                 double& intrinsicValue){
     try {
 
-        vector< vector<int> > featureOutputPair(featureVector.size(), vector<int>(2, 0));
+        vector< pair<int, int> > featureOutputPair(featureVector.size(), pair<int, int>(0, 0));
+        
         for (int i = 0; i < featureVector.size(); i++) { 
             if (m->control_pressed) { return 0; }
-            featureOutputPair[i][0] = featureVector[i];
-            featureOutputPair[i][1] = outputVector[i];
+            
+            featureOutputPair[i].first = featureVector[i];
+            featureOutputPair[i].second = outputVector[i];
         }
-        // TODO: using default behavior to sort(), need to specify the comparator for added safety and compiler portability
-        sort(featureOutputPair.begin(), featureOutputPair.end());
+        // TODO: using default behavior to sort(), need to specify the comparator for added safety and compiler portability,
         
+        IntPairVectorSorter intPairVectorSorter;
+        sort(featureOutputPair.begin(), featureOutputPair.end(), intPairVectorSorter);
         
         vector<int> splitPoints;
-        vector<int> uniqueFeatureValues(1, featureOutputPair[0][0]);
+        vector<int> uniqueFeatureValues(1, featureOutputPair[0].first);
         
         for (int i = 0; i < featureOutputPair.size(); i++) {
+
             if (m->control_pressed) { return 0; }
-            int featureValue = featureOutputPair[i][0];
+            int featureValue = featureOutputPair[i].first;
+
             vector<int>::iterator it = find(uniqueFeatureValues.begin(), uniqueFeatureValues.end(), featureValue);
             if (it == uniqueFeatureValues.end()){                 // NOT FOUND
                 uniqueFeatureValues.push_back(featureValue);
@@ -115,7 +140,7 @@ int AbstractDecisionTree::getMinEntropyOfFeature(vector<int> featureVector, vect
             featureSplitValue = -1;                                                   // OUTPUT
         }else{
             getBestSplitAndMinEntropy(featureOutputPair, splitPoints, minEntropy, bestSplitIndex, intrinsicValue);  // OUTPUT
-            featureSplitValue = featureOutputPair[splitPoints[bestSplitIndex]][0];    // OUTPUT
+            featureSplitValue = featureOutputPair[splitPoints[bestSplitIndex]].first;    // OUTPUT
         }
         
         return 0;
@@ -146,8 +171,9 @@ double AbstractDecisionTree::calcIntrinsicValue(int numLessThanValueAtSplitPoint
        } 
 }
 /**************************************************************************************************/
-int AbstractDecisionTree::getBestSplitAndMinEntropy(vector< vector<int> > featureOutputPairs, vector<int> splitPoints,
-                               double& minEntropy, int& minEntropyIndex, double& relatedIntrinsicValue){
+
+int AbstractDecisionTree::getBestSplitAndMinEntropy(vector< pair<int, int> > featureOutputPairs, vector<int> splitPoints,
+                                                    double& minEntropy, int& minEntropyIndex, double& relatedIntrinsicValue){
     try {
         
         int numSamples = (int)featureOutputPairs.size();
@@ -155,16 +181,17 @@ int AbstractDecisionTree::getBestSplitAndMinEntropy(vector< vector<int> > featur
         vector<double> intrinsicValues;
         
         for (int i = 0; i < splitPoints.size(); i++) {
-             if (m->control_pressed) { return 0; }
+            if (m->control_pressed) { return 0; }
             int index = splitPoints[i];
-            int valueAtSplitPoint = featureOutputPairs[index][0];
+            int valueAtSplitPoint = featureOutputPairs[index].first;
+
             int numLessThanValueAtSplitPoint = 0;
             int numGreaterThanValueAtSplitPoint = 0;
             
             for (int j = 0; j < featureOutputPairs.size(); j++) {
-                 if (m->control_pressed) { return 0; }
-                vector<int> record = featureOutputPairs[j];
-                if (record[0] < valueAtSplitPoint){ numLessThanValueAtSplitPoint++; }
+                if (m->control_pressed) { return 0; }
+                pair<int, int> record = featureOutputPairs[j];
+                if (record.first < valueAtSplitPoint){ numLessThanValueAtSplitPoint++; }
                 else{ numGreaterThanValueAtSplitPoint++; }
             }
             
@@ -193,19 +220,19 @@ int AbstractDecisionTree::getBestSplitAndMinEntropy(vector< vector<int> > featur
 }
 /**************************************************************************************************/
 
-double AbstractDecisionTree::calcSplitEntropy(vector< vector<int> > featureOutputPairs, int splitIndex, int numOutputClasses, bool isUpperSplit = true) {
+double AbstractDecisionTree::calcSplitEntropy(vector< pair<int, int> > featureOutputPairs, int splitIndex, int numOutputClasses, bool isUpperSplit = true) {
     try {
         vector<int> classCounts(numOutputClasses, 0);
         
         if (isUpperSplit) { 
-            for (int i = 0; i < splitIndex; i++) { 
+            for (int i = 0; i < splitIndex; i++) {
                 if (m->control_pressed) { return 0; }
-                classCounts[featureOutputPairs[i][1]]++; 
+                classCounts[featureOutputPairs[i].second]++;
             }
         } else {
             for (int i = splitIndex; i < featureOutputPairs.size(); i++) { 
                 if (m->control_pressed) { return 0; }
-                classCounts[featureOutputPairs[i][1]]++; 
+                classCounts[featureOutputPairs[i].second]++;
             }
         }
         
@@ -245,8 +272,9 @@ int AbstractDecisionTree::getSplitPopulation(RFTreeNode* node, vector< vector<in
             if (m->control_pressed) { return 0; }
             vector<int> sample =  node->getBootstrappedTrainingSamples()[i];
             if (m->control_pressed) { return 0; }
-            if (sample[splitFeatureGlobalIndex] < node->getSplitFeatureValue()){ leftChildSamples.push_back(sample); }
-            else{ rightChildSamples.push_back(sample); }
+            
+            if (sample[splitFeatureGlobalIndex] < node->getSplitFeatureValue()) { leftChildSamples.push_back(sample); }
+            else { rightChildSamples.push_back(sample); }
         }
         
         return 0;