]> git.donarmstrong.com Git - mothur.git/blobdiff - randomforest.cpp
adding labels to list file.
[mothur.git] / randomforest.cpp
index 852da372291d00b3e37caabefbea4e4fe3981f3f..acf87dfebcd022d37cf6974f8331c7ad940a017a 100644 (file)
@@ -70,22 +70,31 @@ int RandomForest::printConfusionMatrix(map<int, string> intToTreatmentMap) {
             vector<int>::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end());
             int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin());
             int realOutcome = dataSet[indexOfSample][numFeatures];                       
-            cm[realOutcome][majorityVotedOutcome] = cm[realOutcome][majorityVotedOutcome] + 1;            
+            cm[realOutcome][majorityVotedOutcome] = cm[realOutcome][majorityVotedOutcome] + 1;
+        }
+        
+        vector<int> fw;
+        for (int w = 0; w <numGroups; w++) {
+            fw.push_back(intToTreatmentMap[w].length());
         }
         
         m->mothurOut("confusion matrix:\n\t\t");
-        for (int i = 0; i < numGroups; i++) {
-            m->mothurOut(intToTreatmentMap[i] + "\t");
+        for (int k = 0; k < numGroups; k++) {
+            //m->mothurOut(intToTreatmentMap[k] + "\t");
+            cout << setw(fw[k]) << intToTreatmentMap[k] << "\t";
         }
         for (int i = 0; i < numGroups; i++) {
+            cout << "\n" << setw(fw[i]) << intToTreatmentMap[i] << "\t";
             //m->mothurOut("\n" + intToTreatmentMap[i] + "\t");
             if (m->control_pressed) { return 0; }
             for (int j = 0; j < numGroups; j++) {
-                m->mothurOut(cm[i][j] + "\t");            
+                //m->mothurOut(toString(cm[i][j]) + "\t");
+                cout << setw(fw[i]) << cm[i][j] << "\t";
             }    
         }
-        m->mothurOut("\n");
-    
+        //m->mothurOut("\n");
+        cout << "\n";
+
         return 0;
     }
     
@@ -95,6 +104,38 @@ int RandomForest::printConfusionMatrix(map<int, string> intToTreatmentMap) {
        }
 }
 
+/***********************************************************************/
+
+int RandomForest::getMissclassifications(string filename, map<int, string> intToTreatmentMap, vector<string> names) {
+    try {
+        ofstream out;
+        m->openOutputFile(filename, out);
+        out <<"Sample\tRF classification\tActual classification\n";
+        for (map<int, vector<int> >::iterator it = globalOutOfBagEstimates.begin(); it != globalOutOfBagEstimates.end(); it++) {
+            
+            if (m->control_pressed) { return 0; }
+            
+            int indexOfSample = it->first;
+            vector<int> predictedOutComes = it->second;
+            vector<int>::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end());
+            int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin());
+            int realOutcome = dataSet[indexOfSample][numFeatures];
+                                   
+            if (majorityVotedOutcome != realOutcome) {             
+                out << names[indexOfSample] << "\t" << intToTreatmentMap[majorityVotedOutcome] << "\t" << intToTreatmentMap[realOutcome] << endl;
+                                
+            }
+        }
+        
+        out.close();    
+        return 0;
+    }
+       catch(exception& e) {
+               m->errorOut(e, "RandomForest", "getMissclassifications");
+               exit(1);
+       } 
+}
+
 /***********************************************************************/
 int RandomForest::calcForrestVariableImportance(string filename) {
     try {
@@ -139,7 +180,7 @@ int RandomForest::calcForrestVariableImportance(string filename) {
         m->openOutputFile(filename, out);
         out <<"OTU\tMean decrease accuracy\n";
         for (int i = 0; i < globalVariableRanks.size(); i++) {
-            out << m->currentBinLabels[(int)globalVariableRanks[i].first] << '\t' << globalVariableImportanceList[globalVariableRanks[i].first] << endl;
+            out << m->currentSharedBinLabels[(int)globalVariableRanks[i].first] << '\t' << globalVariableImportanceList[globalVariableRanks[i].first] << endl;
         }
         out.close();
         return 0;