5 // Created by Abu Zaher Faridee on 5/28/12.
6 // Copyright (c) 2012 Schloss Lab. All rights reserved.
9 #ifndef RF_DECISIONTREE_HPP
10 #define RF_DECISIONTREE_HPP
13 #include "rftreenode.hpp"
14 #include "abstractdecisiontree.hpp"
16 /***********************************************************************/
18 struct VariableRankDescendingSorter {
19 bool operator() (const pair<int, int>& firstPair, const pair<int, int>& secondPair){
20 return firstPair.second > secondPair.second;
23 struct VariableRankDescendingSorterDouble {
24 bool operator() (const pair<int, double>& firstPair, const pair<int, double>& secondPair){
25 return firstPair.second > secondPair.second;
28 /***********************************************************************/
30 class DecisionTree: public AbstractDecisionTree{
32 friend class RandomForest;
36 DecisionTree(vector< vector<int> >& baseDataSet,
37 vector<int> globalDiscardedFeatureIndices,
38 OptimumFeatureSubsetSelector optimumFeatureSubsetSelector,
39 string treeSplitCriterion,
40 float featureStandardDeviationThreshold);
42 virtual ~DecisionTree(){ deleteTreeNodesRecursively(rootNode); }
44 int calcTreeVariableImportanceAndError(int& numCorrect, double& treeErrorRate);
45 int evaluateSample(vector<int> testSample);
46 int calcTreeErrorRate(int& numCorrect, double& treeErrorRate);
48 void randomlyShuffleAttribute(const vector< vector<int> >& samples,
49 const int featureIndex,
50 const int prevFeatureIndex,
51 vector< vector<int> >& shuffledSample);
53 void purgeDataSetsFromTree() { purgeTreeNodesDataRecursively(rootNode); }
54 int purgeTreeNodesDataRecursively(RFTreeNode* treeNode);
56 void pruneTree(double pruneAggressiveness);
57 void pruneRecursively(RFTreeNode* treeNode, double pruneAggressiveness);
58 void updateMisclassificationCountRecursively(RFTreeNode* treeNode, vector<int> testSample);
59 void updateOutputClassOfNode(RFTreeNode* treeNode);
64 void buildDecisionTree();
65 int splitRecursively(RFTreeNode* rootNode);
66 int findAndUpdateBestFeatureToSplitOn(RFTreeNode* node);
67 vector<int> selectFeatureSubsetRandomly(vector<int> globalDiscardedFeatureIndices, vector<int> localDiscardedFeatureIndices);
68 int printTree(RFTreeNode* treeNode, string caption);
69 void deleteTreeNodesRecursively(RFTreeNode* treeNode);
71 vector<int> variableImportanceList;
72 map<int, int> outOfBagEstimates;
74 float featureStandardDeviationThreshold;