vector<int>::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end());
int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin());
int realOutcome = dataSet[indexOfSample][numFeatures];
-
+
if (majorityVotedOutcome == realOutcome) { numCorrect++; }
}
m->mothurOut("numCorrect = " + toString(numCorrect)+ "\n");
m->mothurOut("forrestErrorRate = " + toString(forrestErrorRate)+ "\n");
-
+
return 0;
}
catch(exception& e) {
exit(1);
}
}
+/***********************************************************************/
+
+int RandomForest::printConfusionMatrix(map<int, string> intToTreatmentMap) {
+ try {
+ int numGroups = intToTreatmentMap.size();
+ vector<vector<int> > cm(numGroups, vector<int>(numGroups, 0));
+
+ for (map<int, vector<int> >::iterator it = globalOutOfBagEstimates.begin(); it != globalOutOfBagEstimates.end(); it++) {
+
+ if (m->control_pressed) { return 0; }
+
+ int indexOfSample = it->first; //key
+ vector<int> predictedOutComes = it->second; //value, vector of all predicted classes
+ vector<int>::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end());
+ int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin());
+ int realOutcome = dataSet[indexOfSample][numFeatures];
+ cm[realOutcome][majorityVotedOutcome] = cm[realOutcome][majorityVotedOutcome] + 1;
+ }
+
+ m->mothurOut("confusion matrix:\n\t\t");
+ for (int i = 0; i < numGroups; i++) {
+ m->mothurOut(intToTreatmentMap[i] + "\t");
+ }
+ for (int i = 0; i < numGroups; i++) {
+ //m->mothurOut("\n" + intToTreatmentMap[i] + "\t");
+ if (m->control_pressed) { return 0; }
+ for (int j = 0; j < numGroups; j++) {
+ m->mothurOut(cm[i][j] + "\t");
+ }
+ }
+ m->mothurOut("\n");
+
+ return 0;
+ }
+
+ catch(exception& e) {
+ m->errorOut(e, "RandomForest", "printConfusionMatrix");
+ exit(1);
+ }
+}
/***********************************************************************/
int RandomForest::calcForrestVariableImportance(string filename) {
ofstream out;
m->openOutputFile(filename, out);
- out <<"OTU\tRank\n";
+ out <<"OTU\tMean decrease accuracy\n";
for (int i = 0; i < globalVariableRanks.size(); i++) {
out << m->currentBinLabels[(int)globalVariableRanks[i].first] << '\t' << globalVariableImportanceList[globalVariableRanks[i].first] << endl;
}