X-Git-Url: https://git.donarmstrong.com/?a=blobdiff_plain;f=randomforest.cpp;h=2ae0eb595c1c8d5f989c44a5e561a0ff3dac2d2b;hb=372fb21ea66ced432b109225851a1b80ef0491a3;hp=36a2c1a261f27514c394a370d0d97387ac9cbfac;hpb=035f86272c776e1cccaa47021e26782e49cd41e7;p=mothur.git diff --git a/randomforest.cpp b/randomforest.cpp index 36a2c1a..2ae0eb5 100644 --- a/randomforest.cpp +++ b/randomforest.cpp @@ -10,8 +10,16 @@ /***********************************************************************/ -RandomForest::RandomForest(const vector > dataSet,const int numDecisionTrees, - const string treeSplitCriterion = "informationGain") : AbstractRandomForest(dataSet, numDecisionTrees, treeSplitCriterion) { +RandomForest::RandomForest(const vector > dataSet, + const int numDecisionTrees, + const string treeSplitCriterion = "gainratio", + const bool doPruning = false, + const float pruneAggressiveness = 0.9, + const bool discardHighErrorTrees = true, + const float highErrorTreeDiscardThreshold = 0.4, + const string optimumFeatureSubsetSelectionCriteria = "log2", + const float featureStandardDeviationThreshold = 0.0) + : Forest(dataSet, numDecisionTrees, treeSplitCriterion, doPruning, pruneAggressiveness, discardHighErrorTrees, highErrorTreeDiscardThreshold, optimumFeatureSubsetSelectionCriteria, featureStandardDeviationThreshold) { m = MothurOut::getInstance(); } @@ -48,44 +56,50 @@ int RandomForest::calcForrestErrorRate() { } /***********************************************************************/ -// DONE int RandomForest::calcForrestVariableImportance(string filename) { try { - // TODO: need to add try/catch operators to fix this - // follow the link: http://en.wikipedia.org/wiki/Dynamic_cast + // follow the link: http://en.wikipedia.org/wiki/Dynamic_cast //if you are going to dynamically cast, aren't you undoing the advantage of abstraction. Why abstract at all? //could cause maintenance issues later if other types of Abstract decison trees are created that cannot be cast as a decision tree. - for (int i = 0; i < decisionTrees.size(); i++) { - if (m->control_pressed) { return 0; } - DecisionTree* decisionTree = dynamic_cast(decisionTrees[i]); + for (int i = 0; i < decisionTrees.size(); i++) { + if (m->control_pressed) { return 0; } + + DecisionTree* decisionTree = dynamic_cast(decisionTrees[i]); + + for (int j = 0; j < numFeatures; j++) { + globalVariableImportanceList[j] += (double)decisionTree->variableImportanceList[j]; + } + } - for (int j = 0; j < numFeatures; j++) { - globalVariableImportanceList[j] += (double)decisionTree->variableImportanceList[j]; + for (int i = 0; i < numFeatures; i++) { + globalVariableImportanceList[i] /= (double)numDecisionTrees; } - } - - for (int i = 0; i < numFeatures; i++) { - cout << "[" << i << ',' << globalVariableImportanceList[i] << "], "; - globalVariableImportanceList[i] /= (double)numDecisionTrees; - } - - vector< vector > globalVariableRanks; - for (int i = 0; i < globalVariableImportanceList.size(); i++) { - if (globalVariableImportanceList[i] > 0) { - vector globalVariableRank(2, 0); - globalVariableRank[0] = i; globalVariableRank[1] = globalVariableImportanceList[i]; - globalVariableRanks.push_back(globalVariableRank); + + vector< pair > globalVariableRanks; + for (int i = 0; i < globalVariableImportanceList.size(); i++) { + //cout << "[" << i << ',' << globalVariableImportanceList[i] << "], "; + if (globalVariableImportanceList[i] > 0) { + pair globalVariableRank(0, 0.0); + globalVariableRank.first = i; + globalVariableRank.second = globalVariableImportanceList[i]; + globalVariableRanks.push_back(globalVariableRank); + } } - } - - VariableRankDescendingSorterDouble variableRankDescendingSorter; - sort(globalVariableRanks.begin(), globalVariableRanks.end(), variableRankDescendingSorter); + +// for (int i = 0; i < globalVariableRanks.size(); i++) { +// cout << m->currentBinLabels[(int)globalVariableRanks[i][0]] << '\t' << globalVariableImportanceList[globalVariableRanks[i][0]] << endl; +// } + + + VariableRankDescendingSorterDouble variableRankDescendingSorter; + sort(globalVariableRanks.begin(), globalVariableRanks.end(), variableRankDescendingSorter); + ofstream out; m->openOutputFile(filename, out); out <<"OTU\tRank\n"; for (int i = 0; i < globalVariableRanks.size(); i++) { - out << m->currentBinLabels[(int)globalVariableRanks[i][0]] << '\t' << globalVariableImportanceList[globalVariableRanks[i][0]] << endl; + out << m->currentBinLabels[(int)globalVariableRanks[i].first] << '\t' << globalVariableImportanceList[globalVariableRanks[i].first] << endl; } out.close(); return 0; @@ -96,27 +110,89 @@ int RandomForest::calcForrestVariableImportance(string filename) { } } /***********************************************************************/ -// DONE int RandomForest::populateDecisionTrees() { try { + vector errorRateImprovements; + for (int i = 0; i < numDecisionTrees; i++) { + if (m->control_pressed) { return 0; } - if (((i+1) % 10) == 0) { m->mothurOut("Creating " + toString(i+1) + " (th) Decision tree\n"); } + if (((i+1) % 100) == 0) { m->mothurOut("Creating " + toString(i+1) + " (th) Decision tree\n"); } + // TODO: need to first fix if we are going to use pointer based system or anything else - DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector("log2"), treeSplitCriterion); - decisionTree->calcTreeVariableImportanceAndError(); - if (m->control_pressed) { return 0; } - updateGlobalOutOfBagEstimates(decisionTree); - if (m->control_pressed) { return 0; } - decisionTree->purgeDataSetsFromTree(); - if (m->control_pressed) { return 0; } - decisionTrees.push_back(decisionTree); + DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector(optimumFeatureSubsetSelectionCriteria), treeSplitCriterion, featureStandardDeviationThreshold); + + if (m->debug && doPruning) { + m->mothurOut("Before pruning\n"); + decisionTree->printTree(decisionTree->rootNode, "ROOT"); + } + + int numCorrect; + double treeErrorRate; + + decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate); + double prePrunedErrorRate = treeErrorRate; + + if (m->debug) { + m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n"); + } + + if (doPruning) { + decisionTree->pruneTree(pruneAggressiveness); + if (m->debug) { + m->mothurOut("After pruning\n"); + decisionTree->printTree(decisionTree->rootNode, "ROOT"); + } + decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate); + } + double postPrunedErrorRate = treeErrorRate; + + + decisionTree->calcTreeVariableImportanceAndError(numCorrect, treeErrorRate); + double errorRateImprovement = (prePrunedErrorRate - postPrunedErrorRate) / prePrunedErrorRate; + + if (m->debug) { + m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n"); + if (doPruning) { + m->mothurOut("errorRateImprovement: " + toString(errorRateImprovement) + "\n"); + } + } + + + if (discardHighErrorTrees) { + if (treeErrorRate < highErrorTreeDiscardThreshold) { + updateGlobalOutOfBagEstimates(decisionTree); + decisionTree->purgeDataSetsFromTree(); + decisionTrees.push_back(decisionTree); + if (doPruning) { + errorRateImprovements.push_back(errorRateImprovement); + } + } else { + delete decisionTree; + } + } else { + updateGlobalOutOfBagEstimates(decisionTree); + decisionTree->purgeDataSetsFromTree(); + decisionTrees.push_back(decisionTree); + if (doPruning) { + errorRateImprovements.push_back(errorRateImprovement); + } + } + } + + double avgErrorRateImprovement = -1.0; + if (errorRateImprovements.size() > 0) { + avgErrorRateImprovement = accumulate(errorRateImprovements.begin(), errorRateImprovements.end(), 0.0); +// cout << "Total " << avgErrorRateImprovement << " size " << errorRateImprovements.size() << endl; + avgErrorRateImprovement /= errorRateImprovements.size(); } - if (m->debug) { - // m->mothurOut("globalOutOfBagEstimates = " + toStringVectorMap(globalOutOfBagEstimates)+ "\n"); + if (m->debug && doPruning) { + m->mothurOut("avgErrorRateImprovement:" + toString(avgErrorRateImprovement) + "\n"); } + // m->mothurOut("globalOutOfBagEstimates = " + toStringVectorMap(globalOutOfBagEstimates)+ "\n"); + return 0; } @@ -127,7 +203,7 @@ int RandomForest::populateDecisionTrees() { } /***********************************************************************/ // TODO: need to finalize bettween reference and pointer for DecisionTree [partially solved] -// TODO: make this pure virtual in superclass +// DONE: make this pure virtual in superclass // DONE int RandomForest::updateGlobalOutOfBagEstimates(DecisionTree* decisionTree) { try {