]> git.donarmstrong.com Git - mothur.git/blob - classifyrfsharedcommand.cpp
changing command name classify.shared to classifyrf.shared
[mothur.git] / classifyrfsharedcommand.cpp
1 //
2 //  classifysharedcommand.cpp
3 //  Mothur
4 //
5 //  Created by Abu Zaher Md. Faridee on 8/13/12.
6 //  Copyright (c) 2012 Schloss Lab. All rights reserved.
7 //
8
9 #include "classifyrfsharedcommand.h"
10 #include "randomforest.hpp"
11 #include "decisiontree.hpp"
12 #include "rftreenode.hpp"
13
14 //**********************************************************************************************************************
15 vector<string> ClassifyRFSharedCommand::setParameters(){        
16         try {
17                 //CommandParameter pprocessors("processors", "Number", "", "1", "", "", "",false,false); parameters.push_back(pprocessors);        
18         CommandParameter pshared("shared", "InputTypes", "", "", "none", "none", "none","summary",false,true,true); parameters.push_back(pshared);              
19         CommandParameter pdesign("design", "InputTypes", "", "", "none", "none", "none","",false,true,true); parameters.push_back(pdesign);     
20         CommandParameter potupersplit("otupersplit", "Multiple", "log2-squareroot", "log2", "", "", "","",false,false); parameters.push_back(potupersplit);
21         CommandParameter psplitcriteria("splitcriteria", "Multiple", "gainratio-infogain", "gainratio", "", "", "","",false,false); parameters.push_back(psplitcriteria);
22                 CommandParameter pnumtrees("numtrees", "Number", "", "100", "", "", "","",false,false); parameters.push_back(pnumtrees);
23         
24             // parameters related to pruning
25         CommandParameter pdopruning("prune", "Boolean", "", "T", "", "", "", "", false, false); parameters.push_back(pdopruning);
26         CommandParameter ppruneaggrns("pruneaggressiveness", "Number", "", "0.9", "", "", "", "", false, false); parameters.push_back(ppruneaggrns);
27         CommandParameter pdiscardhetrees("discarderrortrees", "Boolean", "", "T", "", "", "", "", false, false); parameters.push_back(pdiscardhetrees);
28         CommandParameter phetdiscardthreshold("errorthreshold", "Number", "", "0.4", "", "", "", "", false, false); parameters.push_back(phetdiscardthreshold);
29         CommandParameter psdthreshold("stdthreshold", "Number", "", "0.0", "", "", "", "", false, false); parameters.push_back(psdthreshold);
30             // pruning params end
31
32         CommandParameter pgroups("groups", "String", "", "", "", "", "","",false,false); parameters.push_back(pgroups);
33                 CommandParameter plabel("label", "String", "", "", "", "", "","",false,false); parameters.push_back(plabel);
34                 CommandParameter pinputdir("inputdir", "String", "", "", "", "", "","",false,false); parameters.push_back(pinputdir);
35                 CommandParameter poutputdir("outputdir", "String", "", "", "", "", "","",false,false); parameters.push_back(poutputdir);
36                 
37                 vector<string> myArray;
38                 for (int i = 0; i < parameters.size(); i++) {   myArray.push_back(parameters[i].name);          }
39                 return myArray;
40         }
41         catch(exception& e) {
42                 m->errorOut(e, "ClassifySharedCommand", "setParameters");
43                 exit(1);
44         }
45 }
46 //**********************************************************************************************************************
47 string ClassifyRFSharedCommand::getHelpString(){        
48         try {
49                 string helpString = "";
50                 helpString += "The classify.shared command allows you to ....\n";
51                 helpString += "The classify.shared command parameters are: shared, design, label, groups, otupersplit.\n";
52         helpString += "The label parameter is used to analyze specific labels in your input.\n";
53                 helpString += "The groups parameter allows you to specify which of the groups in your designfile you would like analyzed.\n";
54                 helpString += "The classify.shared should be in the following format: \n";
55                 helpString += "classify.shared(shared=yourSharedFile, design=yourDesignFile)\n";
56                 return helpString;
57         }
58         catch(exception& e) {
59                 m->errorOut(e, "ClassifySharedCommand", "getHelpString");
60                 exit(1);
61         }
62 }
63 //**********************************************************************************************************************
64 string ClassifyRFSharedCommand::getOutputPattern(string type) {
65     try {
66         string pattern = "";
67         
68         if (type == "summary") {  pattern = "[filename],[distance],summary"; } //makes file like: amazon.0.03.fasta
69         else { m->mothurOut("[ERROR]: No definition for type " + type + " output pattern.\n"); m->control_pressed = true;  }
70         
71         return pattern;
72     }
73     catch(exception& e) {
74         m->errorOut(e, "ClassifySharedCommand", "getOutputPattern");
75         exit(1);
76     }
77 }
78 //**********************************************************************************************************************
79
80 ClassifyRFSharedCommand::ClassifyRFSharedCommand() {
81   try {
82     abort = true; calledHelp = true;
83     setParameters();
84     vector<string> tempOutNames;
85     outputTypes["summary"] = tempOutNames; 
86   }
87   catch(exception& e) {
88     m->errorOut(e, "ClassifySharedCommand", "ClassifySharedCommand");
89     exit(1);
90   }
91 }
92
93 //**********************************************************************************************************************
94 ClassifyRFSharedCommand::ClassifyRFSharedCommand(string option) {
95   try {
96     abort = false; calledHelp = false;
97     allLines = 1;
98       
99       //allow user to run help
100     if(option == "help") { help(); abort = true; calledHelp = true; }
101     else if(option == "citation") { citation(); abort = true; calledHelp = true;}
102     
103     else {
104         //valid paramters for this command
105       vector<string> myArray = setParameters();
106       
107       OptionParser parser(option);
108       map<string,string> parameters = parser.getParameters();
109       
110       ValidParameters validParameter;
111       map<string,string>::iterator it;
112         //check to make sure all parameters are valid for command
113       for (it = parameters.begin(); it != parameters.end(); it++) {
114         if (validParameter.isValidParameter(it->first, myArray, it->second) != true) {  abort = true;  }
115       }
116         vector<string> tempOutNames;
117         outputTypes["summary"] = tempOutNames;
118       
119         //if the user changes the input directory command factory will send this info to us in the output parameter
120       string inputDir = validParameter.validFile(parameters, "inputdir", false);
121       if (inputDir == "not found"){     inputDir = "";          }
122       else {
123         string path;
124         it = parameters.find("shared");
125           //user has given a shared file
126         if(it != parameters.end()){
127           path = m->hasPath(it->second);
128             //if the user has not given a path then, add inputdir. else leave path alone.
129           if (path == "") {     parameters["shared"] = inputDir + it->second;           }
130         }
131         
132         it = parameters.find("design");
133           //user has given a design file
134         if(it != parameters.end()){
135           path = m->hasPath(it->second);
136             //if the user has not given a path then, add inputdir. else leave path alone.
137           if (path == "") {     parameters["design"] = inputDir + it->second;           }
138         }
139         
140       }
141         //check for parameters
142         //get shared file, it is required
143       sharedfile = validParameter.validFile(parameters, "shared", true);
144       if (sharedfile == "not open") { sharedfile = ""; abort = true; }
145       else if (sharedfile == "not found") {
146           //if there is a current shared file, use it
147         sharedfile = m->getSharedFile();
148         if (sharedfile != "") { m->mothurOut("Using " + sharedfile + " as input file for the shared parameter."); m->mothurOutEndLine(); }
149         else {  m->mothurOut("You have no current sharedfile and the shared parameter is required."); m->mothurOutEndLine(); abort = true; }
150       }else { m->setSharedFile(sharedfile); }
151       
152         //get design file, it is required
153       designfile = validParameter.validFile(parameters, "design", true);
154       if (designfile == "not open") { sharedfile = ""; abort = true; }
155       else if (designfile == "not found") {
156           //if there is a current shared file, use it
157         designfile = m->getDesignFile();
158         if (designfile != "") { m->mothurOut("Using " + designfile + " as input file for the design parameter."); m->mothurOutEndLine(); }
159         else {  m->mothurOut("You have no current designfile and the design parameter is required."); m->mothurOutEndLine(); abort = true; }
160       }else { m->setDesignFile(designfile); }
161
162       
163         //if the user changes the output directory command factory will send this info to us in the output parameter
164       outputDir = validParameter.validFile(parameters, "outputdir", false);             if (outputDir == "not found"){
165         outputDir = m->hasPath(sharedfile); //if user entered a file with a path then preserve it
166       }
167       
168         // NEW CODE for OTU per split selection criteria
169         string temp = validParameter.validFile(parameters, "splitcriteria", false);
170         if (temp == "not found") { temp = "gainratio"; }
171         if ((temp == "gainratio") || (temp == "infogain")) {
172             treeSplitCriterion = temp;
173         } else { m->mothurOut("Not a valid tree splitting criterio. Valid tree splitting criteria are 'gainratio' and 'infogain'.");
174             m->mothurOutEndLine();
175             abort = true;
176         }
177         
178         temp = validParameter.validFile(parameters, "numtrees", false); if (temp == "not found"){       temp = "100";   }
179         m->mothurConvert(temp, numDecisionTrees);
180         
181             // parameters for pruning
182         temp = validParameter.validFile(parameters, "prune", false);
183         if (temp == "not found") { temp = "f"; }
184         doPruning = m->isTrue(temp);
185         
186         temp = validParameter.validFile(parameters, "pruneaggressiveness", false);
187         if (temp == "not found") { temp = "0.9"; }
188         m->mothurConvert(temp, pruneAggressiveness);
189         
190         temp = validParameter.validFile(parameters, "discarderrortrees", false);
191         if (temp == "not found") { temp = "f"; }
192         discardHighErrorTrees = m->isTrue(temp);
193         
194         temp = validParameter.validFile(parameters, "errorthreshold", false);
195         if (temp == "not found") { temp = "0.4"; }
196         m->mothurConvert(temp, highErrorTreeDiscardThreshold);
197         
198         temp = validParameter.validFile(parameters, "otupersplit", false);
199         if (temp == "not found") { temp = "log2"; }
200         if ((temp == "squareroot") || (temp == "log2")) {
201             optimumFeatureSubsetSelectionCriteria = temp;
202         } else { m->mothurOut("Not a valid OTU per split selection method. Valid OTU per split selection methods are 'log2' and 'squareroot'.");
203             m->mothurOutEndLine();
204             abort = true;
205         }
206         
207         temp = validParameter.validFile(parameters, "stdthreshold", false);
208         if (temp == "not found") { temp = "0.0"; }
209         m->mothurConvert(temp, featureStandardDeviationThreshold);
210                         
211             // end of pruning params
212         
213         //Groups must be checked later to make sure they are valid. SharedUtilities has functions of check the validity, just make to so m->setGroups() after the checks.  If you are using these with a shared file no need to check the SharedRAbundVector class will call SharedUtilites for you, kinda nice, huh?
214       string groups = validParameter.validFile(parameters, "groups", false);
215       if (groups == "not found") { groups = ""; }
216       else { m->splitAtDash(groups, Groups); }
217       m->setGroups(Groups);
218       
219         //Commonly used to process list, rabund, sabund, shared and relabund files.  Look at "smart distancing" examples below in the execute function.
220       string label = validParameter.validFile(parameters, "label", false);
221       if (label == "not found") { label = ""; }
222       else {
223         if(label != "all") {  m->splitAtDash(label, labels);  allLines = 0;  }
224         else { allLines = 1;  }
225       }
226     }
227     
228   }
229   catch(exception& e) {
230     m->errorOut(e, "ClassifySharedCommand", "ClassifySharedCommand");
231     exit(1);
232   }
233 }
234 //**********************************************************************************************************************
235 int ClassifyRFSharedCommand::execute() {
236   try {
237     
238     if (abort == true) { if (calledHelp) { return 0; }  return 2;       }
239     
240     InputData input(sharedfile, "sharedfile");
241     vector<SharedRAbundVector*> lookup = input.getSharedRAbundVectors();
242         
243     //read design file
244     designMap.readDesignMap(designfile);
245     
246     string lastLabel = lookup[0]->getLabel();
247     set<string> processedLabels;
248     set<string> userLabels = labels;
249     
250       //as long as you are not at the end of the file or done wih the lines you want
251     while((lookup[0] != NULL) && ((allLines == 1) || (userLabels.size() != 0))) {
252       
253       if (m->control_pressed) { for (int i = 0; i < lookup.size(); i++) {  delete lookup[i];  }  return 0; }
254       
255       if(allLines == 1 || labels.count(lookup[0]->getLabel()) == 1){
256         
257         m->mothurOut(lookup[0]->getLabel()); m->mothurOutEndLine();
258         
259         processSharedAndDesignData(lookup);  
260           
261         processedLabels.insert(lookup[0]->getLabel());
262         userLabels.erase(lookup[0]->getLabel());
263       }
264       
265       if ((m->anyLabelsToProcess(lookup[0]->getLabel(), userLabels, "") == true) && (processedLabels.count(lastLabel) != 1)) {
266         string saveLabel = lookup[0]->getLabel();
267         
268         for (int i = 0; i < lookup.size(); i++) {  delete lookup[i];  }
269         lookup = input.getSharedRAbundVectors(lastLabel);
270         m->mothurOut(lookup[0]->getLabel()); m->mothurOutEndLine();
271         processSharedAndDesignData(lookup);        
272         
273         processedLabels.insert(lookup[0]->getLabel());
274         userLabels.erase(lookup[0]->getLabel());
275         
276           //restore real lastlabel to save below
277         lookup[0]->setLabel(saveLabel);
278       }
279       
280       lastLabel = lookup[0]->getLabel();
281         //prevent memory leak
282       for (int i = 0; i < lookup.size(); i++) {  delete lookup[i]; lookup[i] = NULL; }
283       
284       if (m->control_pressed) { return 0; }
285       
286         //get next line to process
287       lookup = input.getSharedRAbundVectors();
288     }
289     
290     if (m->control_pressed) {  return 0; }
291     
292       //output error messages about any remaining user labels
293     set<string>::iterator it;
294     bool needToRun = false;
295     for (it = userLabels.begin(); it != userLabels.end(); it++) {
296       m->mothurOut("Your file does not include the label " + *it);
297       if (processedLabels.count(lastLabel) != 1) {
298         m->mothurOut(". I will use " + lastLabel + "."); m->mothurOutEndLine();
299         needToRun = true;
300       }else {
301         m->mothurOut(". Please refer to " + lastLabel + "."); m->mothurOutEndLine();
302       }
303     }
304     
305       //run last label if you need to
306     if (needToRun == true)  {
307       for (int i = 0; i < lookup.size(); i++) { if (lookup[i] != NULL) { delete lookup[i]; } }
308       lookup = input.getSharedRAbundVectors(lastLabel);
309       
310       m->mothurOut(lookup[0]->getLabel()); m->mothurOutEndLine();
311       
312       processSharedAndDesignData(lookup);  
313         
314       for (int i = 0; i < lookup.size(); i++) {  delete lookup[i];  }
315       
316     }
317
318       m->mothurOutEndLine();
319       m->mothurOut("Output File Names: "); m->mothurOutEndLine();
320       for (int i = 0; i < outputNames.size(); i++) {    m->mothurOut(outputNames[i]); m->mothurOutEndLine();    }
321       m->mothurOutEndLine();
322       
323     return 0;
324     
325   }
326   catch(exception& e) {
327     m->errorOut(e, "ClassifySharedCommand", "execute");
328     exit(1);
329   }
330 }
331 //**********************************************************************************************************************
332
333 void ClassifyRFSharedCommand::processSharedAndDesignData(vector<SharedRAbundVector*> lookup){  
334     try {
335 //    for (int i = 0; i < designMap->getNamesOfGroups().size(); i++) {
336 //      string groupName = designMap->getNamesOfGroups()[i];
337 //      cout << groupName << endl;
338 //    }
339
340 //    for (int i = 0; i < designMap->getNumSeqs(); i++) {
341 //      string sharedGroupName = designMap->getNamesSeqs()[i];
342 //      string treatmentName = designMap->getGroup(sharedGroupName);
343 //      cout << sharedGroupName << " : " << treatmentName <<  endl;
344 //    }
345   
346         map<string, int> treatmentToIntMap;
347         map<int, string> intToTreatmentMap;
348         for (int  i = 0; i < designMap.getNumGroups(); i++) {
349             string treatmentName = designMap.getNamesOfGroups()[i];
350             treatmentToIntMap[treatmentName] = i;
351             intToTreatmentMap[i] = treatmentName;
352         }
353         
354         int numSamples = lookup.size();
355         int numFeatures = lookup[0]->getNumBins();
356         
357         int numRows = numSamples;
358         int numColumns = numFeatures + 1;           // extra one space needed for the treatment/outcome
359         
360         vector< vector<int> > dataSet(numRows, vector<int>(numColumns, 0));
361         
362         vector<string> names;
363         
364         for (int i = 0; i < lookup.size(); i++) {
365             string sharedGroupName = lookup[i]->getGroup();
366             names.push_back(sharedGroupName);
367             string treatmentName = designMap.getGroup(sharedGroupName);
368             
369             int j = 0;
370             for (; j < lookup[i]->getNumBins(); j++) {
371                 int otuCount = lookup[i]->getAbundance(j);
372                 dataSet[i][j] = otuCount;
373             }
374             dataSet[i][j] = treatmentToIntMap[treatmentName];
375         }
376         
377         RandomForest randomForest(dataSet, numDecisionTrees, treeSplitCriterion, doPruning, pruneAggressiveness, discardHighErrorTrees, highErrorTreeDiscardThreshold, optimumFeatureSubsetSelectionCriteria, featureStandardDeviationThreshold);
378         
379         randomForest.populateDecisionTrees();
380         randomForest.calcForrestErrorRate();
381         randomForest.printConfusionMatrix(intToTreatmentMap);
382         
383         map<string, string> variables; 
384         variables["[filename]"] = outputDir + m->getRootName(m->getSimpleName(sharedfile)) + "RF.";
385         variables["[distance]"] = lookup[0]->getLabel();
386         string filename = getOutputFileName("summary", variables);
387         outputNames.push_back(filename); outputTypes["summary"].push_back(filename);
388         randomForest.calcForrestVariableImportance(filename);
389         
390         //
391         map<string, string> variable; 
392         variable["[filename]"] = outputDir + m->getRootName(m->getSimpleName(sharedfile)) + "misclassifications.";
393         variable["[distance]"] = lookup[0]->getLabel();
394         string mc_filename = getOutputFileName("summary", variable);
395         outputNames.push_back(mc_filename); outputTypes["summary"].push_back(mc_filename);
396         randomForest.getMissclassifications(mc_filename, intToTreatmentMap, names);
397         //
398         
399         m->mothurOutEndLine();
400     }
401     catch(exception& e) {
402         m->errorOut(e, "ClassifySharedCommand", "processSharedAndDesignData");
403         exit(1);
404     }
405 }
406 //**********************************************************************************************************************
407