X-Git-Url: https://git.donarmstrong.com/?a=blobdiff_plain;f=decisiontree.hpp;h=e890c3214c496f5a208b68744cb18d665cd1aabd;hb=a0f1fca79d2ddfa7ad36b4485039c68b5704fe8d;hp=d4441ed738a3049c1ec6fce73b67a91c7b1cdfb3;hpb=90708fe9701e3827e477c82fb3652539c3bf2a0d;p=mothur.git diff --git a/decisiontree.hpp b/decisiontree.hpp index d4441ed..e890c32 100755 --- a/decisiontree.hpp +++ b/decisiontree.hpp @@ -6,8 +6,8 @@ // Copyright (c) 2012 Schloss Lab. All rights reserved. // -#ifndef rrf_fs_prototype_decisiontree_hpp -#define rrf_fs_prototype_decisiontree_hpp +#ifndef RF_DECISIONTREE_HPP +#define RF_DECISIONTREE_HPP #include "macros.h" #include "rftreenode.hpp" @@ -16,10 +16,14 @@ /***********************************************************************/ struct VariableRankDescendingSorter { - bool operator() (vector first, vector second){ return first[1] > second[1]; } + bool operator() (const pair& firstPair, const pair& secondPair){ + return firstPair.second > secondPair.second; + } }; struct VariableRankDescendingSorterDouble { - bool operator() (vector first, vector second){ return first[1] > second[1]; } + bool operator() (const pair& firstPair, const pair& secondPair){ + return firstPair.second > secondPair.second; + } }; /***********************************************************************/ @@ -29,19 +33,31 @@ class DecisionTree: public AbstractDecisionTree{ public: - DecisionTree(vector< vector > baseDataSet, + DecisionTree(vector< vector >& baseDataSet, vector globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector optimumFeatureSubsetSelector, - string treeSplitCriterion); + string treeSplitCriterion, + float featureStandardDeviationThreshold); + virtual ~DecisionTree(){ deleteTreeNodesRecursively(rootNode); } - int calcTreeVariableImportanceAndError(); + int calcTreeVariableImportanceAndError(int& numCorrect, double& treeErrorRate); int evaluateSample(vector testSample); int calcTreeErrorRate(int& numCorrect, double& treeErrorRate); - vector< vector > randomlyShuffleAttribute(vector< vector > samples, int featureIndex); + + void randomlyShuffleAttribute(const vector< vector >& samples, + const int featureIndex, + const int prevFeatureIndex, + vector< vector >& shuffledSample); + void purgeDataSetsFromTree() { purgeTreeNodesDataRecursively(rootNode); } int purgeTreeNodesDataRecursively(RFTreeNode* treeNode); + void pruneTree(double pruneAggressiveness); + void pruneRecursively(RFTreeNode* treeNode, double pruneAggressiveness); + void updateMisclassificationCountRecursively(RFTreeNode* treeNode, vector testSample); + void updateOutputClassOfNode(RFTreeNode* treeNode); + private: @@ -54,6 +70,8 @@ private: vector variableImportanceList; map outOfBagEstimates; + + float featureStandardDeviationThreshold; }; #endif