// Copyright (c) 2012 Schloss Lab. All rights reserved.
//
-#ifndef rrf_fs_prototype_decisiontree_hpp
-#define rrf_fs_prototype_decisiontree_hpp
+#ifndef RF_DECISIONTREE_HPP
+#define RF_DECISIONTREE_HPP
#include "macros.h"
#include "rftreenode.hpp"
/***********************************************************************/
struct VariableRankDescendingSorter {
- bool operator() (vector<int> first, vector<int> second){ return first[1] > second[1]; }
+ bool operator() (const pair<int, int>& firstPair, const pair<int, int>& secondPair){
+ return firstPair.second > secondPair.second;
+ }
};
struct VariableRankDescendingSorterDouble {
- bool operator() (vector<double> first, vector<double> second){ return first[1] > second[1]; }
+ bool operator() (const pair<int, double>& firstPair, const pair<int, double>& secondPair){
+ return firstPair.second > secondPair.second;
+ }
};
/***********************************************************************/
public:
- DecisionTree(vector< vector<int> > baseDataSet,
+ DecisionTree(vector< vector<int> >& baseDataSet,
vector<int> globalDiscardedFeatureIndices,
OptimumFeatureSubsetSelector optimumFeatureSubsetSelector,
- string treeSplitCriterion);
+ string treeSplitCriterion,
+ float featureStandardDeviationThreshold);
+
virtual ~DecisionTree(){ deleteTreeNodesRecursively(rootNode); }
- int calcTreeVariableImportanceAndError();
+ int calcTreeVariableImportanceAndError(int& numCorrect, double& treeErrorRate);
int evaluateSample(vector<int> testSample);
int calcTreeErrorRate(int& numCorrect, double& treeErrorRate);
- vector< vector<int> > randomlyShuffleAttribute(vector< vector<int> > samples, int featureIndex);
+
+ void randomlyShuffleAttribute(const vector< vector<int> >& samples,
+ const int featureIndex,
+ const int prevFeatureIndex,
+ vector< vector<int> >& shuffledSample);
+
void purgeDataSetsFromTree() { purgeTreeNodesDataRecursively(rootNode); }
int purgeTreeNodesDataRecursively(RFTreeNode* treeNode);
+ void pruneTree(double pruneAggressiveness);
+ void pruneRecursively(RFTreeNode* treeNode, double pruneAggressiveness);
+ void updateMisclassificationCountRecursively(RFTreeNode* treeNode, vector<int> testSample);
+ void updateOutputClassOfNode(RFTreeNode* treeNode);
+
private:
vector<int> variableImportanceList;
map<int, int> outOfBagEstimates;
+
+ float featureStandardDeviationThreshold;
};
#endif