]> git.donarmstrong.com Git - mothur.git/blobdiff - classifysharedcommand.cpp
changing command name classify.shared to classifyrf.shared
[mothur.git] / classifysharedcommand.cpp
index f964937b4ded24bc8f9eb001f1a8557f2ffcfd77..c7eb6cd0daa18f5627824a94d297270a1ff32147 100755 (executable)
 vector<string> ClassifySharedCommand::setParameters(){ 
        try {
                //CommandParameter pprocessors("processors", "Number", "", "1", "", "", "",false,false); parameters.push_back(pprocessors);        
-        CommandParameter pshared("shared", "InputTypes", "", "", "none", "none", "none",false,true); parameters.push_back(pshared);            
-        CommandParameter pdesign("design", "InputTypes", "", "", "none", "none", "none",false,true); parameters.push_back(pdesign);    
-        CommandParameter potupersplit("otupersplit", "Multiple", "log2-squareroot", "log2", "", "", "",false,false); parameters.push_back(potupersplit);
-        CommandParameter psplitcriteria("splitcriteria", "Multiple", "gainratio-infogain", "gainratio", "", "", "",false,false); parameters.push_back(psplitcriteria);
-               CommandParameter pnumtrees("numtrees", "Number", "", "100", "", "", "",false,false); parameters.push_back(pnumtrees);
+        CommandParameter pshared("shared", "InputTypes", "", "", "none", "none", "none","summary",false,true,true); parameters.push_back(pshared);             
+        CommandParameter pdesign("design", "InputTypes", "", "", "none", "none", "none","",false,true,true); parameters.push_back(pdesign);    
+        CommandParameter potupersplit("otupersplit", "Multiple", "log2-squareroot", "log2", "", "", "","",false,false); parameters.push_back(potupersplit);
+        CommandParameter psplitcriteria("splitcriteria", "Multiple", "gainratio-infogain", "gainratio", "", "", "","",false,false); parameters.push_back(psplitcriteria);
+               CommandParameter pnumtrees("numtrees", "Number", "", "100", "", "", "","",false,false); parameters.push_back(pnumtrees);
+        
+            // parameters related to pruning
+        CommandParameter pdopruning("prune", "Boolean", "", "T", "", "", "", "", false, false); parameters.push_back(pdopruning);
+        CommandParameter ppruneaggrns("pruneaggressiveness", "Number", "", "0.9", "", "", "", "", false, false); parameters.push_back(ppruneaggrns);
+        CommandParameter pdiscardhetrees("discarderrortrees", "Boolean", "", "T", "", "", "", "", false, false); parameters.push_back(pdiscardhetrees);
+        CommandParameter phetdiscardthreshold("errorthreshold", "Number", "", "0.4", "", "", "", "", false, false); parameters.push_back(phetdiscardthreshold);
+        CommandParameter psdthreshold("stdthreshold", "Number", "", "0.0", "", "", "", "", false, false); parameters.push_back(psdthreshold);
+            // pruning params end
 
-        CommandParameter pgroups("groups", "String", "", "", "", "", "",false,false); parameters.push_back(pgroups);
-               CommandParameter plabel("label", "String", "", "", "", "", "",false,false); parameters.push_back(plabel);
-               CommandParameter pinputdir("inputdir", "String", "", "", "", "", "",false,false); parameters.push_back(pinputdir);
-               CommandParameter poutputdir("outputdir", "String", "", "", "", "", "",false,false); parameters.push_back(poutputdir);
+        CommandParameter pgroups("groups", "String", "", "", "", "", "","",false,false); parameters.push_back(pgroups);
+               CommandParameter plabel("label", "String", "", "", "", "", "","",false,false); parameters.push_back(plabel);
+               CommandParameter pinputdir("inputdir", "String", "", "", "", "", "","",false,false); parameters.push_back(pinputdir);
+               CommandParameter poutputdir("outputdir", "String", "", "", "", "", "","",false,false); parameters.push_back(poutputdir);
                
                vector<string> myArray;
                for (int i = 0; i < parameters.size(); i++) {   myArray.push_back(parameters[i].name);          }
@@ -53,24 +61,19 @@ string ClassifySharedCommand::getHelpString(){
        }
 }
 //**********************************************************************************************************************
-string ClassifySharedCommand::getOutputFileNameTag(string type, string inputName=""){  
-       try {
-        string tag = "";
-               map<string, vector<string> >::iterator it;
+string ClassifySharedCommand::getOutputPattern(string type) {
+    try {
+        string pattern = "";
         
-        //is this a type this command creates
-        it = outputTypes.find(type);
-        if (it == outputTypes.end()) {  m->mothurOut("[ERROR]: this command doesn't create a " + type + " output file.\n"); }
-        else {
-            if (type == "summary") {  tag = "summary"; }
-            else { m->mothurOut("[ERROR]: No definition for type " + type + " output file tag.\n"); m->control_pressed = true;  }
-        }
-        return tag;
-       }
-       catch(exception& e) {
-               m->errorOut(e, "ClassifySharedCommand", "getOutputFileName");
-               exit(1);
-       }
+        if (type == "summary") {  pattern = "[filename],[distance],summary"; } //makes file like: amazon.0.03.fasta
+        else { m->mothurOut("[ERROR]: No definition for type " + type + " output pattern.\n"); m->control_pressed = true;  }
+        
+        return pattern;
+    }
+    catch(exception& e) {
+        m->errorOut(e, "ClassifySharedCommand", "getOutputPattern");
+        exit(1);
+    }
 }
 //**********************************************************************************************************************
 
@@ -86,6 +89,7 @@ ClassifySharedCommand::ClassifySharedCommand() {
     exit(1);
   }
 }
+
 //**********************************************************************************************************************
 ClassifySharedCommand::ClassifySharedCommand(string option) {
   try {
@@ -109,7 +113,6 @@ ClassifySharedCommand::ClassifySharedCommand(string option) {
       for (it = parameters.begin(); it != parameters.end(); it++) {
         if (validParameter.isValidParameter(it->first, myArray, it->second) != true) {  abort = true;  }
       }
-        
         vector<string> tempOutNames;
         outputTypes["summary"] = tempOutNames;
       
@@ -135,7 +138,6 @@ ClassifySharedCommand::ClassifySharedCommand(string option) {
         }
         
       }
-       
         //check for parameters
         //get shared file, it is required
       sharedfile = validParameter.validFile(parameters, "shared", true);
@@ -163,25 +165,51 @@ ClassifySharedCommand::ClassifySharedCommand(string option) {
         outputDir = m->hasPath(sharedfile); //if user entered a file with a path then preserve it
       }
       
-    
         // NEW CODE for OTU per split selection criteria
-      otupersplit = validParameter.validFile(parameters, "otupersplit", false);
-      if (otupersplit == "not found") { otupersplit = "log2"; }
-      if ((otupersplit == "squareroot") || (otupersplit == "log2")) {
-        optimumFeatureSubsetSelectionCriteria = otupersplit;
-      }else { m->mothurOut("Not a valid OTU per split selection method. Valid OTU per split selection methods are 'log2' and 'squareroot'."); m->mothurOutEndLine(); abort = true; }
-      
-        // splitcriteria
-      splitcriteria = validParameter.validFile(parameters, "splitcriteria", false);
-      if (splitcriteria == "not found") { splitcriteria = "gainratio"; }
-      if ((splitcriteria == "gainratio") || (splitcriteria == "infogain")) {
-        treeSplitCriterion = splitcriteria;
-      }else { m->mothurOut("Not a valid tree splitting criterio. Valid tree splitting criteria are 'gainratio' and 'infogain'."); m->mothurOutEndLine(); abort = true; }
-      
-      
-      string temp = validParameter.validFile(parameters, "numtrees", false); if (temp == "not found"){ temp = "100";   }
-      m->mothurConvert(temp, numDecisionTrees);
-
+        string temp = validParameter.validFile(parameters, "splitcriteria", false);
+        if (temp == "not found") { temp = "gainratio"; }
+        if ((temp == "gainratio") || (temp == "infogain")) {
+            treeSplitCriterion = temp;
+        } else { m->mothurOut("Not a valid tree splitting criterio. Valid tree splitting criteria are 'gainratio' and 'infogain'.");
+            m->mothurOutEndLine();
+            abort = true;
+        }
+        
+        temp = validParameter.validFile(parameters, "numtrees", false); if (temp == "not found"){      temp = "100";   }
+        m->mothurConvert(temp, numDecisionTrees);
+        
+            // parameters for pruning
+        temp = validParameter.validFile(parameters, "prune", false);
+        if (temp == "not found") { temp = "f"; }
+        doPruning = m->isTrue(temp);
+        
+        temp = validParameter.validFile(parameters, "pruneaggressiveness", false);
+        if (temp == "not found") { temp = "0.9"; }
+        m->mothurConvert(temp, pruneAggressiveness);
+        
+        temp = validParameter.validFile(parameters, "discarderrortrees", false);
+        if (temp == "not found") { temp = "f"; }
+        discardHighErrorTrees = m->isTrue(temp);
+        
+        temp = validParameter.validFile(parameters, "errorthreshold", false);
+        if (temp == "not found") { temp = "0.4"; }
+        m->mothurConvert(temp, highErrorTreeDiscardThreshold);
+        
+        temp = validParameter.validFile(parameters, "otupersplit", false);
+        if (temp == "not found") { temp = "log2"; }
+        if ((temp == "squareroot") || (temp == "log2")) {
+            optimumFeatureSubsetSelectionCriteria = temp;
+        } else { m->mothurOut("Not a valid OTU per split selection method. Valid OTU per split selection methods are 'log2' and 'squareroot'.");
+            m->mothurOutEndLine();
+            abort = true;
+        }
+        
+        temp = validParameter.validFile(parameters, "stdthreshold", false);
+        if (temp == "not found") { temp = "0.0"; }
+        m->mothurConvert(temp, featureStandardDeviationThreshold);
+                        
+            // end of pruning params
+        
         //Groups must be checked later to make sure they are valid. SharedUtilities has functions of check the validity, just make to so m->setGroups() after the checks.  If you are using these with a shared file no need to check the SharedRAbundVector class will call SharedUtilites for you, kinda nice, huh?
       string groups = validParameter.validFile(parameters, "groups", false);
       if (groups == "not found") { groups = ""; }
@@ -240,7 +268,6 @@ int ClassifySharedCommand::execute() {
         for (int i = 0; i < lookup.size(); i++) {  delete lookup[i];  }
         lookup = input.getSharedRAbundVectors(lastLabel);
         m->mothurOut(lookup[0]->getLabel()); m->mothurOutEndLine();
-           
         processSharedAndDesignData(lookup);        
         
         processedLabels.insert(lookup[0]->getLabel());
@@ -332,8 +359,11 @@ void ClassifySharedCommand::processSharedAndDesignData(vector<SharedRAbundVector
         
         vector< vector<int> > dataSet(numRows, vector<int>(numColumns, 0));
         
+        vector<string> names;
+        
         for (int i = 0; i < lookup.size(); i++) {
             string sharedGroupName = lookup[i]->getGroup();
+            names.push_back(sharedGroupName);
             string treatmentName = designMap.getGroup(sharedGroupName);
             
             int j = 0;
@@ -344,15 +374,28 @@ void ClassifySharedCommand::processSharedAndDesignData(vector<SharedRAbundVector
             dataSet[i][j] = treatmentToIntMap[treatmentName];
         }
         
-        RandomForest randomForest(dataSet, numDecisionTrees, treeSplitCriterion);
+        RandomForest randomForest(dataSet, numDecisionTrees, treeSplitCriterion, doPruning, pruneAggressiveness, discardHighErrorTrees, highErrorTreeDiscardThreshold, optimumFeatureSubsetSelectionCriteria, featureStandardDeviationThreshold);
+        
         randomForest.populateDecisionTrees();
         randomForest.calcForrestErrorRate();
+        randomForest.printConfusionMatrix(intToTreatmentMap);
         
-        string filename = outputDir + m->getRootName(m->getSimpleName(sharedfile)) + lookup[0]->getLabel() + "." + getOutputFileNameTag("summary");
+        map<string, string> variables; 
+        variables["[filename]"] = outputDir + m->getRootName(m->getSimpleName(sharedfile)) + "RF.";
+        variables["[distance]"] = lookup[0]->getLabel();
+        string filename = getOutputFileName("summary", variables);
         outputNames.push_back(filename); outputTypes["summary"].push_back(filename);
-        
         randomForest.calcForrestVariableImportance(filename);
         
+        //
+        map<string, string> variable; 
+        variable["[filename]"] = outputDir + m->getRootName(m->getSimpleName(sharedfile)) + "misclassifications.";
+        variable["[distance]"] = lookup[0]->getLabel();
+        string mc_filename = getOutputFileName("summary", variable);
+        outputNames.push_back(mc_filename); outputTypes["summary"].push_back(mc_filename);
+        randomForest.getMissclassifications(mc_filename, intToTreatmentMap, names);
+        //
+        
         m->mothurOutEndLine();
     }
     catch(exception& e) {