]> git.donarmstrong.com Git - mothur.git/blobdiff - rftreenode.cpp
fixes while testing 1.33.0
[mothur.git] / rftreenode.cpp
index 170cfb16f10d5c87b98efb55969ec733eae7d826..acfae544aa7660215d71643fc741d9a65cb2dfd1 100644 (file)
 
 /***********************************************************************/
 RFTreeNode::RFTreeNode(vector< vector<int> > bootstrappedTrainingSamples,
-           vector<int> globalDiscardedFeatureIndices,
-           int numFeatures,
-           int numSamples,
-           int numOutputClasses,
-           int generation)
+                       vector<int> globalDiscardedFeatureIndices,
+                       int numFeatures,
+                       int numSamples,
+                       int numOutputClasses,
+                       int generation,
+                       int nodeId,
+                       float featureStandardDeviationThreshold)
 
-: bootstrappedTrainingSamples(bootstrappedTrainingSamples),
-globalDiscardedFeatureIndices(globalDiscardedFeatureIndices),
-numFeatures(numFeatures),
-numSamples(numSamples),
-numOutputClasses(numOutputClasses),
-generation(generation),
-isLeaf(false),
-outputClass(-1),
-splitFeatureIndex(-1),
-splitFeatureValue(-1),
-splitFeatureEntropy(-1.0),
-ownEntropy(-1.0),
-bootstrappedFeatureVectors(numFeatures, vector<int>(numSamples, 0)),
-bootstrappedOutputVector(numSamples, 0),
-leftChildNode(NULL),
-rightChildNode(NULL),
-parentNode(NULL) {
+            : bootstrappedTrainingSamples(bootstrappedTrainingSamples),
+            globalDiscardedFeatureIndices(globalDiscardedFeatureIndices),
+            numFeatures(numFeatures),
+            numSamples(numSamples),
+            numOutputClasses(numOutputClasses),
+            generation(generation),
+            isLeaf(false),
+            outputClass(-1),
+            nodeId(nodeId),
+            testSampleMisclassificationCount(0),
+            splitFeatureIndex(-1),
+            splitFeatureValue(-1),
+            splitFeatureEntropy(-1.0),
+            ownEntropy(-1.0),
+            featureStandardDeviationThreshold(featureStandardDeviationThreshold),
+            bootstrappedFeatureVectors(numFeatures, vector<int>(numSamples, 0)),
+            bootstrappedOutputVector(numSamples, 0),
+            leftChildNode(NULL),
+            rightChildNode(NULL),
+            parentNode(NULL) {
+                
     m = MothurOut::getInstance();
     
     for (int i = 0; i < numSamples; i++) {    // just doing a simple transpose of the matrix
@@ -40,7 +46,8 @@ parentNode(NULL) {
         for (int j = 0; j < numFeatures; j++) { bootstrappedFeatureVectors[j][i] = bootstrappedTrainingSamples[i][j]; }
     }
     
-    for (int i = 0; i < numSamples; i++) { if (m->control_pressed) { break; } bootstrappedOutputVector[i] = bootstrappedTrainingSamples[i][numFeatures]; }
+    for (int i = 0; i < numSamples; i++) { if (m->control_pressed) { break; }
+        bootstrappedOutputVector[i] = bootstrappedTrainingSamples[i][numFeatures]; }
     
     createLocalDiscardedFeatureList();
     updateNodeEntropy();
@@ -48,13 +55,14 @@ parentNode(NULL) {
 /***********************************************************************/
 int RFTreeNode::createLocalDiscardedFeatureList(){
     try {
-
+        
         for (int i = 0; i < numFeatures; i++) {
+                // TODO: need to check if bootstrappedFeatureVectors == numFeatures, in python code we are using bootstrappedFeatureVectors instead of numFeatures
             if (m->control_pressed) { return 0; } 
             vector<int>::iterator it = find(globalDiscardedFeatureIndices.begin(), globalDiscardedFeatureIndices.end(), i);
-            if (it == globalDiscardedFeatureIndices.end()){                           // NOT FOUND
+            if (it == globalDiscardedFeatureIndices.end()) {                           // NOT FOUND
                 double standardDeviation = m->getStandardDeviation(bootstrappedFeatureVectors[i]);  
-                if (standardDeviation <= 0){ localDiscardedFeatureIndices.push_back(i); }
+                if (standardDeviation <= featureStandardDeviationThreshold) { localDiscardedFeatureIndices.push_back(i); }
             }
         }
         
@@ -70,7 +78,9 @@ int RFTreeNode::updateNodeEntropy() {
     try {
         
         vector<int> classCounts(numOutputClasses, 0);
-        for (int i = 0; i < bootstrappedOutputVector.size(); i++) { classCounts[bootstrappedOutputVector[i]]++; }
+        for (int i = 0; i < bootstrappedOutputVector.size(); i++) {
+            classCounts[bootstrappedOutputVector[i]]++;
+        }
         int totalClassCounts = accumulate(classCounts.begin(), classCounts.end(), 0);
         double nodeEntropy = 0.0;
         for (int i = 0; i < classCounts.size(); i++) {