X-Git-Url: https://git.donarmstrong.com/?p=mothur.git;a=blobdiff_plain;f=randomforest.cpp;h=acf87dfebcd022d37cf6974f8331c7ad940a017a;hp=bd96cd2f7177633e3d21181c95e7ff2c07682eb2;hb=499f4ac6e321f9f03d4c3aa25c3b6880892c8b83;hpb=d263f4b3a4f96c672317d317061f6adb72656427 diff --git a/randomforest.cpp b/randomforest.cpp index bd96cd2..acf87df 100644 --- a/randomforest.cpp +++ b/randomforest.cpp @@ -10,8 +10,16 @@ /***********************************************************************/ -RandomForest::RandomForest(const vector > dataSet,const int numDecisionTrees, - const string treeSplitCriterion = "informationGain") : Forest(dataSet, numDecisionTrees, treeSplitCriterion) { +RandomForest::RandomForest(const vector > dataSet, + const int numDecisionTrees, + const string treeSplitCriterion = "gainratio", + const bool doPruning = false, + const float pruneAggressiveness = 0.9, + const bool discardHighErrorTrees = true, + const float highErrorTreeDiscardThreshold = 0.4, + const string optimumFeatureSubsetSelectionCriteria = "log2", + const float featureStandardDeviationThreshold = 0.0) + : Forest(dataSet, numDecisionTrees, treeSplitCriterion, doPruning, pruneAggressiveness, discardHighErrorTrees, highErrorTreeDiscardThreshold, optimumFeatureSubsetSelectionCriteria, featureStandardDeviationThreshold) { m = MothurOut::getInstance(); } @@ -29,7 +37,7 @@ int RandomForest::calcForrestErrorRate() { vector::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end()); int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin()); int realOutcome = dataSet[indexOfSample][numFeatures]; - + if (majorityVotedOutcome == realOutcome) { numCorrect++; } } @@ -38,7 +46,7 @@ int RandomForest::calcForrestErrorRate() { m->mothurOut("numCorrect = " + toString(numCorrect)+ "\n"); m->mothurOut("forrestErrorRate = " + toString(forrestErrorRate)+ "\n"); - + return 0; } catch(exception& e) { @@ -46,47 +54,133 @@ int RandomForest::calcForrestErrorRate() { exit(1); } } - /***********************************************************************/ -// DONE -int RandomForest::calcForrestVariableImportance(string filename) { + +int RandomForest::printConfusionMatrix(map intToTreatmentMap) { try { - - // TODO: need to add try/catch operators to fix this - // follow the link: http://en.wikipedia.org/wiki/Dynamic_cast - //if you are going to dynamically cast, aren't you undoing the advantage of abstraction. Why abstract at all? - //could cause maintenance issues later if other types of Abstract decison trees are created that cannot be cast as a decision tree. - for (int i = 0; i < decisionTrees.size(); i++) { - if (m->control_pressed) { return 0; } + int numGroups = intToTreatmentMap.size(); + vector > cm(numGroups, vector(numGroups, 0)); - DecisionTree* decisionTree = dynamic_cast(decisionTrees[i]); + for (map >::iterator it = globalOutOfBagEstimates.begin(); it != globalOutOfBagEstimates.end(); it++) { + + if (m->control_pressed) { return 0; } + + int indexOfSample = it->first; //key + vector predictedOutComes = it->second; //value, vector of all predicted classes + vector::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end()); + int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin()); + int realOutcome = dataSet[indexOfSample][numFeatures]; + cm[realOutcome][majorityVotedOutcome] = cm[realOutcome][majorityVotedOutcome] + 1; + } - for (int j = 0; j < numFeatures; j++) { - globalVariableImportanceList[j] += (double)decisionTree->variableImportanceList[j]; + vector fw; + for (int w = 0; w mothurOut("confusion matrix:\n\t\t"); + for (int k = 0; k < numGroups; k++) { + //m->mothurOut(intToTreatmentMap[k] + "\t"); + cout << setw(fw[k]) << intToTreatmentMap[k] << "\t"; + } + for (int i = 0; i < numGroups; i++) { + cout << "\n" << setw(fw[i]) << intToTreatmentMap[i] << "\t"; + //m->mothurOut("\n" + intToTreatmentMap[i] + "\t"); + if (m->control_pressed) { return 0; } + for (int j = 0; j < numGroups; j++) { + //m->mothurOut(toString(cm[i][j]) + "\t"); + cout << setw(fw[i]) << cm[i][j] << "\t"; + } + } + //m->mothurOut("\n"); + cout << "\n"; + + return 0; } - for (int i = 0; i < numFeatures; i++) { - cout << "[" << i << ',' << globalVariableImportanceList[i] << "], "; - globalVariableImportanceList[i] /= (double)numDecisionTrees; - } - - vector< vector > globalVariableRanks; - for (int i = 0; i < globalVariableImportanceList.size(); i++) { - if (globalVariableImportanceList[i] > 0) { - vector globalVariableRank(2, 0); - globalVariableRank[0] = i; globalVariableRank[1] = globalVariableImportanceList[i]; - globalVariableRanks.push_back(globalVariableRank); + catch(exception& e) { + m->errorOut(e, "RandomForest", "printConfusionMatrix"); + exit(1); + } +} + +/***********************************************************************/ + +int RandomForest::getMissclassifications(string filename, map intToTreatmentMap, vector names) { + try { + ofstream out; + m->openOutputFile(filename, out); + out <<"Sample\tRF classification\tActual classification\n"; + for (map >::iterator it = globalOutOfBagEstimates.begin(); it != globalOutOfBagEstimates.end(); it++) { + + if (m->control_pressed) { return 0; } + + int indexOfSample = it->first; + vector predictedOutComes = it->second; + vector::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end()); + int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin()); + int realOutcome = dataSet[indexOfSample][numFeatures]; + + if (majorityVotedOutcome != realOutcome) { + out << names[indexOfSample] << "\t" << intToTreatmentMap[majorityVotedOutcome] << "\t" << intToTreatmentMap[realOutcome] << endl; + + } } + + out.close(); + return 0; } + catch(exception& e) { + m->errorOut(e, "RandomForest", "getMissclassifications"); + exit(1); + } +} + +/***********************************************************************/ +int RandomForest::calcForrestVariableImportance(string filename) { + try { - VariableRankDescendingSorterDouble variableRankDescendingSorter; - sort(globalVariableRanks.begin(), globalVariableRanks.end(), variableRankDescendingSorter); + // follow the link: http://en.wikipedia.org/wiki/Dynamic_cast + //if you are going to dynamically cast, aren't you undoing the advantage of abstraction. Why abstract at all? + //could cause maintenance issues later if other types of Abstract decison trees are created that cannot be cast as a decision tree. + for (int i = 0; i < decisionTrees.size(); i++) { + if (m->control_pressed) { return 0; } + + DecisionTree* decisionTree = dynamic_cast(decisionTrees[i]); + + for (int j = 0; j < numFeatures; j++) { + globalVariableImportanceList[j] += (double)decisionTree->variableImportanceList[j]; + } + } + + for (int i = 0; i < numFeatures; i++) { + globalVariableImportanceList[i] /= (double)numDecisionTrees; + } + + vector< pair > globalVariableRanks; + for (int i = 0; i < globalVariableImportanceList.size(); i++) { + //cout << "[" << i << ',' << globalVariableImportanceList[i] << "], "; + if (globalVariableImportanceList[i] > 0) { + pair globalVariableRank(0, 0.0); + globalVariableRank.first = i; + globalVariableRank.second = globalVariableImportanceList[i]; + globalVariableRanks.push_back(globalVariableRank); + } + } + +// for (int i = 0; i < globalVariableRanks.size(); i++) { +// cout << m->currentBinLabels[(int)globalVariableRanks[i][0]] << '\t' << globalVariableImportanceList[globalVariableRanks[i][0]] << endl; +// } + + + VariableRankDescendingSorterDouble variableRankDescendingSorter; + sort(globalVariableRanks.begin(), globalVariableRanks.end(), variableRankDescendingSorter); + ofstream out; m->openOutputFile(filename, out); - out <<"OTU\tRank\n"; + out <<"OTU\tMean decrease accuracy\n"; for (int i = 0; i < globalVariableRanks.size(); i++) { - out << m->currentBinLabels[(int)globalVariableRanks[i][0]] << '\t' << globalVariableImportanceList[globalVariableRanks[i][0]] << endl; + out << m->currentSharedBinLabels[(int)globalVariableRanks[i].first] << '\t' << globalVariableImportanceList[globalVariableRanks[i].first] << endl; } out.close(); return 0; @@ -97,28 +191,90 @@ int RandomForest::calcForrestVariableImportance(string filename) { } } /***********************************************************************/ -// DONE int RandomForest::populateDecisionTrees() { try { + vector errorRateImprovements; + for (int i = 0; i < numDecisionTrees; i++) { + if (m->control_pressed) { return 0; } - if (((i+1) % 10) == 0) { m->mothurOut("Creating " + toString(i+1) + " (th) Decision tree\n"); } + if (((i+1) % 100) == 0) { m->mothurOut("Creating " + toString(i+1) + " (th) Decision tree\n"); } + // TODO: need to first fix if we are going to use pointer based system or anything else - DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector("log2"), treeSplitCriterion); - decisionTree->calcTreeVariableImportanceAndError(); - if (m->control_pressed) { return 0; } - updateGlobalOutOfBagEstimates(decisionTree); - if (m->control_pressed) { return 0; } - decisionTree->purgeDataSetsFromTree(); - if (m->control_pressed) { return 0; } - decisionTrees.push_back(decisionTree); + DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector(optimumFeatureSubsetSelectionCriteria), treeSplitCriterion, featureStandardDeviationThreshold); + + if (m->debug && doPruning) { + m->mothurOut("Before pruning\n"); + decisionTree->printTree(decisionTree->rootNode, "ROOT"); + } + + int numCorrect; + double treeErrorRate; + + decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate); + double prePrunedErrorRate = treeErrorRate; + + if (m->debug) { + m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n"); + } + + if (doPruning) { + decisionTree->pruneTree(pruneAggressiveness); + if (m->debug) { + m->mothurOut("After pruning\n"); + decisionTree->printTree(decisionTree->rootNode, "ROOT"); + } + decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate); + } + double postPrunedErrorRate = treeErrorRate; + + + decisionTree->calcTreeVariableImportanceAndError(numCorrect, treeErrorRate); + double errorRateImprovement = (prePrunedErrorRate - postPrunedErrorRate) / prePrunedErrorRate; + + if (m->debug) { + m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n"); + if (doPruning) { + m->mothurOut("errorRateImprovement: " + toString(errorRateImprovement) + "\n"); + } + } + + + if (discardHighErrorTrees) { + if (treeErrorRate < highErrorTreeDiscardThreshold) { + updateGlobalOutOfBagEstimates(decisionTree); + decisionTree->purgeDataSetsFromTree(); + decisionTrees.push_back(decisionTree); + if (doPruning) { + errorRateImprovements.push_back(errorRateImprovement); + } + } else { + delete decisionTree; + } + } else { + updateGlobalOutOfBagEstimates(decisionTree); + decisionTree->purgeDataSetsFromTree(); + decisionTrees.push_back(decisionTree); + if (doPruning) { + errorRateImprovements.push_back(errorRateImprovement); + } + } } - if (m->debug) { - // m->mothurOut("globalOutOfBagEstimates = " + toStringVectorMap(globalOutOfBagEstimates)+ "\n"); + double avgErrorRateImprovement = -1.0; + if (errorRateImprovements.size() > 0) { + avgErrorRateImprovement = accumulate(errorRateImprovements.begin(), errorRateImprovements.end(), 0.0); +// cout << "Total " << avgErrorRateImprovement << " size " << errorRateImprovements.size() << endl; + avgErrorRateImprovement /= errorRateImprovements.size(); } + if (m->debug && doPruning) { + m->mothurOut("avgErrorRateImprovement:" + toString(avgErrorRateImprovement) + "\n"); + } + // m->mothurOut("globalOutOfBagEstimates = " + toStringVectorMap(globalOutOfBagEstimates)+ "\n"); + + return 0; } catch(exception& e) {