5 // Created by Sarah Westcott on 10/2/12.
6 // Copyright (c) 2012 Schloss Lab. All rights reserved.
9 #include "rftreenode.hpp"
11 /***********************************************************************/
12 RFTreeNode::RFTreeNode(vector< vector<int> > bootstrappedTrainingSamples,
13 vector<int> globalDiscardedFeatureIndices,
19 : bootstrappedTrainingSamples(bootstrappedTrainingSamples),
20 globalDiscardedFeatureIndices(globalDiscardedFeatureIndices),
21 numFeatures(numFeatures),
22 numSamples(numSamples),
23 numOutputClasses(numOutputClasses),
24 generation(generation),
27 splitFeatureIndex(-1),
28 splitFeatureValue(-1),
29 splitFeatureEntropy(-1.0),
31 bootstrappedFeatureVectors(numFeatures, vector<int>(numSamples, 0)),
32 bootstrappedOutputVector(numSamples, 0),
36 m = MothurOut::getInstance();
38 for (int i = 0; i < numSamples; i++) { // just doing a simple transpose of the matrix
39 if (m->control_pressed) { break; }
40 for (int j = 0; j < numFeatures; j++) { bootstrappedFeatureVectors[j][i] = bootstrappedTrainingSamples[i][j]; }
43 for (int i = 0; i < numSamples; i++) { if (m->control_pressed) { break; } bootstrappedOutputVector[i] = bootstrappedTrainingSamples[i][numFeatures]; }
45 createLocalDiscardedFeatureList();
48 /***********************************************************************/
49 int RFTreeNode::createLocalDiscardedFeatureList(){
52 for (int i = 0; i < numFeatures; i++) {
53 if (m->control_pressed) { return 0; }
54 vector<int>::iterator it = find(globalDiscardedFeatureIndices.begin(), globalDiscardedFeatureIndices.end(), i);
55 if (it == globalDiscardedFeatureIndices.end()){ // NOT FOUND
56 double standardDeviation = m->getStandardDeviation(bootstrappedFeatureVectors[i]);
57 if (standardDeviation <= 0){ localDiscardedFeatureIndices.push_back(i); }
64 m->errorOut(e, "RFTreeNode", "createLocalDiscardedFeatureList");
68 /***********************************************************************/
69 int RFTreeNode::updateNodeEntropy() {
72 vector<int> classCounts(numOutputClasses, 0);
73 for (int i = 0; i < bootstrappedOutputVector.size(); i++) { classCounts[bootstrappedOutputVector[i]]++; }
74 int totalClassCounts = accumulate(classCounts.begin(), classCounts.end(), 0);
75 double nodeEntropy = 0.0;
76 for (int i = 0; i < classCounts.size(); i++) {
77 if (m->control_pressed) { return 0; }
78 if (classCounts[i] == 0) continue;
79 double probability = (double)classCounts[i] / (double)totalClassCounts;
80 nodeEntropy += -(probability * log2(probability));
82 ownEntropy = nodeEntropy;
87 m->errorOut(e, "RFTreeNode", "updateNodeEntropy");
92 /***********************************************************************/