From: kdiverson Date: Wed, 11 Sep 2013 04:48:20 +0000 (-0400) Subject: adding code for confusion matrix X-Git-Url: https://git.donarmstrong.com/?p=mothur.git;a=commitdiff_plain;h=b682d361a9d59e832a0bd9dcc76aee39769b89e7 adding code for confusion matrix --- diff --git a/randomforest.cpp b/randomforest.cpp index 2ae0eb5..852da37 100644 --- a/randomforest.cpp +++ b/randomforest.cpp @@ -37,7 +37,7 @@ int RandomForest::calcForrestErrorRate() { vector::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end()); int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin()); int realOutcome = dataSet[indexOfSample][numFeatures]; - + if (majorityVotedOutcome == realOutcome) { numCorrect++; } } @@ -46,7 +46,7 @@ int RandomForest::calcForrestErrorRate() { m->mothurOut("numCorrect = " + toString(numCorrect)+ "\n"); m->mothurOut("forrestErrorRate = " + toString(forrestErrorRate)+ "\n"); - + return 0; } catch(exception& e) { @@ -54,6 +54,46 @@ int RandomForest::calcForrestErrorRate() { exit(1); } } +/***********************************************************************/ + +int RandomForest::printConfusionMatrix(map intToTreatmentMap) { + try { + int numGroups = intToTreatmentMap.size(); + vector > cm(numGroups, vector(numGroups, 0)); + + for (map >::iterator it = globalOutOfBagEstimates.begin(); it != globalOutOfBagEstimates.end(); it++) { + + if (m->control_pressed) { return 0; } + + int indexOfSample = it->first; //key + vector predictedOutComes = it->second; //value, vector of all predicted classes + vector::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end()); + int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin()); + int realOutcome = dataSet[indexOfSample][numFeatures]; + cm[realOutcome][majorityVotedOutcome] = cm[realOutcome][majorityVotedOutcome] + 1; + } + + m->mothurOut("confusion matrix:\n\t\t"); + for (int i = 0; i < numGroups; i++) { + m->mothurOut(intToTreatmentMap[i] + "\t"); + } + for (int i = 0; i < numGroups; i++) { + //m->mothurOut("\n" + intToTreatmentMap[i] + "\t"); + if (m->control_pressed) { return 0; } + for (int j = 0; j < numGroups; j++) { + m->mothurOut(cm[i][j] + "\t"); + } + } + m->mothurOut("\n"); + + return 0; + } + + catch(exception& e) { + m->errorOut(e, "RandomForest", "printConfusionMatrix"); + exit(1); + } +} /***********************************************************************/ int RandomForest::calcForrestVariableImportance(string filename) { @@ -97,7 +137,7 @@ int RandomForest::calcForrestVariableImportance(string filename) { ofstream out; m->openOutputFile(filename, out); - out <<"OTU\tRank\n"; + out <<"OTU\tMean decrease accuracy\n"; for (int i = 0; i < globalVariableRanks.size(); i++) { out << m->currentBinLabels[(int)globalVariableRanks[i].first] << '\t' << globalVariableImportanceList[globalVariableRanks[i].first] << endl; }