]> git.donarmstrong.com Git - mothur.git/commitdiff
adding code for confusion matrix
authorkdiverson <kd.iverson@gmail.com>
Wed, 11 Sep 2013 04:48:20 +0000 (00:48 -0400)
committerkdiverson <kd.iverson@gmail.com>
Wed, 11 Sep 2013 04:48:20 +0000 (00:48 -0400)
randomforest.cpp

index 2ae0eb595c1c8d5f989c44a5e561a0ff3dac2d2b..852da372291d00b3e37caabefbea4e4fe3981f3f 100644 (file)
@@ -37,7 +37,7 @@ int RandomForest::calcForrestErrorRate() {
             vector<int>::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end());
             int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin());
             int realOutcome = dataSet[indexOfSample][numFeatures];
-            
+                                   
             if (majorityVotedOutcome == realOutcome) { numCorrect++; }
         }
         
@@ -46,7 +46,7 @@ int RandomForest::calcForrestErrorRate() {
         
         m->mothurOut("numCorrect = " + toString(numCorrect)+ "\n");
         m->mothurOut("forrestErrorRate = " + toString(forrestErrorRate)+ "\n");
-    
+            
         return 0;
     }
        catch(exception& e) {
@@ -54,6 +54,46 @@ int RandomForest::calcForrestErrorRate() {
                exit(1);
        } 
 }
+/***********************************************************************/
+
+int RandomForest::printConfusionMatrix(map<int, string> intToTreatmentMap) {
+    try {
+        int numGroups = intToTreatmentMap.size();
+        vector<vector<int> > cm(numGroups, vector<int>(numGroups, 0));
+        
+        for (map<int, vector<int> >::iterator it = globalOutOfBagEstimates.begin(); it != globalOutOfBagEstimates.end(); it++) {
+            
+            if (m->control_pressed) { return 0; }
+            
+            int indexOfSample = it->first; //key
+            vector<int> predictedOutComes = it->second; //value, vector of all predicted classes
+            vector<int>::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end());
+            int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin());
+            int realOutcome = dataSet[indexOfSample][numFeatures];                       
+            cm[realOutcome][majorityVotedOutcome] = cm[realOutcome][majorityVotedOutcome] + 1;            
+        }
+        
+        m->mothurOut("confusion matrix:\n\t\t");
+        for (int i = 0; i < numGroups; i++) {
+            m->mothurOut(intToTreatmentMap[i] + "\t");
+        }
+        for (int i = 0; i < numGroups; i++) {
+            //m->mothurOut("\n" + intToTreatmentMap[i] + "\t");
+            if (m->control_pressed) { return 0; }
+            for (int j = 0; j < numGroups; j++) {
+                m->mothurOut(cm[i][j] + "\t");            
+            }    
+        }
+        m->mothurOut("\n");
+    
+        return 0;
+    }
+    
+    catch(exception& e) {
+               m->errorOut(e, "RandomForest", "printConfusionMatrix");
+               exit(1);
+       }
+}
 
 /***********************************************************************/
 int RandomForest::calcForrestVariableImportance(string filename) {
@@ -97,7 +137,7 @@ int RandomForest::calcForrestVariableImportance(string filename) {
         
         ofstream out;
         m->openOutputFile(filename, out);
-        out <<"OTU\tRank\n";
+        out <<"OTU\tMean decrease accuracy\n";
         for (int i = 0; i < globalVariableRanks.size(); i++) {
             out << m->currentBinLabels[(int)globalVariableRanks[i].first] << '\t' << globalVariableImportanceList[globalVariableRanks[i].first] << endl;
         }