- DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector("log2"), treeSplitCriterion);
- decisionTree->calcTreeVariableImportanceAndError();
- if (m->control_pressed) { return 0; }
- updateGlobalOutOfBagEstimates(decisionTree);
- if (m->control_pressed) { return 0; }
- decisionTree->purgeDataSetsFromTree();
- if (m->control_pressed) { return 0; }
- decisionTrees.push_back(decisionTree);
+ DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector(optimumFeatureSubsetSelectionCriteria), treeSplitCriterion, featureStandardDeviationThreshold);
+
+ if (m->debug && doPruning) {
+ m->mothurOut("Before pruning\n");
+ decisionTree->printTree(decisionTree->rootNode, "ROOT");
+ }
+
+ int numCorrect;
+ double treeErrorRate;
+
+ decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate);
+ double prePrunedErrorRate = treeErrorRate;
+
+ if (m->debug) {
+ m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n");
+ }
+
+ if (doPruning) {
+ decisionTree->pruneTree(pruneAggressiveness);
+ if (m->debug) {
+ m->mothurOut("After pruning\n");
+ decisionTree->printTree(decisionTree->rootNode, "ROOT");
+ }
+ decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate);
+ }
+ double postPrunedErrorRate = treeErrorRate;
+
+
+ decisionTree->calcTreeVariableImportanceAndError(numCorrect, treeErrorRate);
+ double errorRateImprovement = (prePrunedErrorRate - postPrunedErrorRate) / prePrunedErrorRate;
+
+ if (m->debug) {
+ m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n");
+ if (doPruning) {
+ m->mothurOut("errorRateImprovement: " + toString(errorRateImprovement) + "\n");
+ }
+ }
+
+
+ if (discardHighErrorTrees) {
+ if (treeErrorRate < highErrorTreeDiscardThreshold) {
+ updateGlobalOutOfBagEstimates(decisionTree);
+ decisionTree->purgeDataSetsFromTree();
+ decisionTrees.push_back(decisionTree);
+ if (doPruning) {
+ errorRateImprovements.push_back(errorRateImprovement);
+ }
+ } else {
+ delete decisionTree;
+ }
+ } else {
+ updateGlobalOutOfBagEstimates(decisionTree);
+ decisionTree->purgeDataSetsFromTree();
+ decisionTrees.push_back(decisionTree);
+ if (doPruning) {
+ errorRateImprovements.push_back(errorRateImprovement);
+ }
+ }
+ }
+
+ double avgErrorRateImprovement = -1.0;
+ if (errorRateImprovements.size() > 0) {
+ avgErrorRateImprovement = accumulate(errorRateImprovements.begin(), errorRateImprovements.end(), 0.0);
+// cout << "Total " << avgErrorRateImprovement << " size " << errorRateImprovements.size() << endl;
+ avgErrorRateImprovement /= errorRateImprovements.size();