5 // Created by Sarah Westcott on 10/2/12.
6 // Copyright (c) 2012 Schloss Lab. All rights reserved.
9 #include "rftreenode.hpp"
11 /***********************************************************************/
12 RFTreeNode::RFTreeNode(vector< vector<int> > bootstrappedTrainingSamples,
13 vector<int> globalDiscardedFeatureIndices,
19 float featureStandardDeviationThreshold)
21 : bootstrappedTrainingSamples(bootstrappedTrainingSamples),
22 globalDiscardedFeatureIndices(globalDiscardedFeatureIndices),
23 numFeatures(numFeatures),
24 numSamples(numSamples),
25 numOutputClasses(numOutputClasses),
26 generation(generation),
30 testSampleMisclassificationCount(0),
31 splitFeatureIndex(-1),
32 splitFeatureValue(-1),
33 splitFeatureEntropy(-1.0),
35 featureStandardDeviationThreshold(featureStandardDeviationThreshold),
36 bootstrappedFeatureVectors(numFeatures, vector<int>(numSamples, 0)),
37 bootstrappedOutputVector(numSamples, 0),
42 m = MothurOut::getInstance();
44 for (int i = 0; i < numSamples; i++) { // just doing a simple transpose of the matrix
45 if (m->control_pressed) { break; }
46 for (int j = 0; j < numFeatures; j++) { bootstrappedFeatureVectors[j][i] = bootstrappedTrainingSamples[i][j]; }
49 for (int i = 0; i < numSamples; i++) { if (m->control_pressed) { break; }
50 bootstrappedOutputVector[i] = bootstrappedTrainingSamples[i][numFeatures]; }
52 createLocalDiscardedFeatureList();
55 /***********************************************************************/
56 int RFTreeNode::createLocalDiscardedFeatureList(){
59 for (int i = 0; i < numFeatures; i++) {
60 // TODO: need to check if bootstrappedFeatureVectors == numFeatures, in python code we are using bootstrappedFeatureVectors instead of numFeatures
61 if (m->control_pressed) { return 0; }
62 vector<int>::iterator it = find(globalDiscardedFeatureIndices.begin(), globalDiscardedFeatureIndices.end(), i);
63 if (it == globalDiscardedFeatureIndices.end()) { // NOT FOUND
64 double standardDeviation = m->getStandardDeviation(bootstrappedFeatureVectors[i]);
65 if (standardDeviation <= featureStandardDeviationThreshold) { localDiscardedFeatureIndices.push_back(i); }
72 m->errorOut(e, "RFTreeNode", "createLocalDiscardedFeatureList");
76 /***********************************************************************/
77 int RFTreeNode::updateNodeEntropy() {
80 vector<int> classCounts(numOutputClasses, 0);
81 for (int i = 0; i < bootstrappedOutputVector.size(); i++) {
82 classCounts[bootstrappedOutputVector[i]]++;
84 int totalClassCounts = accumulate(classCounts.begin(), classCounts.end(), 0);
85 double nodeEntropy = 0.0;
86 for (int i = 0; i < classCounts.size(); i++) {
87 if (m->control_pressed) { return 0; }
88 if (classCounts[i] == 0) continue;
89 double probability = (double)classCounts[i] / (double)totalClassCounts;
90 nodeEntropy += -(probability * log2(probability));
92 ownEntropy = nodeEntropy;
97 m->errorOut(e, "RFTreeNode", "updateNodeEntropy");
102 /***********************************************************************/