+void DecisionTree::pruneTree(double pruneAggressiveness = 0.9) {
+
+ // find out the number of misclassification by each of the nodes
+ for (int i = 0; i < bootstrappedTestSamples.size(); i++) {
+ if (m->control_pressed) { return; }
+
+ vector<int> testSample = bootstrappedTestSamples[i];
+ updateMisclassificationCountRecursively(rootNode, testSample);
+ }
+
+ // do the actual pruning
+ pruneRecursively(rootNode, pruneAggressiveness);
+}
+/***********************************************************************/
+
+void DecisionTree::pruneRecursively(RFTreeNode* treeNode, double pruneAggressiveness){
+
+ if (treeNode != NULL && treeNode->checkIsLeaf() == false) {
+ if (m->control_pressed) { return; }
+
+ pruneRecursively(treeNode->leftChildNode, pruneAggressiveness);
+ pruneRecursively(treeNode->rightChildNode, pruneAggressiveness);
+
+ int subTreeMisclassificationCount = treeNode->leftChildNode->getTestSampleMisclassificationCount() + treeNode->rightChildNode->getTestSampleMisclassificationCount();
+ int ownMisclassificationCount = treeNode->getTestSampleMisclassificationCount();
+
+ if (subTreeMisclassificationCount * pruneAggressiveness > ownMisclassificationCount) {
+ // TODO: need to check the effect of these two delete calls
+ delete treeNode->leftChildNode;
+ treeNode->leftChildNode = NULL;
+
+ delete treeNode->rightChildNode;
+ treeNode->rightChildNode = NULL;
+
+ treeNode->isLeaf = true;
+ }
+
+ }
+}
+/***********************************************************************/
+
+void DecisionTree::updateMisclassificationCountRecursively(RFTreeNode* treeNode, vector<int> testSample) {
+
+ int actualSampleOutputClass = testSample[numFeatures];
+ int nodePredictedOutputClass = treeNode->outputClass;
+
+ if (actualSampleOutputClass != nodePredictedOutputClass) {
+ treeNode->testSampleMisclassificationCount++;
+ map<int, int>::iterator it = nodeMisclassificationCounts.find(treeNode->nodeId);
+ if (it == nodeMisclassificationCounts.end()) { // NOT FOUND
+ nodeMisclassificationCounts[treeNode->nodeId] = 0;
+ }
+ nodeMisclassificationCounts[treeNode->nodeId]++;
+ }
+
+ if (treeNode->checkIsLeaf() == false) { // NOT A LEAF
+ int sampleSplitFeatureValue = testSample[treeNode->splitFeatureIndex];
+ if (sampleSplitFeatureValue < treeNode->splitFeatureValue) {
+ updateMisclassificationCountRecursively(treeNode->leftChildNode, testSample);
+ } else {
+ updateMisclassificationCountRecursively(treeNode->rightChildNode, testSample);
+ }
+ }
+}
+
+/***********************************************************************/
+
+void DecisionTree::updateOutputClassOfNode(RFTreeNode* treeNode) {
+ vector<int> counts(numOutputClasses, 0);
+ for (int i = 0; i < treeNode->bootstrappedOutputVector.size(); i++) {
+ int bootstrappedOutput = treeNode->bootstrappedOutputVector[i];
+ counts[bootstrappedOutput]++;
+ }
+
+ vector<int>::iterator majorityVotedOutputClassCountIterator = max_element(counts.begin(), counts.end());
+ int majorityVotedOutputClassCount = *majorityVotedOutputClassCountIterator;
+ vector<int>::iterator it = find(counts.begin(), counts.end(), majorityVotedOutputClassCount);
+ int majorityVotedOutputClass = (int)(it - counts.begin());
+ treeNode->setOutputClass(majorityVotedOutputClass);
+
+}
+/***********************************************************************/
+
+