]> git.donarmstrong.com Git - mothur.git/blob - rftreenode.cpp
fixes while testing 1.33.0
[mothur.git] / rftreenode.cpp
1 //
2 //  rftreenode.cpp
3 //  Mothur
4 //
5 //  Created by Sarah Westcott on 10/2/12.
6 //  Copyright (c) 2012 Schloss Lab. All rights reserved.
7 //
8
9 #include "rftreenode.hpp"
10
11 /***********************************************************************/
12 RFTreeNode::RFTreeNode(vector< vector<int> > bootstrappedTrainingSamples,
13                        vector<int> globalDiscardedFeatureIndices,
14                        int numFeatures,
15                        int numSamples,
16                        int numOutputClasses,
17                        int generation,
18                        int nodeId,
19                        float featureStandardDeviationThreshold)
20
21             : bootstrappedTrainingSamples(bootstrappedTrainingSamples),
22             globalDiscardedFeatureIndices(globalDiscardedFeatureIndices),
23             numFeatures(numFeatures),
24             numSamples(numSamples),
25             numOutputClasses(numOutputClasses),
26             generation(generation),
27             isLeaf(false),
28             outputClass(-1),
29             nodeId(nodeId),
30             testSampleMisclassificationCount(0),
31             splitFeatureIndex(-1),
32             splitFeatureValue(-1),
33             splitFeatureEntropy(-1.0),
34             ownEntropy(-1.0),
35             featureStandardDeviationThreshold(featureStandardDeviationThreshold),
36             bootstrappedFeatureVectors(numFeatures, vector<int>(numSamples, 0)),
37             bootstrappedOutputVector(numSamples, 0),
38             leftChildNode(NULL),
39             rightChildNode(NULL),
40             parentNode(NULL) {
41                 
42     m = MothurOut::getInstance();
43     
44     for (int i = 0; i < numSamples; i++) {    // just doing a simple transpose of the matrix
45         if (m->control_pressed) { break; }
46         for (int j = 0; j < numFeatures; j++) { bootstrappedFeatureVectors[j][i] = bootstrappedTrainingSamples[i][j]; }
47     }
48     
49     for (int i = 0; i < numSamples; i++) { if (m->control_pressed) { break; }
50         bootstrappedOutputVector[i] = bootstrappedTrainingSamples[i][numFeatures]; }
51     
52     createLocalDiscardedFeatureList();
53     updateNodeEntropy();
54 }
55 /***********************************************************************/
56 int RFTreeNode::createLocalDiscardedFeatureList(){
57     try {
58         
59         for (int i = 0; i < numFeatures; i++) {
60                 // TODO: need to check if bootstrappedFeatureVectors == numFeatures, in python code we are using bootstrappedFeatureVectors instead of numFeatures
61             if (m->control_pressed) { return 0; } 
62             vector<int>::iterator it = find(globalDiscardedFeatureIndices.begin(), globalDiscardedFeatureIndices.end(), i);
63             if (it == globalDiscardedFeatureIndices.end()) {                           // NOT FOUND
64                 double standardDeviation = m->getStandardDeviation(bootstrappedFeatureVectors[i]);  
65                 if (standardDeviation <= featureStandardDeviationThreshold) { localDiscardedFeatureIndices.push_back(i); }
66             }
67         }
68         
69         return 0;
70     }
71     catch(exception& e) {
72         m->errorOut(e, "RFTreeNode", "createLocalDiscardedFeatureList");
73         exit(1);
74     }  
75 }
76 /***********************************************************************/
77 int RFTreeNode::updateNodeEntropy() {
78     try {
79         
80         vector<int> classCounts(numOutputClasses, 0);
81         for (int i = 0; i < bootstrappedOutputVector.size(); i++) {
82             classCounts[bootstrappedOutputVector[i]]++;
83         }
84         int totalClassCounts = accumulate(classCounts.begin(), classCounts.end(), 0);
85         double nodeEntropy = 0.0;
86         for (int i = 0; i < classCounts.size(); i++) {
87             if (m->control_pressed) { return 0; }
88             if (classCounts[i] == 0) continue;
89             double probability = (double)classCounts[i] / (double)totalClassCounts;
90             nodeEntropy += -(probability * log2(probability));
91         }
92         ownEntropy = nodeEntropy;
93         
94         return 0;
95     }
96     catch(exception& e) {
97         m->errorOut(e, "RFTreeNode", "updateNodeEntropy");
98         exit(1);
99     } 
100 }
101
102 /***********************************************************************/