]> git.donarmstrong.com Git - mothur.git/blob - decisiontree.hpp
changed random forest output filename
[mothur.git] / decisiontree.hpp
1   //
2   //  decisiontree.hpp
3   //  rrf-fs-prototype
4   //
5   //  Created by Abu Zaher Faridee on 5/28/12.
6   //  Copyright (c) 2012 Schloss Lab. All rights reserved.
7   //
8
9 #ifndef RF_DECISIONTREE_HPP
10 #define RF_DECISIONTREE_HPP
11
12 #include "macros.h"
13 #include "rftreenode.hpp"
14 #include "abstractdecisiontree.hpp"
15
16 /***********************************************************************/
17
18 struct VariableRankDescendingSorter {
19   bool operator() (const pair<int, int>& firstPair, const pair<int, int>& secondPair){
20       return firstPair.second > secondPair.second;
21   }
22 };
23 struct VariableRankDescendingSorterDouble {
24     bool operator() (const pair<int, double>& firstPair, const pair<int, double>& secondPair){
25         return firstPair.second > secondPair.second;
26     }
27 };
28 /***********************************************************************/
29
30 class DecisionTree: public AbstractDecisionTree{
31     
32     friend class RandomForest;
33     
34 public:
35     
36     DecisionTree(vector< vector<int> >& baseDataSet,
37                  vector<int> globalDiscardedFeatureIndices,
38                  OptimumFeatureSubsetSelector optimumFeatureSubsetSelector,
39                  string treeSplitCriterion,
40                  float featureStandardDeviationThreshold);
41     
42     virtual ~DecisionTree(){ deleteTreeNodesRecursively(rootNode); }
43     
44     int calcTreeVariableImportanceAndError(int& numCorrect, double& treeErrorRate);
45     int evaluateSample(vector<int> testSample);
46     int calcTreeErrorRate(int& numCorrect, double& treeErrorRate);
47     
48     void randomlyShuffleAttribute(const vector< vector<int> >& samples,
49                                   const int featureIndex,
50                                   const int prevFeatureIndex,
51                                   vector< vector<int> >& shuffledSample);
52     
53     void purgeDataSetsFromTree() { purgeTreeNodesDataRecursively(rootNode); }
54     int purgeTreeNodesDataRecursively(RFTreeNode* treeNode);
55     
56     void pruneTree(double pruneAggressiveness);
57     void pruneRecursively(RFTreeNode* treeNode, double pruneAggressiveness);
58     void updateMisclassificationCountRecursively(RFTreeNode* treeNode, vector<int> testSample);
59     void updateOutputClassOfNode(RFTreeNode* treeNode);
60     
61     
62 private:
63     
64     void buildDecisionTree();
65     int splitRecursively(RFTreeNode* rootNode);
66     int findAndUpdateBestFeatureToSplitOn(RFTreeNode* node);
67     vector<int> selectFeatureSubsetRandomly(vector<int> globalDiscardedFeatureIndices, vector<int> localDiscardedFeatureIndices);
68     int printTree(RFTreeNode* treeNode, string caption);
69     void deleteTreeNodesRecursively(RFTreeNode* treeNode);
70     
71     vector<int> variableImportanceList;
72     map<int, int> outOfBagEstimates;
73   
74     float featureStandardDeviationThreshold;
75 };
76
77 #endif