From 035f86272c776e1cccaa47021e26782e49cd41e7 Mon Sep 17 00:00:00 2001 From: Sarah Westcott Date: Tue, 2 Oct 2012 14:36:18 -0400 Subject: [PATCH] added classify.shared command and random forest files. added count file to pcr.seqs, remove.rare, screen.seqs, sort.seqs, split.abund, and subsample commands --- Mothur.xcodeproj/project.pbxproj | 46 +++ abstractdecisiontree.cpp | 285 +++++++++++++++++ abstractdecisiontree.hpp | 63 ++++ abstractrandomforest.cpp | 58 ++++ abstractrandomforest.hpp | 67 ++++ classifysharedcommand.cpp | 364 ++++++++++++++++++++++ classifysharedcommand.h | 54 ++++ commandfactory.cpp | 6 + countgroupscommand.cpp | 73 ++++- countgroupscommand.h | 2 +- counttable.cpp | 14 +- decisiontree.cpp | 399 ++++++++++++++++++++++++ decisiontree.hpp | 59 ++++ groupmap.cpp | 73 +++++ groupmap.h | 1 + macros.h | 32 ++ mothurout.cpp | 27 +- mothurout.h | 1 + parsefastaqcommand.cpp | 37 ++- parsefastaqcommand.h | 3 +- pcrseqscommand.h | 4 +- prcseqscommand.cpp | 107 ++++++- randomforest.cpp | 156 ++++++++++ randomforest.hpp | 45 +++ removerarecommand.cpp | 92 +++++- removerarecommand.h | 2 +- rftreenode.cpp | 92 ++++++ rftreenode.hpp | 91 ++++++ screenseqscommand.cpp | 2 +- sharedcommand.cpp | 6 +- sharedrabundvector.h | 1 - sortseqscommand.cpp | 136 ++++++++- sortseqscommand.h | 3 +- splitabundcommand.cpp | 388 +++++++++++++++++------ splitabundcommand.h | 13 +- subsample.cpp | 159 +++++++++- subsample.h | 10 + subsamplecommand.cpp | 510 ++++++++++++++++++------------- subsamplecommand.h | 4 +- trimflowscommand.cpp | 2 + 40 files changed, 3115 insertions(+), 372 deletions(-) create mode 100644 abstractdecisiontree.cpp create mode 100755 abstractdecisiontree.hpp create mode 100644 abstractrandomforest.cpp create mode 100755 abstractrandomforest.hpp create mode 100755 classifysharedcommand.cpp create mode 100755 classifysharedcommand.h create mode 100644 decisiontree.cpp create mode 100755 decisiontree.hpp create mode 100755 macros.h create mode 100644 randomforest.cpp create mode 100755 randomforest.hpp create mode 100644 rftreenode.cpp create mode 100755 rftreenode.hpp diff --git a/Mothur.xcodeproj/project.pbxproj b/Mothur.xcodeproj/project.pbxproj index 37932a4..9d6261b 100644 --- a/Mothur.xcodeproj/project.pbxproj +++ b/Mothur.xcodeproj/project.pbxproj @@ -21,6 +21,10 @@ A721765713BB9F7D0014DAAE /* referencedb.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A721765613BB9F7D0014DAAE /* referencedb.cpp */; }; A724D2B7153C8628000A826F /* makebiomcommand.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A724D2B6153C8628000A826F /* makebiomcommand.cpp */; }; A727864412E9E28C00F86ABA /* removerarecommand.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A727864312E9E28C00F86ABA /* removerarecommand.cpp */; }; + A7386C231619CCE600651424 /* classifysharedcommand.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A7386C211619CCE600651424 /* classifysharedcommand.cpp */; }; + A7386C251619E52300651424 /* abstractdecisiontree.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A7386C241619E52200651424 /* abstractdecisiontree.cpp */; }; + A7386C27161A0F9D00651424 /* abstractrandomforest.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A7386C26161A0F9C00651424 /* abstractrandomforest.cpp */; }; + A7386C29161A110800651424 /* decisiontree.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A7386C28161A110700651424 /* decisiontree.cpp */; }; A73901081588C40900ED2ED6 /* loadlogfilecommand.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A73901071588C40900ED2ED6 /* loadlogfilecommand.cpp */; }; A73DDBBA13C4A0D1006AAE38 /* clearmemorycommand.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A73DDBB913C4A0D1006AAE38 /* clearmemorycommand.cpp */; }; A73DDC3813C4BF64006AAE38 /* mothurmetastats.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A73DDC3713C4BF64006AAE38 /* mothurmetastats.cpp */; }; @@ -37,6 +41,8 @@ A77410F614697C300098E6AC /* seqnoise.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A77410F414697C300098E6AC /* seqnoise.cpp */; }; A778FE6B134CA6CA00C0BA33 /* getcommandinfocommand.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A778FE6A134CA6CA00C0BA33 /* getcommandinfocommand.cpp */; }; A77A221F139001B600B0BE70 /* deuniquetreecommand.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A77A221E139001B600B0BE70 /* deuniquetreecommand.cpp */; }; + A77E1938161B201E00DB1A2A /* randomforest.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A77E1937161B201E00DB1A2A /* randomforest.cpp */; }; + A77E193B161B289600DB1A2A /* rftreenode.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A77E193A161B289600DB1A2A /* rftreenode.cpp */; }; A77EBD2F1523709100ED407C /* createdatabasecommand.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A77EBD2E1523709100ED407C /* createdatabasecommand.cpp */; }; A7876A26152A017C00A0AE86 /* subsample.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A7876A25152A017C00A0AE86 /* subsample.cpp */; }; A79234D713C74BF6002B08E2 /* mothurfisher.cpp in Sources */ = {isa = PBXBuildFile; fileRef = A79234D613C74BF6002B08E2 /* mothurfisher.cpp */; }; @@ -393,6 +399,17 @@ A724D2B6153C8628000A826F /* makebiomcommand.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = makebiomcommand.cpp; sourceTree = ""; }; A727864212E9E28C00F86ABA /* removerarecommand.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = removerarecommand.h; sourceTree = ""; }; A727864312E9E28C00F86ABA /* removerarecommand.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = removerarecommand.cpp; sourceTree = ""; }; + A7386C1B1619CACB00651424 /* abstractdecisiontree.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = abstractdecisiontree.hpp; sourceTree = ""; }; + A7386C1C1619CACB00651424 /* abstractrandomforest.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = abstractrandomforest.hpp; sourceTree = ""; }; + A7386C1D1619CACB00651424 /* decisiontree.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = decisiontree.hpp; sourceTree = ""; }; + A7386C1E1619CACB00651424 /* macros.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = macros.h; sourceTree = ""; }; + A7386C1F1619CACB00651424 /* randomforest.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = randomforest.hpp; sourceTree = ""; }; + A7386C201619CACB00651424 /* rftreenode.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = rftreenode.hpp; sourceTree = ""; }; + A7386C211619CCE600651424 /* classifysharedcommand.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = classifysharedcommand.cpp; sourceTree = ""; }; + A7386C221619CCE600651424 /* classifysharedcommand.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = classifysharedcommand.h; sourceTree = ""; }; + A7386C241619E52200651424 /* abstractdecisiontree.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = abstractdecisiontree.cpp; sourceTree = ""; }; + A7386C26161A0F9C00651424 /* abstractrandomforest.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = abstractrandomforest.cpp; sourceTree = ""; }; + A7386C28161A110700651424 /* decisiontree.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = decisiontree.cpp; sourceTree = ""; }; A73901051588C3EF00ED2ED6 /* loadlogfilecommand.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = loadlogfilecommand.h; sourceTree = ""; }; A73901071588C40900ED2ED6 /* loadlogfilecommand.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = loadlogfilecommand.cpp; sourceTree = ""; }; A73DDBB813C4A0D1006AAE38 /* clearmemorycommand.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = clearmemorycommand.h; sourceTree = ""; }; @@ -425,6 +442,8 @@ A778FE6A134CA6CA00C0BA33 /* getcommandinfocommand.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = getcommandinfocommand.cpp; sourceTree = ""; }; A77A221D139001B600B0BE70 /* deuniquetreecommand.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = deuniquetreecommand.h; sourceTree = ""; }; A77A221E139001B600B0BE70 /* deuniquetreecommand.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = deuniquetreecommand.cpp; sourceTree = ""; }; + A77E1937161B201E00DB1A2A /* randomforest.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = randomforest.cpp; sourceTree = ""; }; + A77E193A161B289600DB1A2A /* rftreenode.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = rftreenode.cpp; sourceTree = ""; }; A77EBD2C1523707F00ED407C /* createdatabasecommand.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = createdatabasecommand.h; sourceTree = ""; }; A77EBD2E1523709100ED407C /* createdatabasecommand.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = createdatabasecommand.cpp; sourceTree = ""; }; A7876A25152A017C00A0AE86 /* subsample.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = subsample.cpp; sourceTree = ""; }; @@ -1132,6 +1151,7 @@ A7E9B79B12D37EC400DA6239 /* progress.cpp */, A7E9B79C12D37EC400DA6239 /* progress.hpp */, A7E9B7A512D37EC400DA6239 /* rarecalc.cpp */, + A7386C191619C9FB00651424 /* randomforest */, A7E9B7A612D37EC400DA6239 /* rarecalc.h */, A7E9B7A712D37EC400DA6239 /* raredisplay.cpp */, A7E9B7A812D37EC400DA6239 /* raredisplay.h */, @@ -1173,6 +1193,24 @@ name = Products; sourceTree = ""; }; + A7386C191619C9FB00651424 /* randomforest */ = { + isa = PBXGroup; + children = ( + A7386C1B1619CACB00651424 /* abstractdecisiontree.hpp */, + A7386C241619E52200651424 /* abstractdecisiontree.cpp */, + A7386C1C1619CACB00651424 /* abstractrandomforest.hpp */, + A7386C26161A0F9C00651424 /* abstractrandomforest.cpp */, + A7386C1D1619CACB00651424 /* decisiontree.hpp */, + A7386C28161A110700651424 /* decisiontree.cpp */, + A7386C1E1619CACB00651424 /* macros.h */, + A7386C1F1619CACB00651424 /* randomforest.hpp */, + A77E1937161B201E00DB1A2A /* randomforest.cpp */, + A7386C201619CACB00651424 /* rftreenode.hpp */, + A77E193A161B289600DB1A2A /* rftreenode.cpp */, + ); + name = randomforest; + sourceTree = ""; + }; A7D161E7149F7F50000523E8 /* fortran */ = { isa = PBXGroup; children = ( @@ -1229,6 +1267,8 @@ A7E9B69012D37EC400DA6239 /* classifyotucommand.cpp */, A7E9B69312D37EC400DA6239 /* classifyseqscommand.h */, A7E9B69212D37EC400DA6239 /* classifyseqscommand.cpp */, + A7386C221619CCE600651424 /* classifysharedcommand.h */, + A7386C211619CCE600651424 /* classifysharedcommand.cpp */, A7EEB0F714F29C1B00344B83 /* classifytreecommand.h */, A7EEB0F414F29BFD00344B83 /* classifytreecommand.cpp */, A7E9B69712D37EC400DA6239 /* clearcutcommand.h */, @@ -2204,6 +2244,12 @@ A7E0243D15B4520A00A5F046 /* sparsedistancematrix.cpp in Sources */, A741FAD215D1688E0067BCC5 /* sequencecountparser.cpp in Sources */, A7C7DAB915DA758B0059B0CF /* sffmultiplecommand.cpp in Sources */, + A7386C231619CCE600651424 /* classifysharedcommand.cpp in Sources */, + A7386C251619E52300651424 /* abstractdecisiontree.cpp in Sources */, + A7386C27161A0F9D00651424 /* abstractrandomforest.cpp in Sources */, + A7386C29161A110800651424 /* decisiontree.cpp in Sources */, + A77E1938161B201E00DB1A2A /* randomforest.cpp in Sources */, + A77E193B161B289600DB1A2A /* rftreenode.cpp in Sources */, ); runOnlyForDeploymentPostprocessing = 0; }; diff --git a/abstractdecisiontree.cpp b/abstractdecisiontree.cpp new file mode 100644 index 0000000..085cd31 --- /dev/null +++ b/abstractdecisiontree.cpp @@ -0,0 +1,285 @@ +// +// abstractdecisiontree.cpp +// Mothur +// +// Created by Sarah Westcott on 10/1/12. +// Copyright (c) 2012 Schloss Lab. All rights reserved. +// + +#include "abstractdecisiontree.hpp" + +/**************************************************************************************************/ + +AbstractDecisionTree::AbstractDecisionTree(vector >baseDataSet, + vector globalDiscardedFeatureIndices, + OptimumFeatureSubsetSelector optimumFeatureSubsetSelector, + string treeSplitCriterion) : baseDataSet(baseDataSet), +numSamples((int)baseDataSet.size()), +numFeatures((int)(baseDataSet[0].size() - 1)), +numOutputClasses(0), +rootNode(NULL), +globalDiscardedFeatureIndices(globalDiscardedFeatureIndices), +optimumFeatureSubsetSize(optimumFeatureSubsetSelector.getOptimumFeatureSubsetSize(numFeatures)), +treeSplitCriterion(treeSplitCriterion) { + + try { + // TODO: istead of calculating this for every DecisionTree + // clacualte this once in the RandomForest class and pass the values + m = MothurOut::getInstance(); + for (int i = 0; i < numSamples; i++) { + if (m->control_pressed) { break; } + int outcome = baseDataSet[i][numFeatures]; + vector::iterator it = find(outputClasses.begin(), outputClasses.end(), outcome); + if (it == outputClasses.end()){ // find() will return classes.end() if the element is not found + outputClasses.push_back(outcome); + numOutputClasses++; + } + } + + if (m->debug) { + //m->mothurOut("outputClasses = " + toStringVectorInt(outputClasses)); + m->mothurOut("numOutputClasses = " + toString(numOutputClasses) + '\n'); + } + + } + catch(exception& e) { + m->errorOut(e, "AbstractDecisionTree", "AbstractDecisionTree"); + exit(1); + } +} +/**************************************************************************************************/ +int AbstractDecisionTree::createBootStrappedSamples(){ + try { + vector isInTrainingSamples(numSamples, false); + + for (int i = 0; i < numSamples; i++) { + if (m->control_pressed) { return 0; } + // TODO: optimize the rand() function call + double check if it's working properly + int randomIndex = rand() % numSamples; + bootstrappedTrainingSamples.push_back(baseDataSet[randomIndex]); + isInTrainingSamples[randomIndex] = true; + } + + for (int i = 0; i < numSamples; i++) { + if (m->control_pressed) { return 0; } + if (isInTrainingSamples[i]){ bootstrappedTrainingSampleIndices.push_back(i); } + else{ + bootstrappedTestSamples.push_back(baseDataSet[i]); + bootstrappedTestSampleIndices.push_back(i); + } + } + + return 0; + } + catch(exception& e) { + m->errorOut(e, "AbstractDecisionTree", "createBootStrappedSamples"); + exit(1); + } +} +/**************************************************************************************************/ +int AbstractDecisionTree::getMinEntropyOfFeature(vector featureVector, vector outputVector, double& minEntropy, int& featureSplitValue, double& intrinsicValue){ + try { + + vector< vector > featureOutputPair(featureVector.size(), vector(2, 0)); + for (int i = 0; i < featureVector.size(); i++) { + if (m->control_pressed) { return 0; } + featureOutputPair[i][0] = featureVector[i]; + featureOutputPair[i][1] = outputVector[i]; + } + // TODO: using default behavior to sort(), need to specify the comparator for added safety and compiler portability + sort(featureOutputPair.begin(), featureOutputPair.end()); + + + vector splitPoints; + vector uniqueFeatureValues(1, featureOutputPair[0][0]); + + for (int i = 0; i < featureOutputPair.size(); i++) { + if (m->control_pressed) { return 0; } + int featureValue = featureOutputPair[i][0]; + vector::iterator it = find(uniqueFeatureValues.begin(), uniqueFeatureValues.end(), featureValue); + if (it == uniqueFeatureValues.end()){ // NOT FOUND + uniqueFeatureValues.push_back(featureValue); + splitPoints.push_back(i); + } + } + + + + int bestSplitIndex = -1; + if (splitPoints.size() == 0){ + // TODO: trying out C++'s infitinity, don't know if this will work properly + // TODO: check the caller function of this function, there check the value if minEntropy and comapre to inf + // so that no wrong calculation is done + minEntropy = numeric_limits::infinity(); // OUTPUT + intrinsicValue = numeric_limits::infinity(); // OUTPUT + featureSplitValue = -1; // OUTPUT + }else{ + getBestSplitAndMinEntropy(featureOutputPair, splitPoints, minEntropy, bestSplitIndex, intrinsicValue); // OUTPUT + featureSplitValue = featureOutputPair[splitPoints[bestSplitIndex]][0]; // OUTPUT + } + + return 0; + } + catch(exception& e) { + m->errorOut(e, "AbstractDecisionTree", "getMinEntropyOfFeature"); + exit(1); + } +} +/**************************************************************************************************/ +double AbstractDecisionTree::calcIntrinsicValue(int numLessThanValueAtSplitPoint, int numGreaterThanValueAtSplitPoint, int numSamples) { + try { + double upperSplitEntropy = 0.0, lowerSplitEntropy = 0.0; + if (numLessThanValueAtSplitPoint > 0) { + upperSplitEntropy = numLessThanValueAtSplitPoint * log2((double) numLessThanValueAtSplitPoint / (double) numSamples); + } + + if (numGreaterThanValueAtSplitPoint > 0) { + lowerSplitEntropy = numGreaterThanValueAtSplitPoint * log2((double) numGreaterThanValueAtSplitPoint / (double) numSamples); + } + + double intrinsicValue = - ((double)(upperSplitEntropy + lowerSplitEntropy) / (double)numSamples); + return intrinsicValue; + } + catch(exception& e) { + m->errorOut(e, "AbstractDecisionTree", "calcIntrinsicValue"); + exit(1); + } +} +/**************************************************************************************************/ +int AbstractDecisionTree::getBestSplitAndMinEntropy(vector< vector > featureOutputPairs, vector splitPoints, + double& minEntropy, int& minEntropyIndex, double& relatedIntrinsicValue){ + try { + + int numSamples = (int)featureOutputPairs.size(); + vector entropies; + vector intrinsicValues; + + for (int i = 0; i < splitPoints.size(); i++) { + if (m->control_pressed) { return 0; } + int index = splitPoints[i]; + int valueAtSplitPoint = featureOutputPairs[index][0]; + int numLessThanValueAtSplitPoint = 0; + int numGreaterThanValueAtSplitPoint = 0; + + for (int j = 0; j < featureOutputPairs.size(); j++) { + if (m->control_pressed) { return 0; } + vector record = featureOutputPairs[j]; + if (record[0] < valueAtSplitPoint){ numLessThanValueAtSplitPoint++; } + else{ numGreaterThanValueAtSplitPoint++; } + } + + double upperEntropyOfSplit = calcSplitEntropy(featureOutputPairs, index, numOutputClasses, true); + double lowerEntropyOfSplit = calcSplitEntropy(featureOutputPairs, index, numOutputClasses, false); + + double totalEntropy = (numLessThanValueAtSplitPoint * upperEntropyOfSplit + numGreaterThanValueAtSplitPoint * lowerEntropyOfSplit) / (double)numSamples; + double intrinsicValue = calcIntrinsicValue(numLessThanValueAtSplitPoint, numGreaterThanValueAtSplitPoint, numSamples); + entropies.push_back(totalEntropy); + intrinsicValues.push_back(intrinsicValue); + + } + + // set output values + vector::iterator it = min_element(entropies.begin(), entropies.end()); + minEntropy = *it; // OUTPUT + minEntropyIndex = (int)(it - entropies.begin()); // OUTPUT + relatedIntrinsicValue = intrinsicValues[minEntropyIndex]; // OUTPUT + + return 0; + } + catch(exception& e) { + m->errorOut(e, "AbstractDecisionTree", "getBestSplitAndMinEntropy"); + exit(1); + } +} +/**************************************************************************************************/ + +double AbstractDecisionTree::calcSplitEntropy(vector< vector > featureOutputPairs, int splitIndex, int numOutputClasses, bool isUpperSplit = true) { + try { + vector classCounts(numOutputClasses, 0); + + if (isUpperSplit) { + for (int i = 0; i < splitIndex; i++) { + if (m->control_pressed) { return 0; } + classCounts[featureOutputPairs[i][1]]++; + } + } else { + for (int i = splitIndex; i < featureOutputPairs.size(); i++) { + if (m->control_pressed) { return 0; } + classCounts[featureOutputPairs[i][1]]++; + } + } + + int totalClassCounts = accumulate(classCounts.begin(), classCounts.end(), 0); + + double splitEntropy = 0.0; + + for (int i = 0; i < classCounts.size(); i++) { + if (m->control_pressed) { return 0; } + if (classCounts[i] == 0) { continue; } + double probability = (double) classCounts[i] / (double) totalClassCounts; + splitEntropy += -(probability * log2(probability)); + } + + return splitEntropy; + } + catch(exception& e) { + m->errorOut(e, "AbstractDecisionTree", "calcSplitEntropy"); + exit(1); + } +} + +/**************************************************************************************************/ + +int AbstractDecisionTree::getSplitPopulation(RFTreeNode* node, vector< vector >& leftChildSamples, vector< vector >& rightChildSamples){ + try { + // TODO: there is a possibility of optimization if we can recycle the samples in each nodes + // we just need to pointers to the samples i.e. vector and use it everywhere and not create the sample + // sample over and over again + // we need to make this const so that it is not modified by all the function calling + // currently purgeTreeNodesDataRecursively() is used for the same purpose, but this can be avoided altogher + // if re-using the same data over the classes + + int splitFeatureGlobalIndex = node->getSplitFeatureIndex(); + + for (int i = 0; i < node->getBootstrappedTrainingSamples().size(); i++) { + if (m->control_pressed) { return 0; } + vector sample = node->getBootstrappedTrainingSamples()[i]; + if (m->control_pressed) { return 0; } + if (sample[splitFeatureGlobalIndex] < node->getSplitFeatureValue()){ leftChildSamples.push_back(sample); } + else{ rightChildSamples.push_back(sample); } + } + + return 0; + } + catch(exception& e) { + m->errorOut(e, "AbstractDecisionTree", "getSplitPopulation"); + exit(1); + } +} +/**************************************************************************************************/ +// TODO: checkIfAlreadyClassified() verify code +// TODO: use bootstrappedOutputVector for easier calculation instead of using getBootstrappedTrainingSamples() +bool AbstractDecisionTree::checkIfAlreadyClassified(RFTreeNode* treeNode, int& outputClass) { + try { + + vector tempOutputClasses; + for (int i = 0; i < treeNode->getBootstrappedTrainingSamples().size(); i++) { + if (m->control_pressed) { return 0; } + int sampleOutputClass = treeNode->getBootstrappedTrainingSamples()[i][numFeatures]; + vector::iterator it = find(tempOutputClasses.begin(), tempOutputClasses.end(), sampleOutputClass); + if (it == tempOutputClasses.end()) { // NOT FOUND + tempOutputClasses.push_back(sampleOutputClass); + } + } + + if (tempOutputClasses.size() < 2) { outputClass = tempOutputClasses[0]; return true; } + else { outputClass = -1; return false; } + + } + catch(exception& e) { + m->errorOut(e, "AbstractDecisionTree", "checkIfAlreadyClassified"); + exit(1); + } +} + +/**************************************************************************************************/ diff --git a/abstractdecisiontree.hpp b/abstractdecisiontree.hpp new file mode 100755 index 0000000..3445db4 --- /dev/null +++ b/abstractdecisiontree.hpp @@ -0,0 +1,63 @@ +// +// abstractdecisiontree.hpp +// rrf-fs-prototype +// +// Created by Abu Zaher Faridee on 7/22/12. +// Copyright (c) 2012 Schloss Lab. All rights reserved. +// + +#ifndef rrf_fs_prototype_abstractdecisiontree_hpp +#define rrf_fs_prototype_abstractdecisiontree_hpp + +#include "mothurout.h" +#include "macros.h" +#include "rftreenode.hpp" + +#define DEBUG_MODE + +/**************************************************************************************************/ + +class AbstractDecisionTree{ + +public: + + AbstractDecisionTree(vector >baseDataSet, + vector globalDiscardedFeatureIndices, + OptimumFeatureSubsetSelector optimumFeatureSubsetSelector, + string treeSplitCriterion); + virtual ~AbstractDecisionTree(){} + + +protected: + + virtual int createBootStrappedSamples(); + virtual int getMinEntropyOfFeature(vector featureVector, vector outputVector, double& minEntropy, int& featureSplitValue, double& intrinsicValue); + virtual int getBestSplitAndMinEntropy(vector< vector > featureOutputPairs, vector splitPoints, double& minEntropy, int& minEntropyIndex, double& relatedIntrinsicValue); + virtual double calcIntrinsicValue(int numLessThanValueAtSplitPoint, int numGreaterThanValueAtSplitPoint, int numSamples); + virtual double calcSplitEntropy(vector< vector > featureOutputPairs, int splitIndex, int numOutputClasses, bool); + virtual int getSplitPopulation(RFTreeNode* node, vector< vector >& leftChildSamples, vector< vector >& rightChildSamples); + virtual bool checkIfAlreadyClassified(RFTreeNode* treeNode, int& outputClass); + + vector< vector > baseDataSet; + int numSamples; + int numFeatures; + int numOutputClasses; + vector outputClasses; + vector< vector > bootstrappedTrainingSamples; + vector bootstrappedTrainingSampleIndices; + vector< vector > bootstrappedTestSamples; + vector bootstrappedTestSampleIndices; + + RFTreeNode* rootNode; + vector globalDiscardedFeatureIndices; + int optimumFeatureSubsetSize; + string treeSplitCriterion; + MothurOut* m; + +private: + + +}; +/**************************************************************************************************/ + +#endif diff --git a/abstractrandomforest.cpp b/abstractrandomforest.cpp new file mode 100644 index 0000000..ae60b77 --- /dev/null +++ b/abstractrandomforest.cpp @@ -0,0 +1,58 @@ +// +// abstractrandomforest.cpp +// Mothur +// +// Created by Sarah Westcott on 10/1/12. +// Copyright (c) 2012 Schloss Lab. All rights reserved. +// + +#include "abstractrandomforest.hpp" + +/***********************************************************************/ +AbstractRandomForest::AbstractRandomForest(const std::vector < std::vector > dataSet, + const int numDecisionTrees, + const string treeSplitCriterion = "informationGain") +: dataSet(dataSet), +numDecisionTrees(numDecisionTrees), +numSamples((int)dataSet.size()), +numFeatures((int)(dataSet[0].size() - 1)), +globalDiscardedFeatureIndices(getGlobalDiscardedFeatureIndices()), +globalVariableImportanceList(numFeatures, 0), +treeSplitCriterion(treeSplitCriterion) { + m = MothurOut::getInstance(); + // TODO: double check if the implemenatation of 'globalOutOfBagEstimates' is correct +} + +/***********************************************************************/ + +vector AbstractRandomForest::getGlobalDiscardedFeatureIndices() { + try { + vector globalDiscardedFeatureIndices; + + // calculate feature vectors + vector< vector > featureVectors(numFeatures, vector(numSamples, 0)); + for (int i = 0; i < numSamples; i++) { + if (m->control_pressed) { return globalDiscardedFeatureIndices; } + for (int j = 0; j < numFeatures; j++) { featureVectors[j][i] = dataSet[i][j]; } + } + + for (int i = 0; i < featureVectors.size(); i++) { + if (m->control_pressed) { return globalDiscardedFeatureIndices; } + double standardDeviation = m->getStandardDeviation(featureVectors[i]); + if (standardDeviation <= 0){ globalDiscardedFeatureIndices.push_back(i); } + } + + if (m->debug) { + m->mothurOut("number of global discarded features: " + toString(globalDiscardedFeatureIndices.size())+ "\n"); + m->mothurOut("total features: " + toString(featureVectors.size())+ "\n"); + } + + return globalDiscardedFeatureIndices; + } + catch(exception& e) { + m->errorOut(e, "AbstractRandomForest", "getGlobalDiscardedFeatureIndices"); + exit(1); + } +} + +/***********************************************************************/ \ No newline at end of file diff --git a/abstractrandomforest.hpp b/abstractrandomforest.hpp new file mode 100755 index 0000000..3be91b9 --- /dev/null +++ b/abstractrandomforest.hpp @@ -0,0 +1,67 @@ +// +// abstractrandomforest.hpp +// rrf-fs-prototype +// +// Created by Abu Zaher Faridee on 7/20/12. +// Copyright (c) 2012 Schloss Lab. All rights reserved. +// + +#ifndef rrf_fs_prototype_abstractrandomforest_hpp +#define rrf_fs_prototype_abstractrandomforest_hpp + +#include "mothurout.h" +#include "macros.h" +#include "abstractdecisiontree.hpp" + +#define DEBUG_MODE + +/***********************************************************************/ + +class AbstractRandomForest{ +public: + // intialization with vectors + AbstractRandomForest(const std::vector < std::vector > dataSet, + const int numDecisionTrees, + const string); + virtual ~AbstractRandomForest(){ } + virtual int populateDecisionTrees() = 0; + virtual int calcForrestErrorRate() = 0; + virtual int calcForrestVariableImportance(string) = 0; + +/***********************************************************************/ + +protected: + + // TODO: create a better way of discarding feature + // currently we just set FEATURE_DISCARD_SD_THRESHOLD to 0 to solved this + // it can be tuned for better selection + // also, there might be other factors like Mean or other stuffs + // same would apply for createLocalDiscardedFeatureList in the TreeNode class + + // TODO: Another idea is getting an aggregated discarded feature indices after the run, from combining + // the local discarded feature indices + // this would penalize a feature, even if in global space the feature looks quite good + // the penalization would be averaged, so this woould unlikely to create a local optmina + + vector getGlobalDiscardedFeatureIndices(); + + int numDecisionTrees; + int numSamples; + int numFeatures; + vector< vector > dataSet; + vector globalDiscardedFeatureIndices; + vector globalVariableImportanceList; + string treeSplitCriterion; + // This is a map of each feature to outcome count of each classes + // e.g. 1 => [2 7] means feature 1 has 2 outcome of 0 and 7 outcome of 1 + map > globalOutOfBagEstimates; + + // TODO: fix this, do we use pointers? + vector decisionTrees; + + MothurOut* m; + +private: + +}; +#endif diff --git a/classifysharedcommand.cpp b/classifysharedcommand.cpp new file mode 100755 index 0000000..f964937 --- /dev/null +++ b/classifysharedcommand.cpp @@ -0,0 +1,364 @@ +// +// classifysharedcommand.cpp +// Mothur +// +// Created by Abu Zaher Md. Faridee on 8/13/12. +// Copyright (c) 2012 Schloss Lab. All rights reserved. +// + +#include "classifysharedcommand.h" +#include "randomforest.hpp" +#include "decisiontree.hpp" +#include "rftreenode.hpp" + +//********************************************************************************************************************** +vector ClassifySharedCommand::setParameters(){ + try { + //CommandParameter pprocessors("processors", "Number", "", "1", "", "", "",false,false); parameters.push_back(pprocessors); + CommandParameter pshared("shared", "InputTypes", "", "", "none", "none", "none",false,true); parameters.push_back(pshared); + CommandParameter pdesign("design", "InputTypes", "", "", "none", "none", "none",false,true); parameters.push_back(pdesign); + CommandParameter potupersplit("otupersplit", "Multiple", "log2-squareroot", "log2", "", "", "",false,false); parameters.push_back(potupersplit); + CommandParameter psplitcriteria("splitcriteria", "Multiple", "gainratio-infogain", "gainratio", "", "", "",false,false); parameters.push_back(psplitcriteria); + CommandParameter pnumtrees("numtrees", "Number", "", "100", "", "", "",false,false); parameters.push_back(pnumtrees); + + CommandParameter pgroups("groups", "String", "", "", "", "", "",false,false); parameters.push_back(pgroups); + CommandParameter plabel("label", "String", "", "", "", "", "",false,false); parameters.push_back(plabel); + CommandParameter pinputdir("inputdir", "String", "", "", "", "", "",false,false); parameters.push_back(pinputdir); + CommandParameter poutputdir("outputdir", "String", "", "", "", "", "",false,false); parameters.push_back(poutputdir); + + vector myArray; + for (int i = 0; i < parameters.size(); i++) { myArray.push_back(parameters[i].name); } + return myArray; + } + catch(exception& e) { + m->errorOut(e, "ClassifySharedCommand", "setParameters"); + exit(1); + } +} +//********************************************************************************************************************** +string ClassifySharedCommand::getHelpString(){ + try { + string helpString = ""; + helpString += "The classify.shared command allows you to ....\n"; + helpString += "The classify.shared command parameters are: shared, design, label, groups, otupersplit.\n"; + helpString += "The label parameter is used to analyze specific labels in your input.\n"; + helpString += "The groups parameter allows you to specify which of the groups in your designfile you would like analyzed.\n"; + helpString += "The classify.shared should be in the following format: \n"; + helpString += "classify.shared(shared=yourSharedFile, design=yourDesignFile)\n"; + return helpString; + } + catch(exception& e) { + m->errorOut(e, "ClassifySharedCommand", "getHelpString"); + exit(1); + } +} +//********************************************************************************************************************** +string ClassifySharedCommand::getOutputFileNameTag(string type, string inputName=""){ + try { + string tag = ""; + map >::iterator it; + + //is this a type this command creates + it = outputTypes.find(type); + if (it == outputTypes.end()) { m->mothurOut("[ERROR]: this command doesn't create a " + type + " output file.\n"); } + else { + if (type == "summary") { tag = "summary"; } + else { m->mothurOut("[ERROR]: No definition for type " + type + " output file tag.\n"); m->control_pressed = true; } + } + return tag; + } + catch(exception& e) { + m->errorOut(e, "ClassifySharedCommand", "getOutputFileName"); + exit(1); + } +} +//********************************************************************************************************************** + +ClassifySharedCommand::ClassifySharedCommand() { + try { + abort = true; calledHelp = true; + setParameters(); + vector tempOutNames; + outputTypes["summary"] = tempOutNames; + } + catch(exception& e) { + m->errorOut(e, "ClassifySharedCommand", "ClassifySharedCommand"); + exit(1); + } +} +//********************************************************************************************************************** +ClassifySharedCommand::ClassifySharedCommand(string option) { + try { + abort = false; calledHelp = false; + allLines = 1; + + //allow user to run help + if(option == "help") { help(); abort = true; calledHelp = true; } + else if(option == "citation") { citation(); abort = true; calledHelp = true;} + + else { + //valid paramters for this command + vector myArray = setParameters(); + + OptionParser parser(option); + map parameters = parser.getParameters(); + + ValidParameters validParameter; + map::iterator it; + //check to make sure all parameters are valid for command + for (it = parameters.begin(); it != parameters.end(); it++) { + if (validParameter.isValidParameter(it->first, myArray, it->second) != true) { abort = true; } + } + + vector tempOutNames; + outputTypes["summary"] = tempOutNames; + + //if the user changes the input directory command factory will send this info to us in the output parameter + string inputDir = validParameter.validFile(parameters, "inputdir", false); + if (inputDir == "not found"){ inputDir = ""; } + else { + string path; + it = parameters.find("shared"); + //user has given a shared file + if(it != parameters.end()){ + path = m->hasPath(it->second); + //if the user has not given a path then, add inputdir. else leave path alone. + if (path == "") { parameters["shared"] = inputDir + it->second; } + } + + it = parameters.find("design"); + //user has given a design file + if(it != parameters.end()){ + path = m->hasPath(it->second); + //if the user has not given a path then, add inputdir. else leave path alone. + if (path == "") { parameters["design"] = inputDir + it->second; } + } + + } + + //check for parameters + //get shared file, it is required + sharedfile = validParameter.validFile(parameters, "shared", true); + if (sharedfile == "not open") { sharedfile = ""; abort = true; } + else if (sharedfile == "not found") { + //if there is a current shared file, use it + sharedfile = m->getSharedFile(); + if (sharedfile != "") { m->mothurOut("Using " + sharedfile + " as input file for the shared parameter."); m->mothurOutEndLine(); } + else { m->mothurOut("You have no current sharedfile and the shared parameter is required."); m->mothurOutEndLine(); abort = true; } + }else { m->setSharedFile(sharedfile); } + + //get design file, it is required + designfile = validParameter.validFile(parameters, "design", true); + if (designfile == "not open") { sharedfile = ""; abort = true; } + else if (designfile == "not found") { + //if there is a current shared file, use it + designfile = m->getDesignFile(); + if (designfile != "") { m->mothurOut("Using " + designfile + " as input file for the design parameter."); m->mothurOutEndLine(); } + else { m->mothurOut("You have no current designfile and the design parameter is required."); m->mothurOutEndLine(); abort = true; } + }else { m->setDesignFile(designfile); } + + + //if the user changes the output directory command factory will send this info to us in the output parameter + outputDir = validParameter.validFile(parameters, "outputdir", false); if (outputDir == "not found"){ + outputDir = m->hasPath(sharedfile); //if user entered a file with a path then preserve it + } + + + // NEW CODE for OTU per split selection criteria + otupersplit = validParameter.validFile(parameters, "otupersplit", false); + if (otupersplit == "not found") { otupersplit = "log2"; } + if ((otupersplit == "squareroot") || (otupersplit == "log2")) { + optimumFeatureSubsetSelectionCriteria = otupersplit; + }else { m->mothurOut("Not a valid OTU per split selection method. Valid OTU per split selection methods are 'log2' and 'squareroot'."); m->mothurOutEndLine(); abort = true; } + + // splitcriteria + splitcriteria = validParameter.validFile(parameters, "splitcriteria", false); + if (splitcriteria == "not found") { splitcriteria = "gainratio"; } + if ((splitcriteria == "gainratio") || (splitcriteria == "infogain")) { + treeSplitCriterion = splitcriteria; + }else { m->mothurOut("Not a valid tree splitting criterio. Valid tree splitting criteria are 'gainratio' and 'infogain'."); m->mothurOutEndLine(); abort = true; } + + + string temp = validParameter.validFile(parameters, "numtrees", false); if (temp == "not found"){ temp = "100"; } + m->mothurConvert(temp, numDecisionTrees); + + //Groups must be checked later to make sure they are valid. SharedUtilities has functions of check the validity, just make to so m->setGroups() after the checks. If you are using these with a shared file no need to check the SharedRAbundVector class will call SharedUtilites for you, kinda nice, huh? + string groups = validParameter.validFile(parameters, "groups", false); + if (groups == "not found") { groups = ""; } + else { m->splitAtDash(groups, Groups); } + m->setGroups(Groups); + + //Commonly used to process list, rabund, sabund, shared and relabund files. Look at "smart distancing" examples below in the execute function. + string label = validParameter.validFile(parameters, "label", false); + if (label == "not found") { label = ""; } + else { + if(label != "all") { m->splitAtDash(label, labels); allLines = 0; } + else { allLines = 1; } + } + } + + } + catch(exception& e) { + m->errorOut(e, "ClassifySharedCommand", "ClassifySharedCommand"); + exit(1); + } +} +//********************************************************************************************************************** +int ClassifySharedCommand::execute() { + try { + + if (abort == true) { if (calledHelp) { return 0; } return 2; } + + InputData input(sharedfile, "sharedfile"); + vector lookup = input.getSharedRAbundVectors(); + + //read design file + designMap.readDesignMap(designfile); + + string lastLabel = lookup[0]->getLabel(); + set processedLabels; + set userLabels = labels; + + //as long as you are not at the end of the file or done wih the lines you want + while((lookup[0] != NULL) && ((allLines == 1) || (userLabels.size() != 0))) { + + if (m->control_pressed) { for (int i = 0; i < lookup.size(); i++) { delete lookup[i]; } return 0; } + + if(allLines == 1 || labels.count(lookup[0]->getLabel()) == 1){ + + m->mothurOut(lookup[0]->getLabel()); m->mothurOutEndLine(); + + processSharedAndDesignData(lookup); + + processedLabels.insert(lookup[0]->getLabel()); + userLabels.erase(lookup[0]->getLabel()); + } + + if ((m->anyLabelsToProcess(lookup[0]->getLabel(), userLabels, "") == true) && (processedLabels.count(lastLabel) != 1)) { + string saveLabel = lookup[0]->getLabel(); + + for (int i = 0; i < lookup.size(); i++) { delete lookup[i]; } + lookup = input.getSharedRAbundVectors(lastLabel); + m->mothurOut(lookup[0]->getLabel()); m->mothurOutEndLine(); + + processSharedAndDesignData(lookup); + + processedLabels.insert(lookup[0]->getLabel()); + userLabels.erase(lookup[0]->getLabel()); + + //restore real lastlabel to save below + lookup[0]->setLabel(saveLabel); + } + + lastLabel = lookup[0]->getLabel(); + //prevent memory leak + for (int i = 0; i < lookup.size(); i++) { delete lookup[i]; lookup[i] = NULL; } + + if (m->control_pressed) { return 0; } + + //get next line to process + lookup = input.getSharedRAbundVectors(); + } + + if (m->control_pressed) { return 0; } + + //output error messages about any remaining user labels + set::iterator it; + bool needToRun = false; + for (it = userLabels.begin(); it != userLabels.end(); it++) { + m->mothurOut("Your file does not include the label " + *it); + if (processedLabels.count(lastLabel) != 1) { + m->mothurOut(". I will use " + lastLabel + "."); m->mothurOutEndLine(); + needToRun = true; + }else { + m->mothurOut(". Please refer to " + lastLabel + "."); m->mothurOutEndLine(); + } + } + + //run last label if you need to + if (needToRun == true) { + for (int i = 0; i < lookup.size(); i++) { if (lookup[i] != NULL) { delete lookup[i]; } } + lookup = input.getSharedRAbundVectors(lastLabel); + + m->mothurOut(lookup[0]->getLabel()); m->mothurOutEndLine(); + + processSharedAndDesignData(lookup); + + for (int i = 0; i < lookup.size(); i++) { delete lookup[i]; } + + } + + m->mothurOutEndLine(); + m->mothurOut("Output File Names: "); m->mothurOutEndLine(); + for (int i = 0; i < outputNames.size(); i++) { m->mothurOut(outputNames[i]); m->mothurOutEndLine(); } + m->mothurOutEndLine(); + + return 0; + + } + catch(exception& e) { + m->errorOut(e, "ClassifySharedCommand", "execute"); + exit(1); + } +} +//********************************************************************************************************************** + +void ClassifySharedCommand::processSharedAndDesignData(vector lookup){ + try { +// for (int i = 0; i < designMap->getNamesOfGroups().size(); i++) { +// string groupName = designMap->getNamesOfGroups()[i]; +// cout << groupName << endl; +// } + +// for (int i = 0; i < designMap->getNumSeqs(); i++) { +// string sharedGroupName = designMap->getNamesSeqs()[i]; +// string treatmentName = designMap->getGroup(sharedGroupName); +// cout << sharedGroupName << " : " << treatmentName << endl; +// } + + map treatmentToIntMap; + map intToTreatmentMap; + for (int i = 0; i < designMap.getNumGroups(); i++) { + string treatmentName = designMap.getNamesOfGroups()[i]; + treatmentToIntMap[treatmentName] = i; + intToTreatmentMap[i] = treatmentName; + } + + int numSamples = lookup.size(); + int numFeatures = lookup[0]->getNumBins(); + + int numRows = numSamples; + int numColumns = numFeatures + 1; // extra one space needed for the treatment/outcome + + vector< vector > dataSet(numRows, vector(numColumns, 0)); + + for (int i = 0; i < lookup.size(); i++) { + string sharedGroupName = lookup[i]->getGroup(); + string treatmentName = designMap.getGroup(sharedGroupName); + + int j = 0; + for (; j < lookup[i]->getNumBins(); j++) { + int otuCount = lookup[i]->getAbundance(j); + dataSet[i][j] = otuCount; + } + dataSet[i][j] = treatmentToIntMap[treatmentName]; + } + + RandomForest randomForest(dataSet, numDecisionTrees, treeSplitCriterion); + randomForest.populateDecisionTrees(); + randomForest.calcForrestErrorRate(); + + string filename = outputDir + m->getRootName(m->getSimpleName(sharedfile)) + lookup[0]->getLabel() + "." + getOutputFileNameTag("summary"); + outputNames.push_back(filename); outputTypes["summary"].push_back(filename); + + randomForest.calcForrestVariableImportance(filename); + + m->mothurOutEndLine(); + } + catch(exception& e) { + m->errorOut(e, "ClassifySharedCommand", "processSharedAndDesignData"); + exit(1); + } +} +//********************************************************************************************************************** + diff --git a/classifysharedcommand.h b/classifysharedcommand.h new file mode 100755 index 0000000..93c6286 --- /dev/null +++ b/classifysharedcommand.h @@ -0,0 +1,54 @@ +// +// classifysharedcommand.h +// Mothur +// +// Created by Abu Zaher Md. Faridee on 8/13/12. +// Copyright (c) 2012 Schloss Lab. All rights reserved. +// + +#ifndef __Mothur__classifysharedcommand__ +#define __Mothur__classifysharedcommand__ + +#include "command.hpp" +#include "inputdata.h" + +class ClassifySharedCommand : public Command { +public: + ClassifySharedCommand(); + ClassifySharedCommand(string); + ~ClassifySharedCommand() {}; + + vector setParameters(); + string getCommandName() { return "classify.shared"; } + string getCommandCategory() { return "OTU-Based Approaches"; } + string getOutputFileNameTag(string, string); + string getHelpString(); + string getCitation() { return "http://www.mothur.org/wiki/Classify.shared\n"; } + string getDescription() { return "description"; } + int execute(); + + void help() { m->mothurOut(getHelpString()); } + +private: + bool abort; + string outputDir; + vector outputNames, Groups; + + string sharedfile, designfile, otupersplit, splitcriteria; + set labels; + bool allLines; + + int processors; + bool useTiming; + + GroupMap designMap; + + int numDecisionTrees; + string treeSplitCriterion, optimumFeatureSubsetSelectionCriteria; + bool doPruning, discardHighErrorTrees; + double pruneAggressiveness, highErrorTreeDiscardThreshold, featureStandardDeviationThreshold; + + void processSharedAndDesignData(vector lookup); +}; + +#endif /* defined(__Mothur__classifysharedcommand__) */ diff --git a/commandfactory.cpp b/commandfactory.cpp index 8653643..6d87a68 100644 --- a/commandfactory.cpp +++ b/commandfactory.cpp @@ -135,6 +135,7 @@ #include "makecontigscommand.h" #include "loadlogfilecommand.h" #include "sffmultiplecommand.h" +#include "classifysharedcommand.h" /*******************************************************/ @@ -293,6 +294,8 @@ CommandFactory::CommandFactory(){ commands["make.table"] = "make.table"; commands["sff.multiple"] = "sff.multiple"; commands["quit"] = "MPIEnabled"; + commands["classify.shared"] = "classify.shared"; + } /***********************************************************/ @@ -506,6 +509,7 @@ Command* CommandFactory::getCommand(string commandName, string optionString){ else if(commandName == "make.contigs") { command = new MakeContigsCommand(optionString); } else if(commandName == "load.logfile") { command = new LoadLogfileCommand(optionString); } else if(commandName == "sff.multiple") { command = new SffMultipleCommand(optionString); } + else if(commandName == "classify.shared") { command = new ClassifySharedCommand(optionString); } else { command = new NoCommand(optionString); } return command; @@ -661,6 +665,7 @@ Command* CommandFactory::getCommand(string commandName, string optionString, str else if(commandName == "make.contigs") { pipecommand = new MakeContigsCommand(optionString); } else if(commandName == "load.logfile") { pipecommand = new LoadLogfileCommand(optionString); } else if(commandName == "sff.multiple") { pipecommand = new SffMultipleCommand(optionString); } + else if(commandName == "classify.shared") { pipecommand = new ClassifySharedCommand(optionString); } else { pipecommand = new NoCommand(optionString); } return pipecommand; @@ -802,6 +807,7 @@ Command* CommandFactory::getCommand(string commandName){ else if(commandName == "make.contigs") { shellcommand = new MakeContigsCommand(); } else if(commandName == "load.logfile") { shellcommand = new LoadLogfileCommand(); } else if(commandName == "sff.multiple") { shellcommand = new SffMultipleCommand(); } + else if(commandName == "classify.shared") { shellcommand = new ClassifySharedCommand(); } else { shellcommand = new NoCommand(); } return shellcommand; diff --git a/countgroupscommand.cpp b/countgroupscommand.cpp index ccf8988..716dc90 100644 --- a/countgroupscommand.cpp +++ b/countgroupscommand.cpp @@ -16,6 +16,7 @@ vector CountGroupsCommand::setParameters(){ try { CommandParameter pshared("shared", "InputTypes", "", "", "sharedGroup", "sharedGroup", "none",false,false); parameters.push_back(pshared); CommandParameter pgroup("group", "InputTypes", "", "", "sharedGroup", "sharedGroup", "none",false,false); parameters.push_back(pgroup); + CommandParameter pcount("count", "InputTypes", "", "", "sharedGroup", "sharedGroup", "none",false,false); parameters.push_back(pcount); CommandParameter paccnos("accnos", "InputTypes", "", "", "none", "none", "none",false,false); parameters.push_back(paccnos); CommandParameter pgroups("groups", "String", "", "", "", "", "",false,false); parameters.push_back(pgroups); CommandParameter pinputdir("inputdir", "String", "", "", "", "", "",false,false); parameters.push_back(pinputdir); @@ -34,7 +35,7 @@ vector CountGroupsCommand::setParameters(){ string CountGroupsCommand::getHelpString(){ try { string helpString = ""; - helpString += "The count.groups command counts sequences from a specific group or set of groups from the following file types: group or shared file.\n"; + helpString += "The count.groups command counts sequences from a specific group or set of groups from the following file types: group, count or shared file.\n"; helpString += "The count.groups command parameters are accnos, group, shared and groups. You must provide a group or shared file.\n"; helpString += "The accnos parameter allows you to provide a file containing the list of groups.\n"; helpString += "The groups parameter allows you to specify which of the groups in your groupfile you would like. You can separate group names with dashes.\n"; @@ -114,6 +115,14 @@ CountGroupsCommand::CountGroupsCommand(string option) { //if the user has not given a path then, add inputdir. else leave path alone. if (path == "") { parameters["shared"] = inputDir + it->second; } } + + it = parameters.find("count"); + //user has given a template file + if(it != parameters.end()){ + path = m->hasPath(it->second); + //if the user has not given a path then, add inputdir. else leave path alone. + if (path == "") { parameters["count"] = inputDir + it->second; } + } } @@ -138,9 +147,23 @@ CountGroupsCommand::CountGroupsCommand(string option) { groupfile = validParameter.validFile(parameters, "group", true); if (groupfile == "not open") { groupfile = ""; abort = true; } else if (groupfile == "not found") { groupfile = ""; } - else { m->setGroupFile(groupfile); } + else { m->setGroupFile(groupfile); } + + countfile = validParameter.validFile(parameters, "count", true); + if (countfile == "not open") { countfile = ""; abort = true; } + else if (countfile == "not found") { countfile = ""; } + else { + m->setCountTableFile(countfile); + CountTable ct; + if (!ct.testGroups(countfile)) { m->mothurOut("[ERROR]: Your count file does not have any group information, aborting."); m->mothurOutEndLine(); abort=true; } + } + + if ((groupfile != "") && (countfile != "")) { + m->mothurOut("[ERROR]: you may only use one of the following: group or count."); m->mothurOutEndLine(); abort=true; + } + - if ((sharedfile == "") && (groupfile == "")) { + if ((sharedfile == "") && (groupfile == "") && (countfile == "")) { //give priority to shared, then group sharedfile = m->getSharedFile(); if (sharedfile != "") { m->mothurOut("Using " + sharedfile + " as input file for the shared parameter."); m->mothurOutEndLine(); } @@ -148,7 +171,11 @@ CountGroupsCommand::CountGroupsCommand(string option) { groupfile = m->getGroupFile(); if (groupfile != "") { m->mothurOut("Using " + groupfile + " as input file for the group parameter."); m->mothurOutEndLine(); } else { - m->mothurOut("You have no current groupfile or sharedfile and one is required."); m->mothurOutEndLine(); abort = true; + countfile = m->getCountTableFile(); + if (countfile != "") { m->mothurOut("Using " + countfile + " as input file for the count parameter."); m->mothurOutEndLine(); } + else { + m->mothurOut("You have no current groupfile, countfile or sharedfile and one is required."); m->mothurOutEndLine(); abort = true; + } } } } @@ -182,9 +209,36 @@ int CountGroupsCommand::execute(){ vector nameGroups = groupMap.getNamesOfGroups(); util.setGroups(Groups, nameGroups); + int total = 0; + for (int i = 0; i < Groups.size(); i++) { + int num = groupMap.getNumSeqs(Groups[i]); + total += num; + m->mothurOut(Groups[i] + " contains " + toString(num) + "."); m->mothurOutEndLine(); + } + + m->mothurOut("\nTotal seqs: " + toString(total) + "."); m->mothurOutEndLine(); + } + + if (m->control_pressed) { return 0; } + + if (countfile != "") { + CountTable ct; + ct.readTable(countfile); + + //make sure groups are valid + //takes care of user setting groupNames that are invalid or setting groups=all + SharedUtil util; + vector nameGroups = ct.getNamesOfGroups(); + util.setGroups(Groups, nameGroups); + + int total = 0; for (int i = 0; i < Groups.size(); i++) { - m->mothurOut(Groups[i] + " contains " + toString(groupMap.getNumSeqs(Groups[i])) + "."); m->mothurOutEndLine(); + int num = ct.getGroupCount(Groups[i]); + total += num; + m->mothurOut(Groups[i] + " contains " + toString(num) + "."); m->mothurOutEndLine(); } + + m->mothurOut("\nTotal seqs: " + toString(total) + "."); m->mothurOutEndLine(); } if (m->control_pressed) { return 0; } @@ -193,10 +247,15 @@ int CountGroupsCommand::execute(){ InputData input(sharedfile, "sharedfile"); vector lookup = input.getSharedRAbundVectors(); + int total = 0; for (int i = 0; i < lookup.size(); i++) { - m->mothurOut(lookup[i]->getGroup() + " contains " + toString(lookup[i]->getNumSeqs()) + "."); m->mothurOutEndLine(); + int num = lookup[i]->getNumSeqs(); + total += num; + m->mothurOut(lookup[i]->getGroup() + " contains " + toString(num) + "."); m->mothurOutEndLine(); delete lookup[i]; - } + } + + m->mothurOut("\nTotal seqs: " + toString(total) + "."); m->mothurOutEndLine(); } return 0; diff --git a/countgroupscommand.h b/countgroupscommand.h index dd0e0a2..d27a7f8 100644 --- a/countgroupscommand.h +++ b/countgroupscommand.h @@ -33,7 +33,7 @@ public: private: - string sharedfile, groupfile, outputDir, groups, accnosfile; + string sharedfile, groupfile, countfile, outputDir, groups, accnosfile; bool abort; vector Groups; }; diff --git a/counttable.cpp b/counttable.cpp index a79047d..2ab0e34 100644 --- a/counttable.cpp +++ b/counttable.cpp @@ -481,6 +481,10 @@ int CountTable::addGroup(string groupName) { int CountTable::removeGroup(string groupName) { try { if (hasGroups) { + //save for later in case removing a group means we need to remove a seq. + map reverse; + for (map::iterator it = indexNameMap.begin(); it !=indexNameMap.end(); it++) { reverse[it->second] = it->first; } + map::iterator it = indexGroupMap.find(groupName); if (it == indexGroupMap.end()) { m->mothurOut("[ERROR]: " + groupName + " is not in your count table. Please correct.\n"); m->control_pressed = true; @@ -491,13 +495,15 @@ int CountTable::removeGroup(string groupName) { for (int i = 0; i < groups.size(); i++) { if (groups[i] != groupName) { newGroups.push_back(groups[i]); - indexGroupMap[groups[i]] = i; + indexGroupMap[groups[i]] = newGroups.size()-1; } } indexGroupMap.erase(groupName); groups = newGroups; totalGroups.erase(totalGroups.begin()+indexOfGroupToRemove); - + + int thisIndex = 0; + map newIndexNameMap; for (int i = 0; i < counts.size(); i++) { int num = counts[i][indexOfGroupToRemove]; counts[i].erase(counts[i].begin()+indexOfGroupToRemove); @@ -509,7 +515,11 @@ int CountTable::removeGroup(string groupName) { uniques--; i--; } + newIndexNameMap[reverse[thisIndex]] = i; + thisIndex++; } + indexNameMap = newIndexNameMap; + if (groups.size() == 0) { hasGroups = false; } } }else { m->mothurOut("[ERROR]: your count table does not contain group information, can not remove group " + groupName + ".\n"); m->control_pressed = true; } diff --git a/decisiontree.cpp b/decisiontree.cpp new file mode 100644 index 0000000..99853f3 --- /dev/null +++ b/decisiontree.cpp @@ -0,0 +1,399 @@ +// +// decisiontree.cpp +// Mothur +// +// Created by Sarah Westcott on 10/1/12. +// Copyright (c) 2012 Schloss Lab. All rights reserved. +// + +#include "decisiontree.hpp" + +DecisionTree::DecisionTree(vector< vector > baseDataSet, + vector globalDiscardedFeatureIndices, + OptimumFeatureSubsetSelector optimumFeatureSubsetSelector, + string treeSplitCriterion) : AbstractDecisionTree(baseDataSet, + globalDiscardedFeatureIndices, + optimumFeatureSubsetSelector, + treeSplitCriterion), variableImportanceList(numFeatures, 0){ + try { + m = MothurOut::getInstance(); + createBootStrappedSamples(); + buildDecisionTree(); + } + catch(exception& e) { + m->errorOut(e, "DecisionTree", "DecisionTree"); + exit(1); + } +} + +/***********************************************************************/ + +int DecisionTree::calcTreeVariableImportanceAndError() { + try { + + int numCorrect; + double treeErrorRate; + calcTreeErrorRate(numCorrect, treeErrorRate); + + if (m->control_pressed) {return 0; } + + for (int i = 0; i < numFeatures; i++) { + if (m->control_pressed) {return 0; } + // NOTE: only shuffle the features, never shuffle the output vector + // so i = 0 and i will be alwaays <= (numFeatures - 1) as the index at numFeatures will denote + // the feature vector + vector< vector > randomlySampledTestData = randomlyShuffleAttribute(bootstrappedTestSamples, i); + + int numCorrectAfterShuffle = 0; + for (int j = 0; j < randomlySampledTestData.size(); j++) { + if (m->control_pressed) {return 0; } + vector shuffledSample = randomlySampledTestData[j]; + int actualSampleOutputClass = shuffledSample[numFeatures]; + int predictedSampleOutputClass = evaluateSample(shuffledSample); + if (actualSampleOutputClass == predictedSampleOutputClass) { numCorrectAfterShuffle++; } + } + variableImportanceList[i] += (numCorrect - numCorrectAfterShuffle); + } + + // TODO: do we need to save the variableRanks in the DecisionTree, do we need it later? + vector< vector > variableRanks; + for (int i = 0; i < variableImportanceList.size(); i++) { + if (m->control_pressed) {return 0; } + if (variableImportanceList[i] > 0) { + // TODO: is there a way to optimize the follow line's code? + vector variableRank(2, 0); + variableRank[0] = i; variableRank[1] = variableImportanceList[i]; + variableRanks.push_back(variableRank); + } + } + VariableRankDescendingSorter variableRankDescendingSorter; + sort(variableRanks.begin(), variableRanks.end(), variableRankDescendingSorter); + + return 0; + } + catch(exception& e) { + m->errorOut(e, "DecisionTree", "calcTreeVariableImportanceAndError"); + exit(1); + } + +} +/***********************************************************************/ + +// TODO: there must be a way to optimize this function +int DecisionTree::evaluateSample(vector testSample) { + try { + RFTreeNode *node = rootNode; + while (true) { + if (m->control_pressed) {return 0; } + if (node->checkIsLeaf()) { return node->getOutputClass(); } + int sampleSplitFeatureValue = testSample[node->getSplitFeatureIndex()]; + if (sampleSplitFeatureValue < node->getSplitFeatureValue()) { node = node->getLeftChildNode(); } + else { node = node->getRightChildNode(); } + } + return 0; + } + catch(exception& e) { + m->errorOut(e, "DecisionTree", "evaluateSample"); + exit(1); + } + +} +/***********************************************************************/ + +int DecisionTree::calcTreeErrorRate(int& numCorrect, double& treeErrorRate){ + try { + numCorrect = 0; + for (int i = 0; i < bootstrappedTestSamples.size(); i++) { + if (m->control_pressed) {return 0; } + + vector testSample = bootstrappedTestSamples[i]; + int testSampleIndex = bootstrappedTestSampleIndices[i]; + + int actualSampleOutputClass = testSample[numFeatures]; + int predictedSampleOutputClass = evaluateSample(testSample); + + if (actualSampleOutputClass == predictedSampleOutputClass) { numCorrect++; } + + outOfBagEstimates[testSampleIndex] = predictedSampleOutputClass; + } + + treeErrorRate = 1 - ((double)numCorrect / (double)bootstrappedTestSamples.size()); + + return 0; + } + catch(exception& e) { + m->errorOut(e, "DecisionTree", "calcTreeErrorRate"); + exit(1); + } +} + +/***********************************************************************/ + +// TODO: optimize the algo, instead of transposing two time, we can extarct the feature, +// shuffle it and then re-insert in the original place, thus iproving runnting time +//This function randomize abundances for a given OTU/feature. +vector< vector > DecisionTree::randomlyShuffleAttribute(vector< vector > samples, int featureIndex) { + try { + // NOTE: we need (numFeatures + 1) featureVecotors, the last extra vector is actually outputVector + vector< vector > shuffledSample = samples; + vector featureVectors(samples.size(), 0); + + for (int j = 0; j < samples.size(); j++) { + if (m->control_pressed) { return shuffledSample; } + featureVectors[j] = samples[j][featureIndex]; + } + + random_shuffle(featureVectors.begin(), featureVectors.end()); + + for (int j = 0; j < samples.size(); j++) { + if (m->control_pressed) {return shuffledSample; } + shuffledSample[j][featureIndex] = featureVectors[j]; + } + + return shuffledSample; + } + catch(exception& e) { + m->errorOut(e, "DecisionTree", "randomlyShuffleAttribute"); + exit(1); + } +} +/***********************************************************************/ + +int DecisionTree::purgeTreeNodesDataRecursively(RFTreeNode* treeNode) { + try { + treeNode->bootstrappedTrainingSamples.clear(); + treeNode->bootstrappedFeatureVectors.clear(); + treeNode->bootstrappedOutputVector.clear(); + treeNode->localDiscardedFeatureIndices.clear(); + treeNode->globalDiscardedFeatureIndices.clear(); + + if (treeNode->leftChildNode != NULL) { purgeTreeNodesDataRecursively(treeNode->leftChildNode); } + if (treeNode->rightChildNode != NULL) { purgeTreeNodesDataRecursively(treeNode->rightChildNode); } + return 0; + } + catch(exception& e) { + m->errorOut(e, "DecisionTree", "purgeTreeNodesDataRecursively"); + exit(1); + } +} +/***********************************************************************/ + +void DecisionTree::buildDecisionTree(){ + try { + + int generation = 0; + rootNode = new RFTreeNode(bootstrappedTrainingSamples, globalDiscardedFeatureIndices, numFeatures, numSamples, numOutputClasses, generation); + + splitRecursively(rootNode); + } + catch(exception& e) { + m->errorOut(e, "DecisionTree", "buildDecisionTree"); + exit(1); + } +} + +/***********************************************************************/ + +int DecisionTree::splitRecursively(RFTreeNode* rootNode) { + try { + + if (rootNode->getNumSamples() < 2){ + rootNode->setIsLeaf(true); + rootNode->setOutputClass(rootNode->getBootstrappedTrainingSamples()[0][rootNode->getNumFeatures()]); + return 0; + } + + int classifiedOutputClass; + bool isAlreadyClassified = checkIfAlreadyClassified(rootNode, classifiedOutputClass); + if (isAlreadyClassified == true){ + rootNode->setIsLeaf(true); + rootNode->setOutputClass(classifiedOutputClass); + return 0; + } + if (m->control_pressed) {return 0;} + vector featureSubsetIndices = selectFeatureSubsetRandomly(globalDiscardedFeatureIndices, rootNode->getLocalDiscardedFeatureIndices()); + rootNode->setFeatureSubsetIndices(featureSubsetIndices); + if (m->control_pressed) {return 0;} + + findAndUpdateBestFeatureToSplitOn(rootNode); + + if (m->control_pressed) {return 0;} + + vector< vector > leftChildSamples; + vector< vector > rightChildSamples; + getSplitPopulation(rootNode, leftChildSamples, rightChildSamples); + + if (m->control_pressed) {return 0;} + + // TODO: need to write code to clear this memory + RFTreeNode* leftChildNode = new RFTreeNode(leftChildSamples, globalDiscardedFeatureIndices, numFeatures, (int)leftChildSamples.size(), numOutputClasses, rootNode->getGeneration() + 1); + RFTreeNode* rightChildNode = new RFTreeNode(rightChildSamples, globalDiscardedFeatureIndices, numFeatures, (int)rightChildSamples.size(), numOutputClasses, rootNode->getGeneration() + 1); + + rootNode->setLeftChildNode(leftChildNode); + leftChildNode->setParentNode(rootNode); + + rootNode->setRightChildNode(rightChildNode); + rightChildNode->setParentNode(rootNode); + + // TODO: This recursive split can be parrallelized later + splitRecursively(leftChildNode); + if (m->control_pressed) {return 0;} + + splitRecursively(rightChildNode); + return 0; + + } + catch(exception& e) { + m->errorOut(e, "DecisionTree", "splitRecursively"); + exit(1); + } +} +/***********************************************************************/ + +int DecisionTree::findAndUpdateBestFeatureToSplitOn(RFTreeNode* node){ + try { + + vector< vector > bootstrappedFeatureVectors = node->getBootstrappedFeatureVectors(); + if (m->control_pressed) {return 0;} + vector bootstrappedOutputVector = node->getBootstrappedOutputVector(); + if (m->control_pressed) {return 0;} + vector featureSubsetIndices = node->getFeatureSubsetIndices(); + if (m->control_pressed) {return 0;} + + vector featureSubsetEntropies; + vector featureSubsetSplitValues; + vector featureSubsetIntrinsicValues; + vector featureSubsetGainRatios; + + for (int i = 0; i < featureSubsetIndices.size(); i++) { + if (m->control_pressed) {return 0;} + + int tryIndex = featureSubsetIndices[i]; + + double featureMinEntropy; + int featureSplitValue; + double featureIntrinsicValue; + + getMinEntropyOfFeature(bootstrappedFeatureVectors[tryIndex], bootstrappedOutputVector, featureMinEntropy, featureSplitValue, featureIntrinsicValue); + if (m->control_pressed) {return 0;} + + featureSubsetEntropies.push_back(featureMinEntropy); + featureSubsetSplitValues.push_back(featureSplitValue); + featureSubsetIntrinsicValues.push_back(featureIntrinsicValue); + + double featureInformationGain = node->getOwnEntropy() - featureMinEntropy; + double featureGainRatio = (double)featureInformationGain / (double)featureIntrinsicValue; + featureSubsetGainRatios.push_back(featureGainRatio); + + } + + vector::iterator minEntropyIterator = min_element(featureSubsetEntropies.begin(), featureSubsetEntropies.end()); + vector::iterator maxGainRatioIterator = max_element(featureSubsetGainRatios.begin(), featureSubsetGainRatios.end()); + double featureMinEntropy = *minEntropyIterator; + //double featureMaxGainRatio = *maxGainRatioIterator; + + double bestFeatureSplitEntropy = featureMinEntropy; + int bestFeatureToSplitOnIndex = -1; + if (treeSplitCriterion == "gainRatio"){ + bestFeatureToSplitOnIndex = (int)(maxGainRatioIterator - featureSubsetGainRatios.begin()); + // if using 'gainRatio' measure, then featureMinEntropy must be re-updated, as the index + // for 'featureMaxGainRatio' would be different + bestFeatureSplitEntropy = featureSubsetEntropies[bestFeatureToSplitOnIndex]; + } + else { bestFeatureToSplitOnIndex = (int)(minEntropyIterator - featureSubsetEntropies.begin()); } + + int bestFeatureSplitValue = featureSubsetSplitValues[bestFeatureToSplitOnIndex]; + + node->setSplitFeatureIndex(featureSubsetIndices[bestFeatureToSplitOnIndex]); + node->setSplitFeatureValue(bestFeatureSplitValue); + node->setSplitFeatureEntropy(bestFeatureSplitEntropy); + + return 0; + } + catch(exception& e) { + m->errorOut(e, "DecisionTree", "findAndUpdateBestFeatureToSplitOn"); + exit(1); + } +} +/***********************************************************************/ +vector DecisionTree::selectFeatureSubsetRandomly(vector globalDiscardedFeatureIndices, vector localDiscardedFeatureIndices){ + try { + + vector featureSubsetIndices; + + vector combinedDiscardedFeatureIndices; + combinedDiscardedFeatureIndices.insert(combinedDiscardedFeatureIndices.end(), globalDiscardedFeatureIndices.begin(), globalDiscardedFeatureIndices.end()); + combinedDiscardedFeatureIndices.insert(combinedDiscardedFeatureIndices.end(), localDiscardedFeatureIndices.begin(), localDiscardedFeatureIndices.end()); + + sort(combinedDiscardedFeatureIndices.begin(), combinedDiscardedFeatureIndices.end()); + + int numberOfRemainingSuitableFeatures = (int)(numFeatures - combinedDiscardedFeatureIndices.size()); + int currentFeatureSubsetSize = numberOfRemainingSuitableFeatures < optimumFeatureSubsetSize ? numberOfRemainingSuitableFeatures : optimumFeatureSubsetSize; + + while (featureSubsetIndices.size() < currentFeatureSubsetSize) { + + if (m->control_pressed) { return featureSubsetIndices; } + + // TODO: optimize rand() call here + int randomIndex = rand() % numFeatures; + vector::iterator it = find(featureSubsetIndices.begin(), featureSubsetIndices.end(), randomIndex); + if (it == featureSubsetIndices.end()){ // NOT FOUND + vector::iterator it2 = find(combinedDiscardedFeatureIndices.begin(), combinedDiscardedFeatureIndices.end(), randomIndex); + if (it2 == combinedDiscardedFeatureIndices.end()){ // NOT FOUND AGAIN + featureSubsetIndices.push_back(randomIndex); + } + } + } + sort(featureSubsetIndices.begin(), featureSubsetIndices.end()); + + //#ifdef DEBUG_LEVEL_3 + // PRINT_VAR(featureSubsetIndices); + //#endif + + return featureSubsetIndices; + } + catch(exception& e) { + m->errorOut(e, "DecisionTree", "selectFeatureSubsetRandomly"); + exit(1); + } +} +/***********************************************************************/ + +// TODO: printTree() needs a check if correct +int DecisionTree::printTree(RFTreeNode* treeNode, string caption){ + try { + string tabs = ""; + for (int i = 0; i < treeNode->getGeneration(); i++) { tabs += " "; } + // for (int i = 0; i < treeNode->getGeneration() - 1; i++) { tabs += "| "; } + // if (treeNode->getGeneration() != 0) { tabs += "|--"; } + + if (treeNode != NULL && treeNode->checkIsLeaf() == false){ + m->mothurOut(tabs + caption + " [ gen: " + toString(treeNode->getGeneration()) + " ] ( " + toString(treeNode->getSplitFeatureValue()) + " < X" + toString(treeNode->getSplitFeatureIndex()) +" )\n"); + + printTree(treeNode->getLeftChildNode(), "leftChild"); + printTree(treeNode->getRightChildNode(), "rightChild"); + }else { + m->mothurOut(tabs + caption + " [ gen: " + toString(treeNode->getGeneration()) + " ] ( classified to: " + toString(treeNode->getOutputClass()) + ", samples: " + toString(treeNode->getNumSamples()) + " )\n"); + } + return 0; + } + catch(exception& e) { + m->errorOut(e, "DecisionTree", "printTree"); + exit(1); + } +} +/***********************************************************************/ +void DecisionTree::deleteTreeNodesRecursively(RFTreeNode* treeNode) { + try { + if (treeNode == NULL) { return; } + deleteTreeNodesRecursively(treeNode->leftChildNode); + deleteTreeNodesRecursively(treeNode->rightChildNode); + delete treeNode; + } + catch(exception& e) { + m->errorOut(e, "DecisionTree", "deleteTreeNodesRecursively"); + exit(1); + } +} +/***********************************************************************/ + diff --git a/decisiontree.hpp b/decisiontree.hpp new file mode 100755 index 0000000..d4441ed --- /dev/null +++ b/decisiontree.hpp @@ -0,0 +1,59 @@ + // + // decisiontree.hpp + // rrf-fs-prototype + // + // Created by Abu Zaher Faridee on 5/28/12. + // Copyright (c) 2012 Schloss Lab. All rights reserved. + // + +#ifndef rrf_fs_prototype_decisiontree_hpp +#define rrf_fs_prototype_decisiontree_hpp + +#include "macros.h" +#include "rftreenode.hpp" +#include "abstractdecisiontree.hpp" + +/***********************************************************************/ + +struct VariableRankDescendingSorter { + bool operator() (vector first, vector second){ return first[1] > second[1]; } +}; +struct VariableRankDescendingSorterDouble { + bool operator() (vector first, vector second){ return first[1] > second[1]; } +}; +/***********************************************************************/ + +class DecisionTree: public AbstractDecisionTree{ + + friend class RandomForest; + +public: + + DecisionTree(vector< vector > baseDataSet, + vector globalDiscardedFeatureIndices, + OptimumFeatureSubsetSelector optimumFeatureSubsetSelector, + string treeSplitCriterion); + virtual ~DecisionTree(){ deleteTreeNodesRecursively(rootNode); } + + int calcTreeVariableImportanceAndError(); + int evaluateSample(vector testSample); + int calcTreeErrorRate(int& numCorrect, double& treeErrorRate); + vector< vector > randomlyShuffleAttribute(vector< vector > samples, int featureIndex); + void purgeDataSetsFromTree() { purgeTreeNodesDataRecursively(rootNode); } + int purgeTreeNodesDataRecursively(RFTreeNode* treeNode); + + +private: + + void buildDecisionTree(); + int splitRecursively(RFTreeNode* rootNode); + int findAndUpdateBestFeatureToSplitOn(RFTreeNode* node); + vector selectFeatureSubsetRandomly(vector globalDiscardedFeatureIndices, vector localDiscardedFeatureIndices); + int printTree(RFTreeNode* treeNode, string caption); + void deleteTreeNodesRecursively(RFTreeNode* treeNode); + + vector variableImportanceList; + map outOfBagEstimates; +}; + +#endif diff --git a/groupmap.cpp b/groupmap.cpp index 7ce9073..fb2495c 100644 --- a/groupmap.cpp +++ b/groupmap.cpp @@ -162,6 +162,79 @@ int GroupMap::readDesignMap() { } } /************************************************************/ +int GroupMap::readMap(string filename) { + try { + groupFileName = filename; + m->openInputFile(filename, fileHandle); + index = 0; + string seqName, seqGroup; + int error = 0; + string rest = ""; + char buffer[4096]; + bool pairDone = false; + bool columnOne = true; + + while (!fileHandle.eof()) { + if (m->control_pressed) { fileHandle.close(); return 1; } + + fileHandle.read(buffer, 4096); + vector pieces = m->splitWhiteSpace(rest, buffer, fileHandle.gcount()); + + for (int i = 0; i < pieces.size(); i++) { + if (columnOne) { seqName = pieces[i]; columnOne=false; } + else { seqGroup = pieces[i]; pairDone = true; columnOne=true; } + + if (pairDone) { + setNamesOfGroups(seqGroup); + + if (m->debug) { m->mothurOut("[DEBUG]: name = '" + seqName + "', group = '" + seqGroup + "'\n"); } + + it = groupmap.find(seqName); + + if (it != groupmap.end()) { error = 1; m->mothurOut("Your group file contains more than 1 sequence named " + seqName + ", sequence names must be unique. Please correct."); m->mothurOutEndLine(); } + else { + groupmap[seqName] = seqGroup; //store data in map + seqsPerGroup[seqGroup]++; //increment number of seqs in that group + } + pairDone = false; + } + } + } + fileHandle.close(); + + if (rest != "") { + vector pieces = m->splitWhiteSpace(rest); + + for (int i = 0; i < pieces.size(); i++) { + if (columnOne) { seqName = pieces[i]; columnOne=false; } + else { seqGroup = pieces[i]; pairDone = true; columnOne=true; } + + if (pairDone) { + setNamesOfGroups(seqGroup); + + if (m->debug) { m->mothurOut("[DEBUG]: name = '" + seqName + "', group = '" + seqGroup + "'\n"); } + + it = groupmap.find(seqName); + + if (it != groupmap.end()) { error = 1; m->mothurOut("Your group file contains more than 1 sequence named " + seqName + ", sequence names must be unique. Please correct."); m->mothurOutEndLine(); } + else { + groupmap[seqName] = seqGroup; //store data in map + seqsPerGroup[seqGroup]++; //increment number of seqs in that group + } + pairDone = false; + } + } + } + + m->setAllGroups(namesOfGroups); + return error; + } + catch(exception& e) { + m->errorOut(e, "GroupMap", "readMap"); + exit(1); + } +} +/************************************************************/ int GroupMap::readDesignMap(string filename) { try { groupFileName = filename; diff --git a/groupmap.h b/groupmap.h index 567165d..d698495 100644 --- a/groupmap.h +++ b/groupmap.h @@ -21,6 +21,7 @@ public: GroupMap(string); ~GroupMap(); int readMap(); + int readMap(string); int readDesignMap(); int readDesignMap(string); int getNumGroups(); diff --git a/macros.h b/macros.h new file mode 100755 index 0000000..f95acbe --- /dev/null +++ b/macros.h @@ -0,0 +1,32 @@ +// +// macros.h +// rrf-fs-prototype +// +// Created by Abu Zaher Faridee on 5/28/12. +// Copyright (c) 2012 Schloss Lab. All rights reserved. +// + +#ifndef rrf_fs_prototype_macros_h +#define rrf_fs_prototype_macros_h + +#include "mothurout.h" + +/***********************************************************************/ +class OptimumFeatureSubsetSelector{ +public: + OptimumFeatureSubsetSelector(string selectionType = "log2"): selectionType(selectionType){ + } + + int getOptimumFeatureSubsetSize(int numFeatures){ + + if (selectionType == "log2"){ return (int)ceil(log2(numFeatures)); } + else if (selectionType == "squareRoot"){ return (int)ceil(sqrt(numFeatures)); } + return -1; + } +private: + string selectionType; +}; + +/***********************************************************************/ + +#endif diff --git a/mothurout.cpp b/mothurout.cpp index c5eb0dc..7d40e80 100644 --- a/mothurout.cpp +++ b/mothurout.cpp @@ -2912,8 +2912,31 @@ string MothurOut::removeQuotes(string tax) { } } /**************************************************************************************************/ - - +// function for calculating standard deviation +double MothurOut::getStandardDeviation(vector& featureVector){ + try { + //finds sum + double average = 0; + for (int i = 0; i < featureVector.size(); i++) { average += featureVector[i]; } + average /= (double) featureVector.size(); + + //find standard deviation + double stdDev = 0; + for (int i = 0; i < featureVector.size(); i++) { //compute the difference of each dist from the mean, and square the result of each + stdDev += ((featureVector[i] - average) * (featureVector[i] - average)); + } + + stdDev /= (double) featureVector.size(); + stdDev = sqrt(stdDev); + + return stdDev; + } + catch(exception& e) { + errorOut(e, "MothurOut", "getStandardDeviation"); + exit(1); + } +} +/**************************************************************************************************/ diff --git a/mothurout.h b/mothurout.h index 0c2e448..53d4250 100644 --- a/mothurout.h +++ b/mothurout.h @@ -151,6 +151,7 @@ class MothurOut { float roundDist(float, int); unsigned int fromBase36(string); int getRandomIndex(int); //highest + double getStandardDeviation(vector&); int control_pressed; bool executing, runParse, jumble, gui, mothurCalling, debug; diff --git a/parsefastaqcommand.cpp b/parsefastaqcommand.cpp index 1331b7f..816bdb5 100644 --- a/parsefastaqcommand.cpp +++ b/parsefastaqcommand.cpp @@ -16,7 +16,8 @@ vector ParseFastaQCommand::setParameters(){ CommandParameter pfastq("fastq", "InputTypes", "", "", "none", "none", "none",false,true); parameters.push_back(pfastq); CommandParameter pfasta("fasta", "Bool", "", "T", "", "", "",false,false); parameters.push_back(pfasta); CommandParameter pqual("qfile", "Bool", "", "T", "", "", "",false,false); parameters.push_back(pqual); - CommandParameter pinputdir("inputdir", "String", "", "", "", "", "",false,false); parameters.push_back(pinputdir); + CommandParameter pformat("format", "Multiple", "sanger-illumina-solexa", "sanger", "", "", "",false,false); parameters.push_back(pformat); + CommandParameter pinputdir("inputdir", "String", "", "", "", "", "",false,false); parameters.push_back(pinputdir); CommandParameter poutputdir("outputdir", "String", "", "", "", "", "",false,false); parameters.push_back(poutputdir); vector myArray; @@ -33,8 +34,9 @@ string ParseFastaQCommand::getHelpString(){ try { string helpString = ""; helpString += "The fastq.info command reads a fastq file and creates a fasta and quality file.\n"; - helpString += "The fastq.info command parameters are fastq, fasta and qfile; fastq is required.\n"; - helpString += "The fastq.info command should be in the following format: fastq.info(fastaq=yourFastaQFile).\n"; + helpString += "The fastq.info command parameters are fastq, fasta, qfile and format; fastq is required.\n"; + helpString += "The fastq.info command should be in the following format: fastq.info(fastaq=yourFastaQFile).\n"; + helpString += "The format parameter is used to indicate whether your sequences are sanger, solexa or illumina, default=sanger.\n"; helpString += "The fasta parameter allows you to indicate whether you want a fasta file generated. Default=T.\n"; helpString += "The qfile parameter allows you to indicate whether you want a quality file generated. Default=T.\n"; helpString += "Example fastq.info(fastaq=test.fastaq).\n"; @@ -138,6 +140,13 @@ ParseFastaQCommand::ParseFastaQCommand(string option){ temp = validParameter.validFile(parameters, "qfile", false); if(temp == "not found"){ temp = "T"; } qual = m->isTrue(temp); + format = validParameter.validFile(parameters, "format", false); if (format == "not found"){ format = "sanger"; } + + if ((format != "sanger") && (format != "illumina") && (format != "solexa")) { + m->mothurOut(format + " is not a valid format. Your format choices are sanger, solexa and illumina, aborting." ); m->mothurOutEndLine(); + abort=true; + } + if ((!fasta) && (!qual)) { m->mothurOut("[ERROR]: no outputs selected. Aborting."); m->mothurOutEndLine(); abort=true; } } @@ -163,6 +172,12 @@ int ParseFastaQCommand::execute(){ ifstream in; m->openInputFile(fastaQFile, in); + + //fill convert table - goes from solexa to sanger. Used fq_all2std.pl as a reference. + for (int i = -64; i < 65; i++) { + char temp = (char) ((int)(33 + 10*log(1+pow(10,(i/10.0)))/log(10)+0.499)); + convertTable.push_back(temp); + } while (!in.eof()) { @@ -238,12 +253,18 @@ vector ParseFastaQCommand::convertQual(string qual) { try { vector qualScores; - int controlChar = int('@'); - for (int i = 0; i < qual.length(); i++) { - int temp = int(qual[i]); - temp -= controlChar; - + + int temp = 0; + temp = int(qual[i]); + if (format == "illumina") { + temp -= 64; //char '@' + }else if (format == "solexa") { + temp = int(convertTable[temp]); //convert to sanger + temp -= 33; //char '!' + }else { + temp -= 33; //char '!' + } qualScores.push_back(temp); } diff --git a/parsefastaqcommand.h b/parsefastaqcommand.h index 4481b98..96fcb7d 100644 --- a/parsefastaqcommand.h +++ b/parsefastaqcommand.h @@ -34,10 +34,11 @@ public: private: vector outputNames; - string outputDir, fastaQFile; + string outputDir, fastaQFile, format; bool abort, fasta, qual; vector convertQual(string); + vector convertTable; }; #endif diff --git a/pcrseqscommand.h b/pcrseqscommand.h index baeca4e..d35850c 100644 --- a/pcrseqscommand.h +++ b/pcrseqscommand.h @@ -15,6 +15,7 @@ #include "trimoligos.h" #include "alignment.hpp" #include "needlemanoverlap.hpp" +#include "counttable.h" class PcrSeqsCommand : public Command { public: @@ -45,7 +46,7 @@ private: vector lines; bool getOligos(vector >&, vector >&, vector >&); bool abort, keepprimer, keepdots; - string fastafile, oligosfile, taxfile, groupfile, namefile, ecolifile, outputDir, nomatch; + string fastafile, oligosfile, taxfile, groupfile, namefile, countfile, ecolifile, outputDir, nomatch; int start, end, processors, length; vector revPrimer, outputNames; @@ -55,6 +56,7 @@ private: int readName(set&); int readGroup(set); int readTax(set); + int readCount(set); bool readOligos(); bool readEcoli(); int driverPcr(string, string, string, set&, linePair); diff --git a/prcseqscommand.cpp b/prcseqscommand.cpp index 6b73d44..de2cb20 100644 --- a/prcseqscommand.cpp +++ b/prcseqscommand.cpp @@ -13,8 +13,9 @@ vector PcrSeqsCommand::setParameters(){ try { CommandParameter pfasta("fasta", "InputTypes", "", "", "none", "none", "none",false,true); parameters.push_back(pfasta); CommandParameter poligos("oligos", "InputTypes", "", "", "ecolioligos", "none", "none",false,false); parameters.push_back(poligos); - CommandParameter pname("name", "InputTypes", "", "", "none", "none", "none",false,false); parameters.push_back(pname); - CommandParameter pgroup("group", "InputTypes", "", "", "none", "none", "none",false,false); parameters.push_back(pgroup); + CommandParameter pname("name", "InputTypes", "", "", "NameCount", "none", "none",false,false); parameters.push_back(pname); + CommandParameter pcount("count", "InputTypes", "", "", "NameCount-CountGroup", "none", "none",false,false); parameters.push_back(pcount); + CommandParameter pgroup("group", "InputTypes", "", "", "CountGroup", "none", "none",false,false); parameters.push_back(pgroup); CommandParameter ptax("taxonomy", "InputTypes", "", "", "none", "none", "none",false,false); parameters.push_back(ptax); CommandParameter pecoli("ecoli", "InputTypes", "", "", "ecolioligos", "none", "none",false,false); parameters.push_back(pecoli); CommandParameter pstart("start", "Number", "", "-1", "", "", "",false,false); parameters.push_back(pstart); @@ -40,7 +41,7 @@ string PcrSeqsCommand::getHelpString(){ try { string helpString = ""; helpString += "The pcr.seqs command reads a fasta file.\n"; - helpString += "The pcr.seqs command parameters are fasta, oligos, name, group, taxonomy, ecoli, start, end, nomatch, processors, keepprimer and keepdots.\n"; + helpString += "The pcr.seqs command parameters are fasta, oligos, name, group, count, taxonomy, ecoli, start, end, nomatch, processors, keepprimer and keepdots.\n"; helpString += "The ecoli parameter is used to provide a fasta file containing a single reference sequence (e.g. for e. coli) this must be aligned. Mothur will trim to the start and end positions of the reference sequence.\n"; helpString += "The start parameter allows you to provide a starting position to trim to.\n"; helpString += "The end parameter allows you to provide a ending position to trim from.\n"; @@ -72,6 +73,7 @@ string PcrSeqsCommand::getOutputFileNameTag(string type, string inputName=""){ else if (type == "taxonomy") { outputFileName = "pcr" + m->getExtension(inputName); } else if (type == "group") { outputFileName = "pcr" + m->getExtension(inputName); } else if (type == "name") { outputFileName = "pcr" + m->getExtension(inputName); } + else if (type == "count") { outputFileName = "pcr" + m->getExtension(inputName); } else if (type == "accnos") { outputFileName = "bad.accnos"; } else { m->mothurOut("[ERROR]: No definition for type " + type + " output file tag.\n"); m->control_pressed = true; } } @@ -93,6 +95,7 @@ PcrSeqsCommand::PcrSeqsCommand(){ outputTypes["taxonomy"] = tempOutNames; outputTypes["group"] = tempOutNames; outputTypes["name"] = tempOutNames; + outputTypes["count"] = tempOutNames; outputTypes["accnos"] = tempOutNames; } catch(exception& e) { @@ -132,6 +135,7 @@ PcrSeqsCommand::PcrSeqsCommand(string option) { outputTypes["group"] = tempOutNames; outputTypes["name"] = tempOutNames; outputTypes["accnos"] = tempOutNames; + outputTypes["count"] = tempOutNames; //if the user changes the input directory command factory will send this info to us in the output parameter string inputDir = validParameter.validFile(parameters, "inputdir", false); @@ -185,6 +189,14 @@ PcrSeqsCommand::PcrSeqsCommand(string option) { //if the user has not given a path then, add inputdir. else leave path alone. if (path == "") { parameters["group"] = inputDir + it->second; } } + + it = parameters.find("count"); + //user has given a template file + if(it != parameters.end()){ + path = m->hasPath(it->second); + //if the user has not given a path then, add inputdir. else leave path alone. + if (path == "") { parameters["count"] = inputDir + it->second; } + } } @@ -229,6 +241,19 @@ PcrSeqsCommand::PcrSeqsCommand(string option) { else if(groupfile == "not open"){ groupfile = ""; abort = true; } else { m->setGroupFile(groupfile); } + countfile = validParameter.validFile(parameters, "count", true); + if (countfile == "not open") { countfile = ""; abort = true; } + else if (countfile == "not found") { countfile = ""; } + else { m->setCountTableFile(countfile); } + + if ((namefile != "") && (countfile != "")) { + m->mothurOut("[ERROR]: you may only use one of the following: name or count."); m->mothurOutEndLine(); abort = true; + } + + if ((groupfile != "") && (countfile != "")) { + m->mothurOut("[ERROR]: you may only use one of the following: group or count."); m->mothurOutEndLine(); abort=true; + } + taxfile = validParameter.validFile(parameters, "taxonomy", true); if (taxfile == "not found"){ taxfile = ""; } else if(taxfile == "not open"){ taxfile = ""; abort = true; } @@ -265,10 +290,12 @@ PcrSeqsCommand::PcrSeqsCommand(string option) { } //check to make sure you didn't forget the name file by mistake - if (namefile == "") { - vector files; files.push_back(fastafile); - parser.getNameFile(files); - } + if (countfile == "") { + if (namefile == "") { + vector files; files.push_back(fastafile); + parser.getNameFile(files); + } + } } } @@ -339,7 +366,9 @@ int PcrSeqsCommand::execute(){ if (m->control_pressed) { for (int i = 0; i < outputNames.size(); i++) { m->mothurRemove(outputNames[i]); } return 0; } if (taxfile != "") { readTax(badNames); } if (m->control_pressed) { for (int i = 0; i < outputNames.size(); i++) { m->mothurRemove(outputNames[i]); } return 0; } - + if (countfile != "") { readCount(badNames); } + if (m->control_pressed) { for (int i = 0; i < outputNames.size(); i++) { m->mothurRemove(outputNames[i]); } return 0; } + m->mothurOutEndLine(); m->mothurOut("Output File Names: "); m->mothurOutEndLine(); for (int i = 0; i < outputNames.size(); i++) { m->mothurOut(outputNames[i]); m->mothurOutEndLine(); } @@ -373,6 +402,11 @@ int PcrSeqsCommand::execute(){ if ((itTypes->second).size() != 0) { current = (itTypes->second)[0]; m->setTaxonomyFile(current); } } + itTypes = outputTypes.find("count"); + if (itTypes != outputTypes.end()) { + if ((itTypes->second).size() != 0) { current = (itTypes->second)[0]; m->setCountTableFile(current); } + } + m->mothurOut("It took " + toString(time(NULL) - start) + " secs to screen " + toString(numFastaSeqs) + " sequences."); m->mothurOutEndLine(); @@ -1087,6 +1121,63 @@ int PcrSeqsCommand::readTax(set names){ exit(1); } } +//*************************************************************************************************************** +int PcrSeqsCommand::readCount(set badSeqNames){ + try { + ifstream in; + m->openInputFile(countfile, in); + set::iterator it; + + string goodCountFile = outputDir + m->getRootName(m->getSimpleName(countfile)) + getOutputFileNameTag("count", countfile); + outputNames.push_back(goodCountFile); outputTypes["count"].push_back(goodCountFile); + ofstream goodCountOut; m->openOutputFile(goodCountFile, goodCountOut); + + string headers = m->getline(in); m->gobble(in); + goodCountOut << headers << endl; + + string name, rest; int thisTotal, removedCount; removedCount = 0; + bool wroteSomething = false; + while (!in.eof()) { + + if (m->control_pressed) { goodCountOut.close(); in.close(); m->mothurRemove(goodCountFile); return 0; } + + in >> name; m->gobble(in); + in >> thisTotal; m->gobble(in); + rest = m->getline(in); m->gobble(in); + + if (badSeqNames.count(name) != 0) { removedCount+=thisTotal; } + else{ + wroteSomething = true; + goodCountOut << name << '\t' << thisTotal << '\t' << rest << endl; + } + } + in.close(); + goodCountOut.close(); + + if (m->control_pressed) { m->mothurRemove(goodCountFile); } + + if (wroteSomething == false) { m->mothurOut("Your count file contains only sequences from the .accnos file."); m->mothurOutEndLine(); } + + //check for groups that have been eliminated + CountTable ct; + if (ct.testGroups(goodCountFile)) { + ct.readTable(goodCountFile); + ct.printTable(goodCountFile); + } + + if (m->control_pressed) { m->mothurRemove(goodCountFile); } + + m->mothurOut("Removed " + toString(removedCount) + " sequences from your count file."); m->mothurOutEndLine(); + + + return 0; + + } + catch(exception& e) { + m->errorOut(e, "PcrSeqsCommand", "readCOunt"); + exit(1); + } +} /**************************************************************************************/ diff --git a/randomforest.cpp b/randomforest.cpp new file mode 100644 index 0000000..36a2c1a --- /dev/null +++ b/randomforest.cpp @@ -0,0 +1,156 @@ +// +// randomforest.cpp +// Mothur +// +// Created by Sarah Westcott on 10/2/12. +// Copyright (c) 2012 Schloss Lab. All rights reserved. +// + +#include "randomforest.hpp" + +/***********************************************************************/ + +RandomForest::RandomForest(const vector > dataSet,const int numDecisionTrees, + const string treeSplitCriterion = "informationGain") : AbstractRandomForest(dataSet, numDecisionTrees, treeSplitCriterion) { + m = MothurOut::getInstance(); +} + +/***********************************************************************/ +// DONE +int RandomForest::calcForrestErrorRate() { + try { + int numCorrect = 0; + for (map >::iterator it = globalOutOfBagEstimates.begin(); it != globalOutOfBagEstimates.end(); it++) { + + if (m->control_pressed) { return 0; } + + int indexOfSample = it->first; + vector predictedOutComes = it->second; + vector::iterator maxPredictedOutComeIterator = max_element(predictedOutComes.begin(), predictedOutComes.end()); + int majorityVotedOutcome = (int)(maxPredictedOutComeIterator - predictedOutComes.begin()); + int realOutcome = dataSet[indexOfSample][numFeatures]; + + if (majorityVotedOutcome == realOutcome) { numCorrect++; } + } + + // TODO: save or return forrestErrorRate for future use; + double forrestErrorRate = 1 - ((double)numCorrect / (double)globalOutOfBagEstimates.size()); + + m->mothurOut("numCorrect = " + toString(numCorrect)+ "\n"); + m->mothurOut("forrestErrorRate = " + toString(forrestErrorRate)+ "\n"); + + return 0; + } + catch(exception& e) { + m->errorOut(e, "RandomForest", "calcForrestErrorRate"); + exit(1); + } +} + +/***********************************************************************/ +// DONE +int RandomForest::calcForrestVariableImportance(string filename) { + try { + + // TODO: need to add try/catch operators to fix this + // follow the link: http://en.wikipedia.org/wiki/Dynamic_cast + //if you are going to dynamically cast, aren't you undoing the advantage of abstraction. Why abstract at all? + //could cause maintenance issues later if other types of Abstract decison trees are created that cannot be cast as a decision tree. + for (int i = 0; i < decisionTrees.size(); i++) { + if (m->control_pressed) { return 0; } + DecisionTree* decisionTree = dynamic_cast(decisionTrees[i]); + + for (int j = 0; j < numFeatures; j++) { + globalVariableImportanceList[j] += (double)decisionTree->variableImportanceList[j]; + } + } + + for (int i = 0; i < numFeatures; i++) { + cout << "[" << i << ',' << globalVariableImportanceList[i] << "], "; + globalVariableImportanceList[i] /= (double)numDecisionTrees; + } + + vector< vector > globalVariableRanks; + for (int i = 0; i < globalVariableImportanceList.size(); i++) { + if (globalVariableImportanceList[i] > 0) { + vector globalVariableRank(2, 0); + globalVariableRank[0] = i; globalVariableRank[1] = globalVariableImportanceList[i]; + globalVariableRanks.push_back(globalVariableRank); + } + } + + VariableRankDescendingSorterDouble variableRankDescendingSorter; + sort(globalVariableRanks.begin(), globalVariableRanks.end(), variableRankDescendingSorter); + ofstream out; + m->openOutputFile(filename, out); + out <<"OTU\tRank\n"; + for (int i = 0; i < globalVariableRanks.size(); i++) { + out << m->currentBinLabels[(int)globalVariableRanks[i][0]] << '\t' << globalVariableImportanceList[globalVariableRanks[i][0]] << endl; + } + out.close(); + return 0; + } + catch(exception& e) { + m->errorOut(e, "RandomForest", "calcForrestVariableImportance"); + exit(1); + } +} +/***********************************************************************/ +// DONE +int RandomForest::populateDecisionTrees() { + try { + + for (int i = 0; i < numDecisionTrees; i++) { + if (m->control_pressed) { return 0; } + if (((i+1) % 10) == 0) { m->mothurOut("Creating " + toString(i+1) + " (th) Decision tree\n"); } + // TODO: need to first fix if we are going to use pointer based system or anything else + DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector("log2"), treeSplitCriterion); + decisionTree->calcTreeVariableImportanceAndError(); + if (m->control_pressed) { return 0; } + updateGlobalOutOfBagEstimates(decisionTree); + if (m->control_pressed) { return 0; } + decisionTree->purgeDataSetsFromTree(); + if (m->control_pressed) { return 0; } + decisionTrees.push_back(decisionTree); + } + + if (m->debug) { + // m->mothurOut("globalOutOfBagEstimates = " + toStringVectorMap(globalOutOfBagEstimates)+ "\n"); + } + + return 0; + } + catch(exception& e) { + m->errorOut(e, "RandomForest", "populateDecisionTrees"); + exit(1); + } +} +/***********************************************************************/ +// TODO: need to finalize bettween reference and pointer for DecisionTree [partially solved] +// TODO: make this pure virtual in superclass +// DONE +int RandomForest::updateGlobalOutOfBagEstimates(DecisionTree* decisionTree) { + try { + for (map::iterator it = decisionTree->outOfBagEstimates.begin(); it != decisionTree->outOfBagEstimates.end(); it++) { + + if (m->control_pressed) { return 0; } + + int indexOfSample = it->first; + int predictedOutcomeOfSample = it->second; + + if (globalOutOfBagEstimates.count(indexOfSample) == 0) { + globalOutOfBagEstimates[indexOfSample] = vector(decisionTree->numOutputClasses, 0); + }; + + globalOutOfBagEstimates[indexOfSample][predictedOutcomeOfSample] += 1; + } + return 0; + } + catch(exception& e) { + m->errorOut(e, "RandomForest", "updateGlobalOutOfBagEstimates"); + exit(1); + } +} +/***********************************************************************/ + + diff --git a/randomforest.hpp b/randomforest.hpp new file mode 100755 index 0000000..716d1a1 --- /dev/null +++ b/randomforest.hpp @@ -0,0 +1,45 @@ +// +// randomforest.hpp +// rrf-fs-prototype +// +// Created by Abu Zaher Faridee on 7/20/12. +// Copyright (c) 2012 Schloss Lab. All rights reserved. +// + +#ifndef rrf_fs_prototype_randomforest_hpp +#define rrf_fs_prototype_randomforest_hpp + +#include "macros.h" +#include "abstractrandomforest.hpp" +#include "decisiontree.hpp" + +class RandomForest: public AbstractRandomForest { + +public: + + // DONE + RandomForest(const vector > dataSet,const int numDecisionTrees, const string); + + + //NOTE:: if you are going to dynamically cast, aren't you undoing the advantage of abstraction. Why abstract at all? + //could cause maintenance issues later if other types of Abstract decison trees are created that cannot be cast as a decision tree. + virtual ~RandomForest() { + for (vector::iterator it = decisionTrees.begin(); it != decisionTrees.end(); it++) { + // we know that this is decision tree, so we can do a dynamic_case here + DecisionTree* decisionTree = dynamic_cast(*it); + // calling the destructor by deleting + delete decisionTree; + } + } + + int calcForrestErrorRate(); + int calcForrestVariableImportance(string); + int populateDecisionTrees(); + int updateGlobalOutOfBagEstimates(DecisionTree* decisionTree); + +private: + MothurOut* m; + +}; + +#endif diff --git a/removerarecommand.cpp b/removerarecommand.cpp index 923ca72..ded26bb 100644 --- a/removerarecommand.cpp +++ b/removerarecommand.cpp @@ -20,7 +20,8 @@ vector RemoveRareCommand::setParameters(){ CommandParameter prabund("rabund", "InputTypes", "", "", "none", "none", "none",false,false); parameters.push_back(prabund); CommandParameter psabund("sabund", "InputTypes", "", "", "none", "none", "none",false,false); parameters.push_back(psabund); CommandParameter pshared("shared", "InputTypes", "", "", "none", "none", "none",false,false); parameters.push_back(pshared); - CommandParameter pgroup("group", "InputTypes", "", "", "none", "none", "none",false,true); parameters.push_back(pgroup); + CommandParameter pcount("count", "InputTypes", "", "", "CountGroup", "none", "none",false,false); parameters.push_back(pcount); + CommandParameter pgroup("group", "InputTypes", "", "", "CountGroup", "none", "none",false,false); parameters.push_back(pgroup); CommandParameter pgroups("groups", "String", "", "", "", "", "",false,false); parameters.push_back(pgroups); CommandParameter plabel("label", "String", "", "", "", "", "",false,false); parameters.push_back(plabel); CommandParameter pnseqs("nseqs", "Number", "", "0", "", "", "",false,true); parameters.push_back(pnseqs); @@ -41,7 +42,7 @@ vector RemoveRareCommand::setParameters(){ string RemoveRareCommand::getHelpString(){ try { string helpString = ""; - helpString += "The remove.rare command parameters are list, rabund, sabund, shared, group, label, groups, bygroup and nseqs.\n"; + helpString += "The remove.rare command parameters are list, rabund, sabund, shared, group, count, label, groups, bygroup and nseqs.\n"; helpString += "The remove.rare command reads one of the following file types: list, rabund, sabund or shared file. It outputs a new file after removing the rare otus.\n"; helpString += "The groups parameter allows you to specify which of the groups you would like analyzed. Default=all. You may separate group names with dashes.\n"; helpString += "The label parameter is used to analyze specific labels in your input. default=all. You may separate label names with dashes.\n"; @@ -72,6 +73,7 @@ string RemoveRareCommand::getOutputFileNameTag(string type, string inputName="") else if (type == "sabund") { outputFileName = "pick" + m->getExtension(inputName); } else if (type == "shared") { outputFileName = "pick" + m->getExtension(inputName); } else if (type == "group") { outputFileName = "pick" + m->getExtension(inputName); } + else if (type == "count") { outputFileName = "pick" + m->getExtension(inputName); } else if (type == "list") { outputFileName = "pick" + m->getExtension(inputName); } else { m->mothurOut("[ERROR]: No definition for type " + type + " output file tag.\n"); m->control_pressed = true; } } @@ -93,6 +95,7 @@ RemoveRareCommand::RemoveRareCommand(){ outputTypes["sabund"] = tempOutNames; outputTypes["list"] = tempOutNames; outputTypes["group"] = tempOutNames; + outputTypes["count"] = tempOutNames; outputTypes["shared"] = tempOutNames; } catch(exception& e) { @@ -131,6 +134,7 @@ RemoveRareCommand::RemoveRareCommand(string option) { outputTypes["list"] = tempOutNames; outputTypes["group"] = tempOutNames; outputTypes["shared"] = tempOutNames; + outputTypes["count"] = tempOutNames; //if the user changes the output directory command factory will send this info to us in the output parameter outputDir = validParameter.validFile(parameters, "outputdir", false); if (outputDir == "not found"){ outputDir = ""; } @@ -179,6 +183,14 @@ RemoveRareCommand::RemoveRareCommand(string option) { //if the user has not given a path then, add inputdir. else leave path alone. if (path == "") { parameters["shared"] = inputDir + it->second; } } + + it = parameters.find("count"); + //user has given a template file + if(it != parameters.end()){ + path = m->hasPath(it->second); + //if the user has not given a path then, add inputdir. else leave path alone. + if (path == "") { parameters["count"] = inputDir + it->second; } + } } @@ -207,6 +219,15 @@ RemoveRareCommand::RemoveRareCommand(string option) { if (sharedfile == "not open") { sharedfile = ""; abort = true; } else if (sharedfile == "not found") { sharedfile = ""; } else { m->setSharedFile(sharedfile); } + + countfile = validParameter.validFile(parameters, "count", true); + if (countfile == "not open") { countfile = ""; abort = true; } + else if (countfile == "not found") { countfile = ""; } + else { m->setCountTableFile(countfile); } + + if ((groupfile != "") && (countfile != "")) { + m->mothurOut("[ERROR]: you may only use one of the following: group or count."); m->mothurOutEndLine(); abort=true; + } if ((sharedfile == "") && (listfile == "") && (rabundfile == "") && (sabundfile == "")) { //is there are current file available for any of these? @@ -252,7 +273,7 @@ RemoveRareCommand::RemoveRareCommand(string option) { if (byGroup && (sharedfile == "")) { m->mothurOut("The byGroup parameter is only valid with a shared file."); m->mothurOutEndLine(); } - if ((groupfile != "") && (listfile == "")) { m->mothurOut("A groupfile is only valid with a list file."); m->mothurOutEndLine(); groupfile = ""; } + if (((groupfile != "") || (countfile != "")) && (listfile == "")) { m->mothurOut("A group or count file is only valid with a list file."); m->mothurOutEndLine(); groupfile = ""; countfile = ""; } } } @@ -310,6 +331,11 @@ int RemoveRareCommand::execute(){ if (itTypes != outputTypes.end()) { if ((itTypes->second).size() != 0) { current = (itTypes->second)[0]; m->setSharedFile(current); } } + + itTypes = outputTypes.find("count"); + if (itTypes != outputTypes.end()) { + if ((itTypes->second).size() != 0) { current = (itTypes->second)[0]; m->setCountTableFile(current); } + } } return 0; @@ -327,7 +353,9 @@ int RemoveRareCommand::processList(){ string thisOutputDir = outputDir; if (outputDir == "") { thisOutputDir += m->hasPath(listfile); } string outputFileName = thisOutputDir + m->getRootName(m->getSimpleName(listfile)) + getOutputFileNameTag("list", listfile); - string outputGroupFileName = thisOutputDir + m->getRootName(m->getSimpleName(groupfile)) + getOutputFileNameTag("group", groupfile); + string outputGroupFileName = thisOutputDir + m->getRootName(m->getSimpleName(groupfile)) + getOutputFileNameTag("group", groupfile); + string outputCountFileName = thisOutputDir + m->getRootName(m->getSimpleName(countfile)) + getOutputFileNameTag("count", countfile); + ofstream out, outGroup; m->openOutputFile(outputFileName, out); @@ -374,13 +402,21 @@ int RemoveRareCommand::processList(){ //if groupfile is given then use it GroupMap* groupMap; + CountTable ct; if (groupfile != "") { groupMap = new GroupMap(groupfile); groupMap->readMap(); SharedUtil util; vector namesGroups = groupMap->getNamesOfGroups(); util.setGroups(Groups, namesGroups); m->openOutputFile(outputGroupFileName, outGroup); - } + }else if (countfile != "") { + ct.readTable(countfile); + if (ct.hasGroupInfo()) { + vector namesGroups = ct.getNamesOfGroups(); + SharedUtil util; + util.setGroups(Groups, namesGroups); + } + } if (list != NULL) { @@ -397,6 +433,7 @@ int RemoveRareCommand::processList(){ vector names; string saveBinNames = binnames; m->splitAtComma(binnames, names); + int binsize = names.size(); vector newGroupFile; if (groupfile != "") { @@ -412,14 +449,38 @@ int RemoveRareCommand::processList(){ saveBinNames += names[k] + ","; } } - names = newNames; + names = newNames; binsize = names.size(); saveBinNames = saveBinNames.substr(0, saveBinNames.length()-1); - } + }else if (countfile != "") { + saveBinNames = ""; + binsize = 0; + for(int k = 0; k < names.size(); k++) { + if (ct.hasGroupInfo()) { + vector thisSeqsGroups = ct.getGroups(names[k]); + + int thisSeqsCount = 0; + for (int n = 0; n < thisSeqsGroups.size(); n++) { + if (m->inUsersGroups(thisSeqsGroups[n], Groups)) { + thisSeqsCount += ct.getGroupCount(names[k], thisSeqsGroups[n]); + } + } + binsize += thisSeqsCount; + //if you don't have any seqs from the groups the user wants, then remove you. + if (thisSeqsCount == 0) { newGroupFile.push_back(names[k]); } + else { saveBinNames += names[k] + ","; } + }else { + binsize += ct.getNumSeqs(names[k]); + saveBinNames += names[k] + ","; + } + } + saveBinNames = saveBinNames.substr(0, saveBinNames.length()-1); + } - if (names.size() > nseqs) { //keep bin + if (binsize > nseqs) { //keep bin newList.push_back(saveBinNames); - for(int k = 0; k < newGroupFile.size(); k++) { outGroup << newGroupFile[k] << endl; } - } + if (groupfile != "") { for(int k = 0; k < newGroupFile.size(); k++) { outGroup << newGroupFile[k] << endl; } } + else if (countfile != "") { for(int k = 0; k < newGroupFile.size(); k++) { ct.remove(newGroupFile[k]); } } + }else { if (countfile != "") { for(int k = 0; k < names.size(); k++) { ct.remove(names[k]); } } } } //print new listvector @@ -431,6 +492,17 @@ int RemoveRareCommand::processList(){ out.close(); if (groupfile != "") { outGroup.close(); outputTypes["group"].push_back(outputGroupFileName); outputNames.push_back(outputGroupFileName); } + if (countfile != "") { + if (ct.hasGroupInfo()) { + vector allGroups = ct.getNamesOfGroups(); + for (int i = 0; i < allGroups.size(); i++) { + if (!m->inUsersGroups(allGroups[i], Groups)) { ct.removeGroup(allGroups[i]); } + } + + } + ct.printTable(outputCountFileName); + outputTypes["count"].push_back(outputCountFileName); outputNames.push_back(outputCountFileName); + } if (wroteSomething == false) { m->mothurOut("Your file contains only rare sequences."); m->mothurOutEndLine(); } outputTypes["list"].push_back(outputFileName); outputNames.push_back(outputFileName); diff --git a/removerarecommand.h b/removerarecommand.h index 2d70ba7..7b4c6fb 100644 --- a/removerarecommand.h +++ b/removerarecommand.h @@ -36,7 +36,7 @@ public: void help() { m->mothurOut(getHelpString()); } private: - string sabundfile, rabundfile, sharedfile, groupfile, listfile, outputDir, groups, label; + string sabundfile, rabundfile, sharedfile, groupfile, countfile, listfile, outputDir, groups, label; int nseqs, allLines; bool abort, byGroup; vector outputNames, Groups; diff --git a/rftreenode.cpp b/rftreenode.cpp new file mode 100644 index 0000000..170cfb1 --- /dev/null +++ b/rftreenode.cpp @@ -0,0 +1,92 @@ +// +// rftreenode.cpp +// Mothur +// +// Created by Sarah Westcott on 10/2/12. +// Copyright (c) 2012 Schloss Lab. All rights reserved. +// + +#include "rftreenode.hpp" + +/***********************************************************************/ +RFTreeNode::RFTreeNode(vector< vector > bootstrappedTrainingSamples, + vector globalDiscardedFeatureIndices, + int numFeatures, + int numSamples, + int numOutputClasses, + int generation) + +: bootstrappedTrainingSamples(bootstrappedTrainingSamples), +globalDiscardedFeatureIndices(globalDiscardedFeatureIndices), +numFeatures(numFeatures), +numSamples(numSamples), +numOutputClasses(numOutputClasses), +generation(generation), +isLeaf(false), +outputClass(-1), +splitFeatureIndex(-1), +splitFeatureValue(-1), +splitFeatureEntropy(-1.0), +ownEntropy(-1.0), +bootstrappedFeatureVectors(numFeatures, vector(numSamples, 0)), +bootstrappedOutputVector(numSamples, 0), +leftChildNode(NULL), +rightChildNode(NULL), +parentNode(NULL) { + m = MothurOut::getInstance(); + + for (int i = 0; i < numSamples; i++) { // just doing a simple transpose of the matrix + if (m->control_pressed) { break; } + for (int j = 0; j < numFeatures; j++) { bootstrappedFeatureVectors[j][i] = bootstrappedTrainingSamples[i][j]; } + } + + for (int i = 0; i < numSamples; i++) { if (m->control_pressed) { break; } bootstrappedOutputVector[i] = bootstrappedTrainingSamples[i][numFeatures]; } + + createLocalDiscardedFeatureList(); + updateNodeEntropy(); +} +/***********************************************************************/ +int RFTreeNode::createLocalDiscardedFeatureList(){ + try { + + for (int i = 0; i < numFeatures; i++) { + if (m->control_pressed) { return 0; } + vector::iterator it = find(globalDiscardedFeatureIndices.begin(), globalDiscardedFeatureIndices.end(), i); + if (it == globalDiscardedFeatureIndices.end()){ // NOT FOUND + double standardDeviation = m->getStandardDeviation(bootstrappedFeatureVectors[i]); + if (standardDeviation <= 0){ localDiscardedFeatureIndices.push_back(i); } + } + } + + return 0; + } + catch(exception& e) { + m->errorOut(e, "RFTreeNode", "createLocalDiscardedFeatureList"); + exit(1); + } +} +/***********************************************************************/ +int RFTreeNode::updateNodeEntropy() { + try { + + vector classCounts(numOutputClasses, 0); + for (int i = 0; i < bootstrappedOutputVector.size(); i++) { classCounts[bootstrappedOutputVector[i]]++; } + int totalClassCounts = accumulate(classCounts.begin(), classCounts.end(), 0); + double nodeEntropy = 0.0; + for (int i = 0; i < classCounts.size(); i++) { + if (m->control_pressed) { return 0; } + if (classCounts[i] == 0) continue; + double probability = (double)classCounts[i] / (double)totalClassCounts; + nodeEntropy += -(probability * log2(probability)); + } + ownEntropy = nodeEntropy; + + return 0; + } + catch(exception& e) { + m->errorOut(e, "RFTreeNode", "updateNodeEntropy"); + exit(1); + } +} + +/***********************************************************************/ diff --git a/rftreenode.hpp b/rftreenode.hpp new file mode 100755 index 0000000..8987ebc --- /dev/null +++ b/rftreenode.hpp @@ -0,0 +1,91 @@ +// +// rftreenode.hpp +// rrf-fs-prototype +// +// Created by Abu Zaher Faridee on 5/29/12. +// Copyright (c) 2012 Schloss Lab. All rights reserved. +// + +#ifndef rrf_fs_prototype_treenode_hpp +#define rrf_fs_prototype_treenode_hpp + +#include "mothurout.h" +#include "macros.h" + +class RFTreeNode{ + +public: + + RFTreeNode(vector< vector > bootstrappedTrainingSamples, vector globalDiscardedFeatureIndices, int numFeatures, int numSamples, int numOutputClasses, int generation); + + virtual ~RFTreeNode(){} + + // getters + // we need to return const reference so that we have the actual value and not a copy, + // plus we do not modify the value as well + const int getSplitFeatureIndex() { return splitFeatureIndex; } + // TODO: check if this works properly or returs a shallow copy of the data + const vector< vector >& getBootstrappedTrainingSamples() { return bootstrappedTrainingSamples; } + const int getSplitFeatureValue() { return splitFeatureValue; } + const int getGeneration() { return generation; } + const bool checkIsLeaf() { return isLeaf; } + // TODO: fix this const pointer dillema + // we do not want to modify the data pointer by getLeftChildNode + RFTreeNode* getLeftChildNode() { return leftChildNode; } + RFTreeNode* getRightChildNode() { return rightChildNode; } + const int getOutputClass() { return outputClass; } + const int getNumSamples() { return numSamples; } + const int getNumFeatures() { return numFeatures; } + const vector& getLocalDiscardedFeatureIndices() { return localDiscardedFeatureIndices; } + const vector< vector >& getBootstrappedFeatureVectors() { return bootstrappedFeatureVectors; } + const vector& getBootstrappedOutputVector() { return bootstrappedOutputVector; } + const vector& getFeatureSubsetIndices() { return featureSubsetIndices; } + const double getOwnEntropy() { return ownEntropy; } + + // setters + void setIsLeaf(bool isLeaf) { this->isLeaf = isLeaf; } + void setOutputClass(int outputClass) { this->outputClass = outputClass; } + void setFeatureSubsetIndices(vector featureSubsetIndices) { this->featureSubsetIndices = featureSubsetIndices; } + void setLeftChildNode(RFTreeNode* leftChildNode) { this->leftChildNode = leftChildNode; } + void setRightChildNode(RFTreeNode* rightChildNode) { this->rightChildNode = rightChildNode; } + void setParentNode(RFTreeNode* parentNode) { this->parentNode = parentNode; } + void setSplitFeatureIndex(int splitFeatureIndex) { this->splitFeatureIndex = splitFeatureIndex; } + void setSplitFeatureValue(int splitFeatureValue) { this->splitFeatureValue = splitFeatureValue; } + void setSplitFeatureEntropy(double splitFeatureEntropy) { this->splitFeatureEntropy = splitFeatureEntropy; } + + // TODO: need to remove this mechanism of friend class + //NOTE: friend classes can be useful for testing purposes, but I would avoid using them otherwise. + friend class DecisionTree; + friend class AbstractDecisionTree; + +private: + vector > bootstrappedTrainingSamples; + vector globalDiscardedFeatureIndices; + vector localDiscardedFeatureIndices; + vector > bootstrappedFeatureVectors; + vector bootstrappedOutputVector; + vector featureSubsetIndices; + + int numFeatures; + int numSamples; + int numOutputClasses; + int generation; + bool isLeaf; + int outputClass; + int splitFeatureIndex; + int splitFeatureValue; + double splitFeatureEntropy; + double ownEntropy; + + RFTreeNode* leftChildNode; + RFTreeNode* rightChildNode; + RFTreeNode* parentNode; + + MothurOut* m; + + int createLocalDiscardedFeatureList(); + int updateNodeEntropy(); + +}; + +#endif diff --git a/screenseqscommand.cpp b/screenseqscommand.cpp index 312f475..2b5ebc1 100644 --- a/screenseqscommand.cpp +++ b/screenseqscommand.cpp @@ -1039,7 +1039,7 @@ int ScreenSeqsCommand::screenCountFile(set badSeqNames){ //we were unable to remove some of the bad sequences if (badSeqNames.size() != 0) { for (it = badSeqNames.begin(); it != badSeqNames.end(); it++) { - m->mothurOut("Your groupfile does not include the sequence " + *it + " please correct."); + m->mothurOut("Your count file does not include the sequence " + *it + " please correct."); m->mothurOutEndLine(); } } diff --git a/sharedcommand.cpp b/sharedcommand.cpp index 3980106..542f8d3 100644 --- a/sharedcommand.cpp +++ b/sharedcommand.cpp @@ -188,7 +188,11 @@ SharedCommand::SharedCommand(string option) { countfile = validParameter.validFile(parameters, "count", true); if (countfile == "not open") { countfile = ""; abort = true; } else if (countfile == "not found") { countfile = ""; } - else { m->setCountTableFile(countfile); } + else { + m->setCountTableFile(countfile); + CountTable temp; + if (!temp.testGroups(countfile)) { m->mothurOut("[ERROR]: Your count file does not have group info, aborting."); m->mothurOutEndLine(); abort=true; } + } if ((biomfile == "") && (listfile == "")) { //is there are current file available for either of these? diff --git a/sharedrabundvector.h b/sharedrabundvector.h index 792543e..419d15a 100644 --- a/sharedrabundvector.h +++ b/sharedrabundvector.h @@ -24,7 +24,6 @@ An individual which knows the OTU from which it came, the group it is in and its abundance. */ -//class GlobalData; class SharedRAbundVector : public DataVector { diff --git a/sortseqscommand.cpp b/sortseqscommand.cpp index ee7bf73..b0af154 100644 --- a/sortseqscommand.cpp +++ b/sortseqscommand.cpp @@ -15,8 +15,9 @@ vector SortSeqsCommand::setParameters(){ try { CommandParameter pfasta("fasta", "InputTypes", "", "", "none", "FNGLT", "none",false,false); parameters.push_back(pfasta); CommandParameter pflow("flow", "InputTypes", "", "", "none", "FNGLT", "none",false,false); parameters.push_back(pflow); - CommandParameter pname("name", "InputTypes", "", "", "none", "FNGLT", "none",false,false); parameters.push_back(pname); - CommandParameter pgroup("group", "InputTypes", "", "", "none", "FNGLT", "none",false,false); parameters.push_back(pgroup); + CommandParameter pname("name", "InputTypes", "", "", "NameCount", "FNGLT", "none",false,false); parameters.push_back(pname); + CommandParameter pcount("count", "InputTypes", "", "", "NameCount-CountGroup", "FNGLT", "none",false,false); parameters.push_back(pcount); + CommandParameter pgroup("group", "InputTypes", "", "", "CountGroup", "FNGLT", "none",false,false); parameters.push_back(pgroup); CommandParameter ptaxonomy("taxonomy", "InputTypes", "", "", "none", "FNGLT", "none",false,false); parameters.push_back(ptaxonomy); CommandParameter pqfile("qfile", "InputTypes", "", "", "none", "FNGLT", "none",false,false); parameters.push_back(pqfile); CommandParameter plarge("large", "Boolean", "", "F", "", "", "",false,false); parameters.push_back(plarge); @@ -37,8 +38,8 @@ vector SortSeqsCommand::setParameters(){ string SortSeqsCommand::getHelpString(){ try { string helpString = ""; - helpString += "The sort.seqs command puts the sequences in the same order for the following file types: accnos fasta, name, group, taxonomy, flow or quality file.\n"; - helpString += "The sort.seqs command parameters are accnos, fasta, name, group, taxonomy, flow, qfile and large.\n"; + helpString += "The sort.seqs command puts the sequences in the same order for the following file types: accnos fasta, name, group, count, taxonomy, flow or quality file.\n"; + helpString += "The sort.seqs command parameters are accnos, fasta, name, group, count, taxonomy, flow, qfile and large.\n"; helpString += "The accnos file allows you to specify the order you want the files in. If none is provided, mothur will use the order of the first file it reads.\n"; helpString += "The large parameters is used to indicate your files are too large to fit in RAM.\n"; helpString += "The sort.seqs command should be in the following format: sort.seqs(fasta=yourFasta).\n"; @@ -65,6 +66,7 @@ string SortSeqsCommand::getOutputFileNameTag(string type, string inputName=""){ if (type == "fasta") { outputFileName = "sorted" + m->getExtension(inputName); } else if (type == "taxonomy") { outputFileName = "sorted" + m->getExtension(inputName); } else if (type == "name") { outputFileName = "sorted" + m->getExtension(inputName); } + else if (type == "count") { outputFileName = "sorted" + m->getExtension(inputName); } else if (type == "group") { outputFileName = "sorted" + m->getExtension(inputName); } else if (type == "flow") { outputFileName = "sorted" + m->getExtension(inputName); } else if (type == "qfile") { outputFileName = "sorted" + m->getExtension(inputName); } @@ -87,6 +89,7 @@ SortSeqsCommand::SortSeqsCommand(){ outputTypes["fasta"] = tempOutNames; outputTypes["taxonomy"] = tempOutNames; outputTypes["name"] = tempOutNames; + outputTypes["count"] = tempOutNames; outputTypes["group"] = tempOutNames; outputTypes["qfile"] = tempOutNames; outputTypes["flow"] = tempOutNames; @@ -127,6 +130,7 @@ SortSeqsCommand::SortSeqsCommand(string option) { outputTypes["group"] = tempOutNames; outputTypes["qfile"] = tempOutNames; outputTypes["flow"] = tempOutNames; + outputTypes["count"] = tempOutNames; //if the user changes the output directory command factory will send this info to us in the output parameter outputDir = validParameter.validFile(parameters, "outputdir", false); if (outputDir == "not found"){ outputDir = ""; } @@ -191,6 +195,14 @@ SortSeqsCommand::SortSeqsCommand(string option) { //if the user has not given a path then, add inputdir. else leave path alone. if (path == "") { parameters["flow"] = inputDir + it->second; } } + + it = parameters.find("count"); + //user has given a template file + if(it != parameters.end()){ + path = m->hasPath(it->second); + //if the user has not given a path then, add inputdir. else leave path alone. + if (path == "") { parameters["count"] = inputDir + it->second; } + } } @@ -229,16 +241,31 @@ SortSeqsCommand::SortSeqsCommand(string option) { if (qualfile == "not open") { abort = true; } else if (qualfile == "not found") { qualfile = ""; } else { m->setQualFile(qualfile); } + + countfile = validParameter.validFile(parameters, "count", true); + if (countfile == "not open") { countfile = ""; abort = true; } + else if (countfile == "not found") { countfile = ""; } + else { m->setCountTableFile(countfile); } + + if ((namefile != "") && (countfile != "")) { + m->mothurOut("[ERROR]: you may only use one of the following: name or count."); m->mothurOutEndLine(); abort = true; + } + + if ((groupfile != "") && (countfile != "")) { + m->mothurOut("[ERROR]: you may only use one of the following: group or count."); m->mothurOutEndLine(); abort=true; + } string temp = validParameter.validFile(parameters, "large", false); if (temp == "not found") { temp = "f"; } large = m->isTrue(temp); - if ((fastafile == "") && (namefile == "") && (groupfile == "") && (taxfile == "") && (flowfile == "") && (qualfile == "")) { m->mothurOut("You must provide at least one of the following: fasta, name, group, taxonomy, flow or quality."); m->mothurOutEndLine(); abort = true; } + if ((fastafile == "") && (namefile == "") && (countfile == "") && (groupfile == "") && (taxfile == "") && (flowfile == "") && (qualfile == "")) { m->mothurOut("You must provide at least one of the following: fasta, name, group, count, taxonomy, flow or quality."); m->mothurOutEndLine(); abort = true; } - if ((fastafile != "") && (namefile == "")) { - vector files; files.push_back(fastafile); - parser.getNameFile(files); - } + if (countfile == "") { + if ((fastafile != "") && (namefile == "")) { + vector files; files.push_back(fastafile); + parser.getNameFile(files); + } + } } } @@ -267,6 +294,7 @@ int SortSeqsCommand::execute(){ if (qualfile != "") { readQual(); } if (namefile != "") { readName(); } if (groupfile != "") { readGroup(); } + if (countfile != "") { readCount(); } if (taxfile != "") { readTax(); } if (m->control_pressed) { for (int i = 0; i < outputNames.size(); i++) { m->mothurRemove(outputNames[i]); } return 0; } @@ -308,7 +336,12 @@ int SortSeqsCommand::execute(){ itTypes = outputTypes.find("flow"); if (itTypes != outputTypes.end()) { if ((itTypes->second).size() != 0) { current = (itTypes->second)[0]; m->setFlowFile(current); } - } + } + + itTypes = outputTypes.find("count"); + if (itTypes != outputTypes.end()) { + if ((itTypes->second).size() != 0) { current = (itTypes->second)[0]; m->setCountTableFile(current); } + } } return 0; @@ -927,7 +960,88 @@ int SortSeqsCommand::readName(){ exit(1); } } - +//********************************************************************************************************************** +int SortSeqsCommand::readCount(){ + try { + string thisOutputDir = outputDir; + if (outputDir == "") { thisOutputDir += m->hasPath(countfile); } + string outputFileName = thisOutputDir + m->getRootName(m->getSimpleName(countfile)) + getOutputFileNameTag("count", countfile); + outputTypes["count"].push_back(outputFileName); outputNames.push_back(outputFileName); + + ofstream out; + m->openOutputFile(outputFileName, out); + + ifstream in; + m->openInputFile(countfile, in); + string firstCol, rest; + + if (names.size() != 0) {//this is not the first file we are reading so we need to use the order we already have + + vector seqs; seqs.resize(names.size(), ""); + + string headers = m->getline(in); m->gobble(in); + + while(!in.eof()){ + if (m->control_pressed) { in.close(); out.close(); m->mothurRemove(outputFileName); return 0; } + + in >> firstCol; m->gobble(in); + rest = m->getline(in); m->gobble(in); + + if (firstCol != "") { + map::iterator it = names.find(firstCol); + if (it != names.end()) { //we found it, so put it in the vector in the right place. + seqs[it->second] = firstCol + '\t' + rest; + }else { //if we cant find it then add it to the end + names[firstCol] = seqs.size(); + seqs.push_back((firstCol + '\t' + rest)); + m->mothurOut(firstCol + " was not in the contained the file which determined the order, adding it to the end.\n"); + } + } + } + in.close(); + + int count = 0; + out << headers << endl; + for (int i = 0; i < seqs.size(); i++) { + if (seqs[i] != "") { out << seqs[i] << endl; count++; } + } + out.close(); + + m->mothurOut("Ordered " + toString(count) + " sequences from " + countfile + ".\n"); + + }else { //read in file to fill names + int count = 0; + + string headers = m->getline(in); m->gobble(in); + out << headers << endl; + + while(!in.eof()){ + if (m->control_pressed) { in.close(); out.close(); m->mothurRemove(outputFileName); return 0; } + + in >> firstCol; m->gobble(in); + rest = m->getline(in); m->gobble(in); + + if (firstCol != "") { + //if this name is in the accnos file + names[firstCol] = count; + count++; + out << firstCol << '\t' << rest << endl; + } + m->gobble(in); + } + in.close(); + out.close(); + + m->mothurOut("\nUsing " + countfile + " to determine the order. It contains " + toString(count) + " representative sequences.\n"); + } + + return 0; + } + catch(exception& e) { + m->errorOut(e, "SortSeqsCommand", "readCount"); + exit(1); + } +} //********************************************************************************************************************** int SortSeqsCommand::readGroup(){ try { diff --git a/sortseqscommand.h b/sortseqscommand.h index 6d9c5ed..4ba8e42 100644 --- a/sortseqscommand.h +++ b/sortseqscommand.h @@ -36,7 +36,7 @@ public: private: map names; - string accnosfile, fastafile, namefile, groupfile, taxfile, qualfile, flowfile, outputDir; + string accnosfile, fastafile, namefile, groupfile, countfile, taxfile, qualfile, flowfile, outputDir; bool abort, large; vector outputNames; @@ -45,6 +45,7 @@ private: int readName(); int readGroup(); int readTax(); + int readCount(); int readQual(); }; diff --git a/splitabundcommand.cpp b/splitabundcommand.cpp index bc1cdb3..48fada8 100644 --- a/splitabundcommand.cpp +++ b/splitabundcommand.cpp @@ -8,13 +8,15 @@ */ #include "splitabundcommand.h" +#include "sharedutilities.h" //********************************************************************************************************************** vector SplitAbundCommand::setParameters(){ try { CommandParameter pfasta("fasta", "InputTypes", "", "", "none", "none", "none",false,true); parameters.push_back(pfasta); - CommandParameter pname("name", "InputTypes", "", "", "none", "FNGLT", "none",false,false); parameters.push_back(pname); - CommandParameter pgroup("group", "InputTypes", "", "", "none", "none", "none",false,false); parameters.push_back(pgroup); + CommandParameter pname("name", "InputTypes", "", "", "NameCount", "FNGLT", "none",false,false); parameters.push_back(pname); + CommandParameter pcount("count", "InputTypes", "", "", "NameCount-CountGroup", "none", "none",false,false); parameters.push_back(pcount); + CommandParameter pgroup("group", "InputTypes", "", "", "CountGroup", "none", "none",false,false); parameters.push_back(pgroup); CommandParameter plist("list", "InputTypes", "", "", "none", "FNGLT", "none",false,false); parameters.push_back(plist); CommandParameter plabel("label", "String", "", "", "", "", "",false,false); parameters.push_back(plabel); CommandParameter pcutoff("cutoff", "Number", "", "0", "", "", "",false,true); parameters.push_back(pcutoff); @@ -37,8 +39,8 @@ string SplitAbundCommand::getHelpString(){ try { string helpString = ""; helpString += "The split.abund command reads a fasta file and a list or a names file splits the sequences into rare and abundant groups. \n"; - helpString += "The split.abund command parameters are fasta, list, name, cutoff, group, label, groups, cutoff and accnos.\n"; - helpString += "The fasta and a list or name parameter are required, and you must provide a cutoff value.\n"; + helpString += "The split.abund command parameters are fasta, list, name, count, cutoff, group, label, groups, cutoff and accnos.\n"; + helpString += "The fasta and a list or name or count parameter are required, and you must provide a cutoff value.\n"; helpString += "The cutoff parameter is used to qualify what is abundant and rare.\n"; helpString += "The group parameter allows you to parse a group file into rare and abundant groups.\n"; helpString += "The label parameter is used to read specific labels in your listfile you want to use.\n"; @@ -69,6 +71,7 @@ string SplitAbundCommand::getOutputFileNameTag(string type, string inputName="") if (type == "fasta") { outputFileName = "fasta"; } else if (type == "list") { outputFileName = "list"; } else if (type == "name") { outputFileName = "names"; } + else if (type == "count") { outputFileName = "count_table"; } else if (type == "group") { outputFileName = "groups"; } else if (type == "accnos") { outputFileName = "accnos"; } else { m->mothurOut("[ERROR]: No definition for type " + type + " output file tag.\n"); m->control_pressed = true; } @@ -88,6 +91,7 @@ SplitAbundCommand::SplitAbundCommand(){ vector tempOutNames; outputTypes["list"] = tempOutNames; outputTypes["name"] = tempOutNames; + outputTypes["count"] = tempOutNames; outputTypes["accnos"] = tempOutNames; outputTypes["group"] = tempOutNames; outputTypes["fasta"] = tempOutNames; @@ -126,7 +130,8 @@ SplitAbundCommand::SplitAbundCommand(string option) { outputTypes["name"] = tempOutNames; outputTypes["accnos"] = tempOutNames; outputTypes["group"] = tempOutNames; - outputTypes["fasta"] = tempOutNames; + outputTypes["fasta"] = tempOutNames; + outputTypes["count"] = tempOutNames; //if the user changes the input directory command factory will send this info to us in the output parameter string inputDir = validParameter.validFile(parameters, "inputdir", false); @@ -165,6 +170,13 @@ SplitAbundCommand::SplitAbundCommand(string option) { if (path == "") { parameters["name"] = inputDir + it->second; } } + it = parameters.find("count"); + //user has given a template file + if(it != parameters.end()){ + path = m->hasPath(it->second); + //if the user has not given a path then, add inputdir. else leave path alone. + if (path == "") { parameters["count"] = inputDir + it->second; } + } } @@ -194,35 +206,52 @@ SplitAbundCommand::SplitAbundCommand(string option) { if (groupfile == "not open") { groupfile = ""; abort = true; } else if (groupfile == "not found") { groupfile = ""; } else { - groupMap = new GroupMap(groupfile); - - int error = groupMap->readMap(); + int error = groupMap.readMap(groupfile); if (error == 1) { abort = true; } m->setGroupFile(groupfile); } + countfile = validParameter.validFile(parameters, "count", true); + if (countfile == "not open") { countfile = ""; abort = true; } + else if (countfile == "not found") { countfile = ""; } + else { + m->setCountTableFile(countfile); + ct.readTable(countfile); + } + + if ((namefile != "") && (countfile != "")) { + m->mothurOut("[ERROR]: you may only use one of the following: name or count."); m->mothurOutEndLine(); abort = true; + } + + if ((groupfile != "") && (countfile != "")) { + m->mothurOut("[ERROR]: you may only use one of the following: group or count."); m->mothurOutEndLine(); abort=true; + } + groups = validParameter.validFile(parameters, "groups", false); if (groups == "not found") { groups = ""; } - else if (groups == "all") { - if (groupfile != "") { Groups = groupMap->getNamesOfGroups(); } - else { m->mothurOut("You cannot select groups without a valid groupfile, I will disregard your groups selection. "); m->mothurOutEndLine(); groups = ""; } - }else { - m->splitAtDash(groups, Groups); - } + else { m->splitAtDash(groups, Groups); } - if ((groupfile == "") && (groups != "")) { m->mothurOut("You cannot select groups without a valid groupfile, I will disregard your groups selection. "); m->mothurOutEndLine(); groups = ""; Groups.clear(); } + if (((groupfile == "") && (countfile == ""))&& (groups != "")) { m->mothurOut("You cannot select groups without a valid group or count file, I will disregard your groups selection. "); m->mothurOutEndLine(); groups = ""; Groups.clear(); } + if (countfile != "") { + if (!ct.hasGroupInfo()) { m->mothurOut("You cannot pick groups without group info in your count file; I will disregard your groups selection."); m->mothurOutEndLine(); groups = ""; Groups.clear(); } + } + //do you have all files needed - if ((listfile == "") && (namefile == "")) { + if ((listfile == "") && (namefile == "") && (countfile == "")) { namefile = m->getNameFile(); if (namefile != "") { m->mothurOut("Using " + namefile + " as input file for the name parameter."); m->mothurOutEndLine(); } else { listfile = m->getListFile(); if (listfile != "") { m->mothurOut("Using " + listfile + " as input file for the list parameter."); m->mothurOutEndLine(); } - else { m->mothurOut("You have no current list or namefile and the list or name parameter is required."); m->mothurOutEndLine(); abort = true; } + else { + countfile = m->getCountTableFile(); + if (countfile != "") { m->mothurOut("Using " + countfile + " as input file for the count parameter."); m->mothurOutEndLine(); } + else { m->mothurOut("You have no current list, count or namefile and one is required."); m->mothurOutEndLine(); abort = true; } + } } } - + //check for optional parameter and set defaults // ...at some point should added some additional type checking... label = validParameter.validFile(parameters, "label", false); @@ -248,14 +277,20 @@ SplitAbundCommand::SplitAbundCommand(string option) { } } //********************************************************************************************************************** -SplitAbundCommand::~SplitAbundCommand(){ - if (groupfile != "") { delete groupMap; } -} +SplitAbundCommand::~SplitAbundCommand(){} //********************************************************************************************************************** int SplitAbundCommand::execute(){ try { if (abort == true) { if (calledHelp) { return 0; } return 2; } + + if (Groups.size() != 0) { + vector allGroups; + if (countfile != "") { allGroups = ct.getNamesOfGroups(); } + else { allGroups = groupMap.getNamesOfGroups(); } + SharedUtil util; + util.setGroups(Groups, allGroups); + } if (listfile != "") { //you are using a listfile to determine abundance if (outputDir == "") { outputDir = m->hasPath(listfile); } @@ -264,19 +299,19 @@ int SplitAbundCommand::execute(){ set processedLabels; set userLabels = labels; - input = new InputData(listfile, "list"); - list = input->getListVector(); + InputData input(listfile, "list"); + ListVector* list = input.getListVector(); string lastLabel = list->getLabel(); //do you have a namefile or do we need to similate one? if (namefile != "") { readNamesFile(); } else { createNameMap(list); } - if (m->control_pressed) { delete input; delete list; for (int i = 0; i < outputNames.size(); i++) { m->mothurRemove(outputNames[i]); } return 0; } + if (m->control_pressed) { delete list; for (int i = 0; i < outputNames.size(); i++) { m->mothurRemove(outputNames[i]); } return 0; } while((list != NULL) && ((allLines == 1) || (userLabels.size() != 0))) { - if (m->control_pressed) { delete input; delete list; for (int i = 0; i < outputNames.size(); i++) { m->mothurRemove(outputNames[i]); } return 0; } + if (m->control_pressed) { delete list; for (int i = 0; i < outputNames.size(); i++) { m->mothurRemove(outputNames[i]); } return 0; } if(allLines == 1 || labels.count(list->getLabel()) == 1){ @@ -291,7 +326,7 @@ int SplitAbundCommand::execute(){ string saveLabel = list->getLabel(); delete list; - list = input->getListVector(lastLabel); //get new list vector to process + list = input.getListVector(lastLabel); //get new list vector to process m->mothurOut(list->getLabel()); m->mothurOutEndLine(); splitList(list); @@ -307,10 +342,10 @@ int SplitAbundCommand::execute(){ lastLabel = list->getLabel(); delete list; - list = input->getListVector(); //get new list vector to process + list = input.getListVector(); //get new list vector to process } - if (m->control_pressed) { delete input; for (int i = 0; i < outputNames.size(); i++) { m->mothurRemove(outputNames[i]); } return 0; } + if (m->control_pressed) { for (int i = 0; i < outputNames.size(); i++) { m->mothurRemove(outputNames[i]); } return 0; } //output error messages about any remaining user labels set::iterator it; @@ -326,12 +361,12 @@ int SplitAbundCommand::execute(){ } - if (m->control_pressed) { delete input; for (int i = 0; i < outputNames.size(); i++) { m->mothurRemove(outputNames[i]); } return 0; } + if (m->control_pressed) { for (int i = 0; i < outputNames.size(); i++) { m->mothurRemove(outputNames[i]); } return 0; } //run last label if you need to if (needToRun == true) { if (list != NULL) { delete list; } - list = input->getListVector(lastLabel); //get new list vector to process + list = input.getListVector(lastLabel); //get new list vector to process m->mothurOut(list->getLabel()); m->mothurOutEndLine(); splitList(list); @@ -339,11 +374,9 @@ int SplitAbundCommand::execute(){ delete list; } - delete input; - if (m->control_pressed) { for (int i = 0; i < outputNames.size(); i++) { m->mothurRemove(outputNames[i]); } return 0; } - }else { //you are using the namefile to determine abundance + }else if (namefile != "") { //you are using the namefile to determine abundance if (outputDir == "") { outputDir = m->hasPath(namefile); } splitNames(); @@ -353,7 +386,14 @@ int SplitAbundCommand::execute(){ if (groupfile != "") { parseGroup(tag); } if (accnos) { writeAccnos(tag); } if (fastafile != "") { parseFasta(tag); } - } + }else { + //split by countfile + string tag = ""; + splitCount(); + + if (accnos) { writeAccnos(tag); } + if (fastafile != "") { parseFasta(tag); } + } //set fasta file as new current fastafile string current = ""; @@ -381,6 +421,11 @@ int SplitAbundCommand::execute(){ if (itTypes != outputTypes.end()) { if ((itTypes->second).size() != 0) { current = (itTypes->second)[0]; m->setAccnosFile(current); } } + + itTypes = outputTypes.find("count"); + if (itTypes != outputTypes.end()) { + if ((itTypes->second).size() != 0) { current = (itTypes->second)[0]; m->setCountTableFile(current); } + } m->mothurOutEndLine(); m->mothurOut("Output File Names: "); m->mothurOutEndLine(); @@ -401,6 +446,7 @@ int SplitAbundCommand::splitList(ListVector* thisList) { abundNames.clear(); //get rareNames and abundNames + int numRareBins = 0; for (int i = 0; i < thisList->getNumBins(); i++) { if (m->control_pressed) { return 0; } @@ -409,8 +455,15 @@ int SplitAbundCommand::splitList(ListVector* thisList) { vector names; m->splitAtComma(bin, names); //parses bin into individual sequence names int size = names.size(); - + + //if countfile is not blank we assume the list file is unique, otherwise we assume it includes all seqs + if (countfile != "") { + size = 0; + for (int j = 0; j < names.size(); j++) { size += ct.getNumSeqs(names[j]); } + } + if (size <= cutoff) { + numRareBins++; for (int j = 0; j < names.size(); j++) { rareNames.insert(names[j]); } }else{ for (int j = 0; j < names.size(); j++) { abundNames.insert(names[j]); } @@ -419,13 +472,14 @@ int SplitAbundCommand::splitList(ListVector* thisList) { string tag = thisList->getLabel() + "."; - - writeList(thisList, tag); - + + writeList(thisList, tag, numRareBins); + if (groupfile != "") { parseGroup(tag); } if (accnos) { writeAccnos(tag); } if (fastafile != "") { parseFasta(tag); } - + if (countfile != "") { parseCount(tag); } + return 0; } @@ -435,24 +489,13 @@ int SplitAbundCommand::splitList(ListVector* thisList) { } } /**********************************************************************************************************************/ -int SplitAbundCommand::writeList(ListVector* thisList, string tag) { +int SplitAbundCommand::writeList(ListVector* thisList, string tag, int numRareBins) { try { map filehandles; if (Groups.size() == 0) { - SAbundVector* sabund = new SAbundVector(); - *sabund = thisList->getSAbundVector(); - - //find out how many bins are rare and how many are abundant so you can process the list vector one bin at a time - // and don't have to store the bins until you are done with the whole vector, this save alot of space. - int numRareBins = 0; - for (int i = 0; i <= sabund->getMaxRank(); i++) { - if (i > cutoff) { break; } - numRareBins += sabund->get(i); - } int numAbundBins = thisList->getNumBins() - numRareBins; - delete sabund; ofstream aout; ofstream rout; @@ -471,9 +514,15 @@ int SplitAbundCommand::writeList(ListVector* thisList, string tag) { for (int i = 0; i < thisList->getNumBins(); i++) { if (m->control_pressed) { break; } - string bin = list->get(i); - - int size = m->getNumNames(bin); + string bin = thisList->get(i); + vector names; + m->splitAtComma(bin, names); + + int size = names.size(); + if (countfile != "") { + size = 0; + for (int j = 0; j < names.size(); j++) { size += ct.getNumSeqs(names[j]); } + } if (size <= cutoff) { rout << bin << '\t'; } else { aout << bin << '\t'; } @@ -499,8 +548,8 @@ int SplitAbundCommand::writeList(ListVector* thisList, string tag) { temp2 = new ofstream; filehandles[Groups[i]+".abund"] = temp2; - string rareGroupFileName = fileroot + Groups[i] + tag + ".rare." + getOutputFileNameTag("list"); - string abundGroupFileName = fileroot + Groups[i] + tag + ".abund." + getOutputFileNameTag("list"); + string rareGroupFileName = fileroot + Groups[i] +"."+ tag + "rare." + getOutputFileNameTag("list"); + string abundGroupFileName = fileroot + Groups[i] +"."+ tag + "abund." + getOutputFileNameTag("list"); m->openOutputFile(rareGroupFileName, *(filehandles[Groups[i]+".rare"])); m->openOutputFile(abundGroupFileName, *(filehandles[Groups[i]+".abund"])); outputNames.push_back(rareGroupFileName); outputTypes["list"].push_back(rareGroupFileName); @@ -520,7 +569,7 @@ int SplitAbundCommand::writeList(ListVector* thisList, string tag) { if (m->control_pressed) { break; } map groupBins; - string bin = list->get(i); + string bin = thisList->get(i); vector names; m->splitAtComma(bin, names); //parses bin into individual sequence names @@ -534,19 +583,34 @@ int SplitAbundCommand::writeList(ListVector* thisList, string tag) { rareAbund = ".abund"; } - string group = groupMap->getGroup(names[j]); - - if (m->inUsersGroups(group, Groups)) { //only add if this is in a group we want - itGroup = groupBins.find(group+rareAbund); - if(itGroup == groupBins.end()) { - groupBins[group+rareAbund] = names[j]; //add first name - groupNumBins[group+rareAbund]++; - }else{ //add another name - groupBins[group+rareAbund] += "," + names[j]; - } - }else if(group == "not found") { - m->mothurOut(names[j] + " is not in your groupfile. Ignoring."); m->mothurOutEndLine(); - } + if (countfile == "") { + string group = groupMap.getGroup(names[j]); + + if (m->inUsersGroups(group, Groups)) { //only add if this is in a group we want + itGroup = groupBins.find(group+rareAbund); + if(itGroup == groupBins.end()) { + groupBins[group+rareAbund] = names[j]; //add first name + groupNumBins[group+rareAbund]++; + }else{ //add another name + groupBins[group+rareAbund] += "," + names[j]; + } + }else if(group == "not found") { + m->mothurOut(names[j] + " is not in your groupfile. Ignoring."); m->mothurOutEndLine(); + } + }else { + vector thisSeqsGroups = ct.getGroups(names[j]); + for (int k = 0; k < thisSeqsGroups.size(); k++) { + if (m->inUsersGroups(thisSeqsGroups[k], Groups)) { //only add if this is in a group we want + itGroup = groupBins.find(thisSeqsGroups[k]+rareAbund); + if(itGroup == groupBins.end()) { + groupBins[thisSeqsGroups[k]+rareAbund] = names[j]; //add first name + groupNumBins[thisSeqsGroups[k]+rareAbund]++; + }else{ //add another name + groupBins[thisSeqsGroups[k]+rareAbund] += "," + names[j]; + } + } + } + } } @@ -572,6 +636,37 @@ int SplitAbundCommand::writeList(ListVector* thisList, string tag) { } } /**********************************************************************************************************************/ +int SplitAbundCommand::splitCount() { //countfile + try { + rareNames.clear(); + abundNames.clear(); + + vector allNames = ct.getNamesOfSeqs(); + for (int i = 0; i < allNames.size(); i++) { + + if (m->control_pressed) { return 0; } + + int size = ct.getNumSeqs(allNames[i]); + nameMap[allNames[i]] = allNames[i]; + + if (size <= cutoff) { + rareNames.insert(allNames[i]); + }else{ + abundNames.insert(allNames[i]); + } + } + + //write out split count files + parseCount(""); + + return 0; + } + catch(exception& e) { + m->errorOut(e, "SplitAbundCommand", "splitCount"); + exit(1); + } +} +/**********************************************************************************************************************/ int SplitAbundCommand::splitNames() { //namefile try { @@ -658,6 +753,115 @@ int SplitAbundCommand::createNameMap(ListVector* thisList) { } } /**********************************************************************************************************************/ +int SplitAbundCommand::parseCount(string tag) { //namefile + try { + + map filehandles; + + if (Groups.size() == 0) { + string rare = outputDir + m->getRootName(m->getSimpleName(countfile)) + tag + "rare." + getOutputFileNameTag("count"); + outputNames.push_back(rare); outputTypes["count"].push_back(rare); + + string abund = outputDir + m->getRootName(m->getSimpleName(countfile)) + tag + "abund." + getOutputFileNameTag("count"); + outputNames.push_back(abund); outputTypes["count"].push_back(abund); + + CountTable rareTable; + CountTable abundTable; + if (ct.hasGroupInfo()) { + vector ctGroups = ct.getNamesOfGroups(); + for (int i = 0; i < ctGroups.size(); i++) { rareTable.addGroup(ctGroups[i]); abundTable.addGroup(ctGroups[i]); } + } + + if (rareNames.size() != 0) { + for (set::iterator itRare = rareNames.begin(); itRare != rareNames.end(); itRare++) { + if (ct.hasGroupInfo()) { + vector groupCounts = ct.getGroupCounts(*itRare); + rareTable.push_back(*itRare, groupCounts); + }else { + int groupCounts = ct.getNumSeqs(*itRare); + rareTable.push_back(*itRare, groupCounts); + } + } + if (rareTable.hasGroupInfo()) { + vector ctGroups = rareTable.getNamesOfGroups(); + for (int i = 0; i < ctGroups.size(); i++) { + if (rareTable.getGroupCount(ctGroups[i]) == 0) { rareTable.removeGroup(ctGroups[i]); } + } + } + rareTable.printTable(rare); + } + + + if (abundNames.size() != 0) { + for (set::iterator itAbund = abundNames.begin(); itAbund != abundNames.end(); itAbund++) { + if (ct.hasGroupInfo()) { + vector groupCounts = ct.getGroupCounts(*itAbund); + abundTable.push_back(*itAbund, groupCounts); + }else { + int groupCounts = ct.getNumSeqs(*itAbund); + abundTable.push_back(*itAbund, groupCounts); + } + } + if (abundTable.hasGroupInfo()) { + vector ctGroups = abundTable.getNamesOfGroups(); + for (int i = 0; i < ctGroups.size(); i++) { + if (abundTable.getGroupCount(ctGroups[i]) == 0) { abundTable.removeGroup(ctGroups[i]); } + } + } + abundTable.printTable(abund); + } + + }else{ //parse names by abundance and group + map countTableMap; + map::iterator it3; + + for (int i=0; iaddGroup(Groups[i]); + countTableMap[Groups[i]+".rare"] = rareCt; + CountTable* abundCt = new CountTable(); + abundCt->addGroup(Groups[i]); + countTableMap[Groups[i]+".abund"] = abundCt; + } + + vector allNames = ct.getNamesOfSeqs(); + for (int i = 0; i < allNames.size(); i++) { + string rareAbund; + if (rareNames.count(allNames[i]) != 0) { //you are a rare name + rareAbund = ".rare"; + }else{ //you are a abund name + rareAbund = ".abund"; + } + + vector thisSeqsGroups = ct.getGroups(allNames[i]); + for (int j = 0; j < thisSeqsGroups.size(); j++) { + if (m->inUsersGroups(thisSeqsGroups[j], Groups)) { //only add if this is in a group we want + int num = ct.getGroupCount(allNames[i], thisSeqsGroups[j]); + vector nums; nums.push_back(num); + countTableMap[thisSeqsGroups[j]+rareAbund]->push_back(allNames[i], nums); + } + } + } + + + for (it3 = countTableMap.begin(); it3 != countTableMap.end(); it3++) { + string fileroot = outputDir + m->getRootName(m->getSimpleName(countfile)); + string filename = fileroot + it3->first + "." + getOutputFileNameTag("count"); + outputNames.push_back(filename); outputTypes["count"].push_back(filename); + (it3->second)->printTable(filename); + delete it3->second; + } + } + + return 0; + + } + catch(exception& e) { + m->errorOut(e, "SplitAbundCommand", "parseCount"); + exit(1); + } +} +/**********************************************************************************************************************/ int SplitAbundCommand::writeNames() { //namefile try { @@ -723,7 +927,7 @@ int SplitAbundCommand::writeNames() { //namefile map::iterator itout; for (int i = 0; i < names.size(); i++) { - string group = groupMap->getGroup(names[i]); + string group = groupMap.getGroup(names[i]); if (m->inUsersGroups(group, Groups)) { //only add if this is in a group we want itout = outputStrings.find(group+rareAbund); @@ -803,7 +1007,7 @@ int SplitAbundCommand::writeAccnos(string tag) { //write rare for (set::iterator itRare = rareNames.begin(); itRare != rareNames.end(); itRare++) { - string group = groupMap->getGroup(*itRare); + string group = groupMap.getGroup(*itRare); if (m->inUsersGroups(group, Groups)) { //only add if this is in a group we want *(filehandles[group+".rare"]) << *itRare << endl; @@ -812,7 +1016,7 @@ int SplitAbundCommand::writeAccnos(string tag) { //write abund for (set::iterator itAbund = abundNames.begin(); itAbund != abundNames.end(); itAbund++) { - string group = groupMap->getGroup(*itAbund); + string group = groupMap.getGroup(*itAbund); if (m->inUsersGroups(group, Groups)) { //only add if this is in a group we want *(filehandles[group+".abund"]) << *itAbund << endl; @@ -860,7 +1064,7 @@ int SplitAbundCommand::parseGroup(string tag) { //namefile for (int i = 0; i < names.size(); i++) { - string group = groupMap->getGroup(names[i]); + string group = groupMap.getGroup(names[i]); if (group == "not found") { m->mothurOut(names[i] + " is not in your groupfile, ignoring, please correct."); m->mothurOutEndLine(); @@ -907,7 +1111,7 @@ int SplitAbundCommand::parseGroup(string tag) { //namefile for (int i = 0; i < names.size(); i++) { - string group = groupMap->getGroup(names[i]); + string group = groupMap.getGroup(names[i]); if (m->inUsersGroups(group, Groups)) { //only add if this is in a group we want *(filehandles[group+rareAbund]) << names[i] << '\t' << group << endl; @@ -964,7 +1168,7 @@ int SplitAbundCommand::parseFasta(string tag) { //namefile itNames = nameMap.find(seq.getName()); if (itNames == nameMap.end()) { - m->mothurOut(seq.getName() + " is not in your namesfile, ignoring."); m->mothurOutEndLine(); + m->mothurOut(seq.getName() + " is not in your names or list file, ignoring."); m->mothurOutEndLine(); }else{ if (rareNames.count(seq.getName()) != 0) { //you are a rare name seq.printSequence(rout); @@ -1008,7 +1212,7 @@ int SplitAbundCommand::parseFasta(string tag) { //namefile map::iterator itNames = nameMap.find(seq.getName()); if (itNames == nameMap.end()) { - m->mothurOut(seq.getName() + " is not in your namesfile, ignoring."); m->mothurOutEndLine(); + m->mothurOut(seq.getName() + " is not in your names or list file, ignoring."); m->mothurOutEndLine(); }else{ vector names; m->splitAtComma(itNames->second, names); //parses bin into individual sequence names @@ -1019,17 +1223,25 @@ int SplitAbundCommand::parseFasta(string tag) { //namefile }else{ //you are a abund name rareAbund = ".abund"; } - - for (int i = 0; i < names.size(); i++) { - - string group = groupMap->getGroup(seq.getName()); - - if (m->inUsersGroups(group, Groups)) { //only add if this is in a group we want - seq.printSequence(*(filehandles[group+rareAbund])); - }else if(group == "not found") { - m->mothurOut(seq.getName() + " is not in your groupfile. Ignoring."); m->mothurOutEndLine(); - } - } + + if (countfile == "") { + for (int i = 0; i < names.size(); i++) { + string group = groupMap.getGroup(seq.getName()); + + if (m->inUsersGroups(group, Groups)) { //only add if this is in a group we want + seq.printSequence(*(filehandles[group+rareAbund])); + }else if(group == "not found") { + m->mothurOut(seq.getName() + " is not in your groupfile. Ignoring."); m->mothurOutEndLine(); + } + } + }else { + vector thisSeqsGroups = ct.getGroups(names[0]); //we only need names[0], because there is no namefile + for (int i = 0; i < thisSeqsGroups.size(); i++) { + if (m->inUsersGroups(thisSeqsGroups[i], Groups)) { //only add if this is in a group we want + seq.printSequence(*(filehandles[thisSeqsGroups[i]+rareAbund])); + } + } + } } } } diff --git a/splitabundcommand.h b/splitabundcommand.h index 232c36b..d054264 100644 --- a/splitabundcommand.h +++ b/splitabundcommand.h @@ -22,6 +22,7 @@ also allow an option where a user can give a group file with the list or names f #include "inputdata.h" #include "listvector.hpp" #include "sequence.hpp" +#include "counttable.h" /***************************************************************************************/ @@ -47,24 +48,24 @@ private: int splitList(ListVector*); int splitNames(); //namefile int writeNames(); - int writeList(ListVector*, string); + int writeList(ListVector*, string, int); int writeAccnos(string); int parseGroup(string); int parseFasta(string); + int parseCount(string); + int splitCount(); int readNamesFile(); //namefile int createNameMap(ListVector*); vector outputNames; - ListVector* list; - GroupMap* groupMap; - InputData* input; + GroupMap groupMap; + CountTable ct; - string outputDir, listfile, namefile, groupfile, label, groups, fastafile, inputFile; + string outputDir, listfile, namefile, groupfile, countfile, label, groups, fastafile, inputFile; set labels, rareNames, abundNames; vector Groups; bool abort, allLines, accnos; int cutoff; - //map wroteListFile; map nameMap; diff --git a/subsample.cpp b/subsample.cpp index c55accd..392f97b 100644 --- a/subsample.cpp +++ b/subsample.cpp @@ -264,7 +264,164 @@ int SubSample::getSample(SAbundVector*& sabund, int size) { m->errorOut(e, "SubSampleCommand", "getSample"); exit(1); } -} +} +//********************************************************************************************************************** +CountTable SubSample::getSample(CountTable& ct, int size, vector Groups) { + try { + if (!ct.hasGroupInfo()) { m->mothurOut("[ERROR]: Cannot subsample by group because your count table doesn't have group information.\n"); m->control_pressed = true; } + + CountTable sampledCt; + map > tempCount; + for (int i = 0; i < Groups.size(); i++) { + sampledCt.addGroup(Groups[i]); + + vector names = ct.getNamesOfSeqs(Groups[i]); + vector allNames; + for (int j = 0; j < names.size(); j++) { + + if (m->control_pressed) { return sampledCt; } + + int num = ct. getGroupCount(names[j], Groups[i]); + for (int k = 0; k < num; k++) { allNames.push_back(names[j]); } + } + + random_shuffle(allNames.begin(), allNames.end()); + + if (allNames.size() < size) { m->mothurOut("[ERROR]: You have selected a size that is larger than "+Groups[i]+" number of sequences.\n"); m->control_pressed = true; } + else{ + for (int j = 0; j < size; j++) { + + if (m->control_pressed) { return sampledCt; } + + map >::iterator it = tempCount.find(allNames[j]); + + if (it == tempCount.end()) { //we have not seen this sequence at all yet + vector tempGroups; tempGroups.resize(Groups.size(), 0); + tempGroups[i]++; + tempCount[allNames[j]] = tempGroups; + }else{ + tempCount[allNames[j]][i]++; + } + } + } + } + + //build count table + for (map >::iterator it = tempCount.begin(); it != tempCount.end();) { + sampledCt.push_back(it->first, it->second); + tempCount.erase(it++); + } + + return sampledCt; + } + catch(exception& e) { + m->errorOut(e, "SubSampleCommand", "getSample"); + exit(1); + } +} +//********************************************************************************************************************** +CountTable SubSample::getSample(CountTable& ct, int size, vector Groups, bool pickedGroups) { + try { + CountTable sampledCt; + if (!ct.hasGroupInfo() && pickedGroups) { m->mothurOut("[ERROR]: Cannot subsample with groups because your count table doesn't have group information.\n"); m->control_pressed = true; return sampledCt; } + + if (ct.hasGroupInfo()) { + map > tempCount; + vector allNames; + map groupMap; + + vector myGroups; + if (pickedGroups) { myGroups = Groups; } + else { myGroups = ct.getNamesOfGroups(); } + + for (int i = 0; i < myGroups.size(); i++) { + sampledCt.addGroup(myGroups[i]); + groupMap[myGroups[i]] = i; + + vector names = ct.getNamesOfSeqs(myGroups[i]); + for (int j = 0; j < names.size(); j++) { + + if (m->control_pressed) { return sampledCt; } + + int num = ct. getGroupCount(names[j], myGroups[i]); + for (int k = 0; k < num; k++) { + item temp(names[j], myGroups[i]); + allNames.push_back(temp); + } + } + } + + random_shuffle(allNames.begin(), allNames.end()); + + if (allNames.size() < size) { + if (pickedGroups) { m->mothurOut("[ERROR]: You have selected a size that is larger than the number of sequences.\n"); } + else { m->mothurOut("[ERROR]: You have selected a size that is larger than the number of sequences in the groups you chose.\n"); } + m->control_pressed = true; return sampledCt; } + else{ + for (int j = 0; j < size; j++) { + + if (m->control_pressed) { return sampledCt; } + + map >::iterator it = tempCount.find(allNames[j].name); + + if (it == tempCount.end()) { //we have not seen this sequence at all yet + vector tempGroups; tempGroups.resize(myGroups.size(), 0); + tempGroups[groupMap[allNames[j].group]]++; + tempCount[allNames[j].name] = tempGroups; + }else{ + tempCount[allNames[j].name][groupMap[allNames[j].group]]++; + } + } + } + + //build count table + for (map >::iterator it = tempCount.begin(); it != tempCount.end();) { + sampledCt.push_back(it->first, it->second); + tempCount.erase(it++); + } + + //remove empty groups + for (int i = 0; i < myGroups.size(); i++) { if (sampledCt.getGroupCount(myGroups[i]) == 0) { sampledCt.removeGroup(myGroups[i]); } } + + }else { + vector names = ct.getNamesOfSeqs(); + map nameMap; + vector allNames; + + for (int i = 0; i < names.size(); i++) { + int num = ct.getNumSeqs(names[i]); + for (int j = 0; j < num; j++) { allNames.push_back(names[i]); } + } + + if (allNames.size() < size) { m->mothurOut("[ERROR]: You have selected a size that is larger than the number of sequences.\n"); m->control_pressed = true; return sampledCt; } + else { + random_shuffle(allNames.begin(), allNames.end()); + + for (int j = 0; j < size; j++) { + if (m->control_pressed) { return sampledCt; } + + map::iterator it = nameMap.find(allNames[j]); + + //we have not seen this sequence at all yet + if (it == nameMap.end()) { nameMap[allNames[j]] = 1; } + else{ nameMap[allNames[j]]++; } + } + + //build count table + for (map::iterator it = nameMap.begin(); it != nameMap.end();) { + sampledCt.push_back(it->first, it->second); + nameMap.erase(it++); + } + } + } + + return sampledCt; + } + catch(exception& e) { + m->errorOut(e, "SubSampleCommand", "getSample"); + exit(1); + } +} //********************************************************************************************************************** diff --git a/subsample.h b/subsample.h index 55abed9..fdf8576 100644 --- a/subsample.h +++ b/subsample.h @@ -15,6 +15,14 @@ #include "tree.h" #include "counttable.h" +struct item { + string name; + string group; + + item() {} + item(string n, string g) : name(n), group(g) {} + ~item() {} +}; //subsampling overwrites the sharedRabunds. If you need to reuse the original use the getSamplePreserve function. @@ -28,6 +36,8 @@ class SubSample { vector getSample(vector&, int); //returns the bin labels for the subsample, mothurOuts binlabels are preserved so you can run this multiple times. Overwrites original vector passed in, if you need to preserve it deep copy first. Tree* getSample(Tree*, CountTable*, CountTable*, int); //creates new subsampled tree. Uses first counttable to fill new counttable with sabsampled seqs. Sets groups of seqs not in subsample to "doNotIncludeMe". int getSample(SAbundVector*&, int); //destroys sabundvector passed in, so copy it if you need it + CountTable getSample(CountTable&, int, vector); //subsample a countTable bygroup(same number sampled from each group, returns subsampled countTable + CountTable getSample(CountTable&, int, vector, bool); //subsample a countTable. If you want to only sample from specific groups, pass in groups in the vector and set bool=true, otherwise set bool=false. private: diff --git a/subsamplecommand.cpp b/subsamplecommand.cpp index 8c5761d..e1793f4 100644 --- a/subsamplecommand.cpp +++ b/subsamplecommand.cpp @@ -16,8 +16,9 @@ vector SubSampleCommand::setParameters(){ try { CommandParameter pfasta("fasta", "InputTypes", "", "", "none", "FLSSR", "none",false,false); parameters.push_back(pfasta); - CommandParameter pname("name", "InputTypes", "", "", "none", "none", "none",false,false); parameters.push_back(pname); - CommandParameter pgroup("group", "InputTypes", "", "", "none", "none", "none",false,false); parameters.push_back(pgroup); + CommandParameter pname("name", "InputTypes", "", "", "NameCount", "none", "none",false,false); parameters.push_back(pname); + CommandParameter pcount("count", "InputTypes", "", "", "NameCount-CountGroup", "none", "none",false,false); parameters.push_back(pcount); + CommandParameter pgroup("group", "InputTypes", "", "", "CountGroup", "none", "none",false,false); parameters.push_back(pgroup); CommandParameter plist("list", "InputTypes", "", "", "none", "FLSSR", "none",false,false); parameters.push_back(plist); CommandParameter pshared("shared", "InputTypes", "", "", "none", "FLSSR", "none",false,false); parameters.push_back(pshared); CommandParameter prabund("rabund", "InputTypes", "", "", "none", "FLSSR", "none",false,false); parameters.push_back(prabund); @@ -43,7 +44,7 @@ string SubSampleCommand::getHelpString(){ try { string helpString = ""; helpString += "The sub.sample command is designed to be used as a way to normalize your data, or create a smaller set from your original set.\n"; - helpString += "The sub.sample command parameters are fasta, name, list, group, rabund, sabund, shared, groups, size, persample and label. You must provide a fasta, list, sabund, rabund or shared file as an input file.\n"; + helpString += "The sub.sample command parameters are fasta, name, list, group, count, rabund, sabund, shared, groups, size, persample and label. You must provide a fasta, list, sabund, rabund or shared file as an input file.\n"; helpString += "The namefile is only used with the fasta file, not with the listfile, because the list file should contain all sequences.\n"; helpString += "The groups parameter allows you to specify which of the groups in your groupfile you would like included. The group names are separated by dashes.\n"; helpString += "The label parameter allows you to select what distance levels you would like, and are also separated by dashes.\n"; @@ -76,6 +77,7 @@ string SubSampleCommand::getOutputFileNameTag(string type, string inputName=""){ if (type == "fasta") { outputFileName = "subsample" + m->getExtension(inputName); } else if (type == "sabund") { outputFileName = "subsample" + m->getExtension(inputName); } else if (type == "name") { outputFileName = "subsample" + m->getExtension(inputName); } + else if (type == "count") { outputFileName = "subsample" + m->getExtension(inputName); } else if (type == "group") { outputFileName = "subsample" + m->getExtension(inputName); } else if (type == "list") { outputFileName = "subsample" + m->getExtension(inputName); } else if (type == "rabund") { outputFileName = "subsample" + m->getExtension(inputName); } @@ -103,6 +105,7 @@ SubSampleCommand::SubSampleCommand(){ outputTypes["fasta"] = tempOutNames; outputTypes["name"] = tempOutNames; outputTypes["group"] = tempOutNames; + outputTypes["count"] = tempOutNames; } catch(exception& e) { m->errorOut(e, "SubSampleCommand", "GetRelAbundCommand"); @@ -142,6 +145,7 @@ SubSampleCommand::SubSampleCommand(string option) { outputTypes["fasta"] = tempOutNames; outputTypes["name"] = tempOutNames; outputTypes["group"] = tempOutNames; + outputTypes["count"] = tempOutNames; //if the user changes the output directory command factory will send this info to us in the output parameter outputDir = validParameter.validFile(parameters, "outputdir", false); if (outputDir == "not found"){ outputDir = ""; } @@ -206,6 +210,14 @@ SubSampleCommand::SubSampleCommand(string option) { //if the user has not given a path then, add inputdir. else leave path alone. if (path == "") { parameters["name"] = inputDir + it->second; } } + + it = parameters.find("count"); + //user has given a template file + if(it != parameters.end()){ + path = m->hasPath(it->second); + //if the user has not given a path then, add inputdir. else leave path alone. + if (path == "") { parameters["count"] = inputDir + it->second; } + } } //check for required parameters @@ -244,6 +256,22 @@ SubSampleCommand::SubSampleCommand(string option) { else if (groupfile == "not found") { groupfile = ""; } else { m->setGroupFile(groupfile); } + countfile = validParameter.validFile(parameters, "count", true); + if (countfile == "not open") { countfile = ""; abort = true; } + else if (countfile == "not found") { countfile = ""; } + else { + m->setCountTableFile(countfile); + ct.readTable(countfile); + } + + if ((namefile != "") && (countfile != "")) { + m->mothurOut("[ERROR]: you may only use one of the following: name or count."); m->mothurOutEndLine(); abort = true; + } + + if ((groupfile != "") && (countfile != "")) { + m->mothurOut("[ERROR]: you may only use one of the following: group or count."); m->mothurOutEndLine(); abort=true; + } + //check for optional parameter and set defaults // ...at some point should added some additional type checking... label = validParameter.validFile(parameters, "label", false); @@ -267,26 +295,34 @@ SubSampleCommand::SubSampleCommand(string option) { temp = validParameter.validFile(parameters, "persample", false); if (temp == "not found"){ temp = "f"; } persample = m->isTrue(temp); - if (groupfile == "") { persample = false; } + if ((groupfile == "") && (countfile == "")) { persample = false; } + if (countfile != "") { + if (!ct.hasGroupInfo()) { + persample = false; + if (pickedGroups) { m->mothurOut("You cannot pick groups without group info in your count file."); m->mothurOutEndLine(); abort = true; } + } + } if ((namefile != "") && (fastafile == "")) { m->mothurOut("You may only use a namefile with a fastafile."); m->mothurOutEndLine(); abort = true; } if ((fastafile == "") && (listfile == "") && (sabundfile == "") && (rabundfile == "") && (sharedfile == "")) { m->mothurOut("You must provide a fasta, list, sabund, rabund or shared file as an input file."); m->mothurOutEndLine(); abort = true; } - if (pickedGroups && ((groupfile == "") && (sharedfile == ""))) { - m->mothurOut("You cannot pick groups without a valid group file or shared file."); m->mothurOutEndLine(); abort = true; } + if (pickedGroups && ((groupfile == "") && (sharedfile == "") && (countfile == ""))) { + m->mothurOut("You cannot pick groups without a valid group, count or shared file."); m->mothurOutEndLine(); abort = true; } - if ((groupfile != "") && ((fastafile == "") && (listfile == ""))) { - m->mothurOut("Group file only valid with listfile or fastafile."); m->mothurOutEndLine(); abort = true; } + if (((groupfile != "") || (countfile != "")) && ((fastafile == "") && (listfile == ""))) { + m->mothurOut("Group or count files are only valid with listfile or fastafile."); m->mothurOutEndLine(); abort = true; } - if ((groupfile != "") && ((fastafile != "") && (listfile != ""))) { - m->mothurOut("A new group file can only be made from the subsample of a listfile or fastafile, not both. Please correct."); m->mothurOutEndLine(); abort = true; } + if (((groupfile != "") || (countfile != "")) && ((fastafile != "") && (listfile != ""))) { + m->mothurOut("A new group or count file can only be made from the subsample of a listfile or fastafile, not both. Please correct."); m->mothurOutEndLine(); abort = true; } - if ((fastafile != "") && (namefile == "")) { - vector files; files.push_back(fastafile); - parser.getNameFile(files); - } + if (countfile == "") { + if ((fastafile != "") && (namefile == "")) { + vector files; files.push_back(fastafile); + parser.getNameFile(files); + } + } } } @@ -353,6 +389,11 @@ int SubSampleCommand::execute(){ if (itTypes != outputTypes.end()) { if ((itTypes->second).size() != 0) { current = (itTypes->second)[0]; m->setSabundFile(current); } } + + itTypes = outputTypes.find("count"); + if (itTypes != outputTypes.end()) { + if ((itTypes->second).size() != 0) { current = (itTypes->second)[0]; m->setCountTableFile(current); } + } m->mothurOutEndLine(); @@ -374,49 +415,67 @@ int SubSampleCommand::getSubSampleFasta() { if (namefile != "") { readNames(); } //fills names with all names in namefile. else { getNames(); }//no name file, so get list of names to pick from - GroupMap* groupMap; + GroupMap groupMap; if (groupfile != "") { - - groupMap = new GroupMap(groupfile); - groupMap->readMap(); + groupMap.readMap(groupfile); //takes care of user setting groupNames that are invalid or setting groups=all - SharedUtil* util = new SharedUtil(); - vector namesGroups = groupMap->getNamesOfGroups(); - util->setGroups(Groups, namesGroups); - delete util; + SharedUtil util; + vector namesGroups = groupMap.getNamesOfGroups(); + util.setGroups(Groups, namesGroups); //file mismatch quit - if (names.size() != groupMap->getNumSeqs()) { - m->mothurOut("[ERROR]: your fasta file contains " + toString(names.size()) + " sequences, and your groupfile contains " + toString(groupMap->getNumSeqs()) + ", please correct."); + if (names.size() != groupMap.getNumSeqs()) { + m->mothurOut("[ERROR]: your fasta file contains " + toString(names.size()) + " sequences, and your groupfile contains " + toString(groupMap.getNumSeqs()) + ", please correct."); m->mothurOutEndLine(); - delete groupMap; return 0; } - } + }else if (countfile != "") { + if (ct.hasGroupInfo()) { + SharedUtil util; + vector namesGroups = ct.getNamesOfGroups(); + util.setGroups(Groups, namesGroups); + } + + //file mismatch quit + if (names.size() != ct.getNumUniqueSeqs()) { + m->mothurOut("[ERROR]: your fasta file contains " + toString(names.size()) + " sequences, and your count file contains " + toString(ct.getNumUniqueSeqs()) + " unique sequences, please correct."); + m->mothurOutEndLine(); + return 0; + } + } if (m->control_pressed) { return 0; } - //make sure that if your picked groups size is not too big - int thisSize = names.size(); + int thisSize = 0; + if (countfile == "") { thisSize = names.size(); } + else { thisSize = ct. getNumSeqs(); } //all seqs not just unique + if (persample) { if (size == 0) { //user has not set size, set size = smallest samples size - size = groupMap->getNumSeqs(Groups[0]); + if (countfile == "") { size = groupMap.getNumSeqs(Groups[0]); } + else { size = ct.getGroupCount(Groups[0]); } + for (int i = 1; i < Groups.size(); i++) { - int thisSize = groupMap->getNumSeqs(Groups[i]); + int thisSize = 0; + if (countfile == "") { thisSize = groupMap.getNumSeqs(Groups[i]); } + else { thisSize = ct.getGroupCount(Groups[i]); } if (thisSize < size) { size = thisSize; } } }else { //make sure size is not too large vector newGroups; for (int i = 0; i < Groups.size(); i++) { - int thisSize = groupMap->getNumSeqs(Groups[i]); + int thisSize = 0; + if (countfile == "") { thisSize = groupMap.getNumSeqs(Groups[i]); } + else { thisSize = ct.getGroupCount(Groups[i]); } if (thisSize >= size) { newGroups.push_back(Groups[i]); } else { m->mothurOut("You have selected a size that is larger than " + Groups[i] + " number of sequences, removing " + Groups[i] + "."); m->mothurOutEndLine(); } } Groups = newGroups; + if (newGroups.size() == 0) { m->mothurOut("[ERROR]: all groups removed."); m->mothurOutEndLine(); m->control_pressed = true; } } m->mothurOut("Sampling " + toString(size) + " from each group."); m->mothurOutEndLine(); @@ -424,7 +483,8 @@ int SubSampleCommand::getSubSampleFasta() { if (pickedGroups) { int total = 0; for(int i = 0; i < Groups.size(); i++) { - total += groupMap->getNumSeqs(Groups[i]); + if (countfile == "") { total += groupMap.getNumSeqs(Groups[i]); } + else { total += ct.getGroupCount(Groups[i]); } } if (size == 0) { //user has not set size, set size = 10% samples size @@ -442,64 +502,87 @@ int SubSampleCommand::getSubSampleFasta() { } if (size == 0) { //user has not set size, set size = 10% samples size - size = int (names.size() * 0.10); - } - - if (size > thisSize) { m->mothurOut("Your fasta file only contains " + toString(thisSize) + " sequences. Setting size to " + toString(thisSize) + "."); m->mothurOutEndLine(); - size = thisSize; + if (countfile == "") { size = int (names.size() * 0.10); } + else { size = int (ct.getNumSeqs() * 0.10); } } - if (!pickedGroups) { m->mothurOut("Sampling " + toString(size) + " from " + toString(thisSize) + "."); m->mothurOutEndLine(); } + + if (size > thisSize) { m->mothurOut("Your fasta file only contains " + toString(thisSize) + " sequences. Setting size to " + toString(thisSize) + "."); m->mothurOutEndLine(); + size = thisSize; + } + + if (!pickedGroups) { m->mothurOut("Sampling " + toString(size) + " from " + toString(thisSize) + "."); m->mothurOutEndLine(); } } random_shuffle(names.begin(), names.end()); set subset; //dont want repeat sequence names added if (persample) { - //initialize counts - map groupCounts; - map::iterator itGroupCounts; - for (int i = 0; i < Groups.size(); i++) { groupCounts[Groups[i]] = 0; } + if (countfile == "") { + //initialize counts + map groupCounts; + map::iterator itGroupCounts; + for (int i = 0; i < Groups.size(); i++) { groupCounts[Groups[i]] = 0; } - for (int j = 0; j < names.size(); j++) { + for (int j = 0; j < names.size(); j++) { - if (m->control_pressed) { return 0; } + if (m->control_pressed) { return 0; } - string group = groupMap->getGroup(names[j]); - if (group == "not found") { m->mothurOut("[ERROR]: " + names[j] + " is not in your groupfile. please correct."); m->mothurOutEndLine(); group = "NOTFOUND"; } - else{ - itGroupCounts = groupCounts.find(group); - if (itGroupCounts != groupCounts.end()) { - if (groupCounts[group] < size) { subset.insert(names[j]); groupCounts[group]++; } - } - } - } + string group = groupMap.getGroup(names[j]); + if (group == "not found") { m->mothurOut("[ERROR]: " + names[j] + " is not in your groupfile. please correct."); m->mothurOutEndLine(); group = "NOTFOUND"; } + else{ + itGroupCounts = groupCounts.find(group); + if (itGroupCounts != groupCounts.end()) { + if (groupCounts[group] < size) { subset.insert(names[j]); groupCounts[group]++; } + } + } + } + }else { + SubSample sample; + CountTable sampledCt = sample.getSample(ct, size, Groups); + vector sampledSeqs = sampledCt.getNamesOfSeqs(); + for (int i = 0; i < sampledSeqs.size(); i++) { subset.insert(sampledSeqs[i]); } + + string countOutputDir = outputDir; + if (outputDir == "") { countOutputDir += m->hasPath(countfile); } + string countOutputFileName = countOutputDir + m->getRootName(m->getSimpleName(countfile)) + getOutputFileNameTag("count", countfile); + outputTypes["count"].push_back(countOutputFileName); outputNames.push_back(countOutputFileName); + sampledCt.printTable(countOutputFileName); + } }else { - - //randomly select a subset of those names to include in the subsample - //since names was randomly shuffled just grab the next one - for (int j = 0; j < names.size(); j++) { - - if (m->control_pressed) { return 0; } - - if (groupfile != "") { //if there is a groupfile given fill in group info - string group = groupMap->getGroup(names[j]); - if (group == "not found") { m->mothurOut("[ERROR]: " + names[j] + " is not in your groupfile. please correct."); m->mothurOutEndLine(); group = "NOTFOUND"; } - - if (pickedGroups) { //if hte user picked groups, we only want to keep the names of sequences from those groups - if (m->inUsersGroups(group, Groups)) { - subset.insert(names[j]); - } - }else{ - subset.insert(names[j]); - } - }else{ //save everyone, group - subset.insert(names[j]); - } - - //do we have enough?? - if (subset.size() == size) { break; } - } + if (countfile == "") { + //randomly select a subset of those names to include in the subsample + //since names was randomly shuffled just grab the next one + for (int j = 0; j < names.size(); j++) { + + if (m->control_pressed) { return 0; } + + if (groupfile != "") { //if there is a groupfile given fill in group info + string group = groupMap.getGroup(names[j]); + if (group == "not found") { m->mothurOut("[ERROR]: " + names[j] + " is not in your groupfile. please correct."); m->mothurOutEndLine(); group = "NOTFOUND"; } + + if (pickedGroups) { //if hte user picked groups, we only want to keep the names of sequences from those groups + if (m->inUsersGroups(group, Groups)) { subset.insert(names[j]); } + }else{ subset.insert(names[j]); } + }else{ //save everyone, group + subset.insert(names[j]); + } + + //do we have enough?? + if (subset.size() == size) { break; } + } + }else { + SubSample sample; + CountTable sampledCt = sample.getSample(ct, size, Groups, pickedGroups); + vector sampledSeqs = sampledCt.getNamesOfSeqs(); + for (int i = 0; i < sampledSeqs.size(); i++) { subset.insert(sampledSeqs[i]); } + + string countOutputDir = outputDir; + if (outputDir == "") { countOutputDir += m->hasPath(countfile); } + string countOutputFileName = countOutputDir + m->getRootName(m->getSimpleName(countfile)) + getOutputFileNameTag("count", countfile); + outputTypes["count"].push_back(countOutputFileName); outputNames.push_back(countOutputFileName); + sampledCt.printTable(countOutputFileName); + } } if (subset.size() == 0) { m->mothurOut("The size you selected is too large, skipping fasta file."); m->mothurOutEndLine(); return 0; } @@ -858,67 +941,76 @@ int SubSampleCommand::getSubSampleList() { //if the users enters label "0.06" and there is no "0.06" in their file use the next lowest label. set processedLabels; set userLabels = labels; - + ofstream outGroup; - GroupMap* groupMap; + GroupMap groupMap; if (groupfile != "") { - - groupMap = new GroupMap(groupfile); - groupMap->readMap(); + groupMap.readMap(groupfile); //takes care of user setting groupNames that are invalid or setting groups=all - SharedUtil* util = new SharedUtil(); - vector namesGroups = groupMap->getNamesOfGroups(); - util->setGroups(Groups, namesGroups); - delete util; + SharedUtil util; vector namesGroups = groupMap.getNamesOfGroups(); util.setGroups(Groups, namesGroups); //create outputfiles string groupOutputDir = outputDir; if (outputDir == "") { groupOutputDir += m->hasPath(groupfile); } string groupOutputFileName = groupOutputDir + m->getRootName(m->getSimpleName(groupfile)) + "subsample" + m->getExtension(groupfile); - m->openOutputFile(groupOutputFileName, outGroup); outputTypes["group"].push_back(groupOutputFileName); outputNames.push_back(groupOutputFileName); //file mismatch quit - if (list->getNumSeqs() != groupMap->getNumSeqs()) { - m->mothurOut("[ERROR]: your list file contains " + toString(list->getNumSeqs()) + " sequences, and your groupfile contains " + toString(groupMap->getNumSeqs()) + ", please correct."); + if (list->getNumSeqs() != groupMap.getNumSeqs()) { + m->mothurOut("[ERROR]: your list file contains " + toString(list->getNumSeqs()) + " sequences, and your groupfile contains " + toString(groupMap.getNumSeqs()) + ", please correct."); + m->mothurOutEndLine(); delete list; delete input; out.close(); outGroup.close(); return 0; + } + }else if (countfile != "") { + if (ct.hasGroupInfo()) { + SharedUtil util; + vector namesGroups = ct.getNamesOfGroups(); + util.setGroups(Groups, namesGroups); + } + + //file mismatch quit + if (list->getNumSeqs() != ct.getNumUniqueSeqs()) { + m->mothurOut("[ERROR]: your list file contains " + toString(list->getNumSeqs()) + " sequences, and your count file contains " + toString(ct.getNumUniqueSeqs()) + " unique sequences, please correct."); m->mothurOutEndLine(); - delete groupMap; - delete list; - delete input; - out.close(); - outGroup.close(); return 0; - } - } - + } + } + //make sure that if your picked groups size is not too big if (persample) { if (size == 0) { //user has not set size, set size = smallest samples size - size = groupMap->getNumSeqs(Groups[0]); + if (countfile == "") { size = groupMap.getNumSeqs(Groups[0]); } + else { size = ct.getGroupCount(Groups[0]); } + for (int i = 1; i < Groups.size(); i++) { - int thisSize = groupMap->getNumSeqs(Groups[i]); + int thisSize = 0; + if (countfile == "") { thisSize = groupMap.getNumSeqs(Groups[i]); } + else { thisSize = ct.getGroupCount(Groups[i]); } if (thisSize < size) { size = thisSize; } } }else { //make sure size is not too large vector newGroups; for (int i = 0; i < Groups.size(); i++) { - int thisSize = groupMap->getNumSeqs(Groups[i]); + int thisSize = 0; + if (countfile == "") { thisSize = groupMap.getNumSeqs(Groups[i]); } + else { thisSize = ct.getGroupCount(Groups[i]); } if (thisSize >= size) { newGroups.push_back(Groups[i]); } else { m->mothurOut("You have selected a size that is larger than " + Groups[i] + " number of sequences, removing " + Groups[i] + "."); m->mothurOutEndLine(); } } Groups = newGroups; + if (newGroups.size() == 0) { m->mothurOut("[ERROR]: all groups removed."); m->mothurOutEndLine(); m->control_pressed = true; } } - m->mothurOut("Sampling " + toString(size) + " from each group."); m->mothurOutEndLine(); + m->mothurOut("Sampling " + toString(size) + " from each group."); m->mothurOutEndLine(); }else{ - if (pickedGroups) { + if (pickedGroups) { int total = 0; for(int i = 0; i < Groups.size(); i++) { - total += groupMap->getNumSeqs(Groups[i]); + if (countfile == "") { total += groupMap.getNumSeqs(Groups[i]); } + else { total += ct.getGroupCount(Groups[i]); } } if (size == 0) { //user has not set size, set size = 10% samples size @@ -926,122 +1018,110 @@ int SubSampleCommand::getSubSampleList() { } if (total < size) { - m->mothurOut("Your size is too large for the number of groups you selected. Adjusting to " + toString(int (total * 0.10)) + "."); m->mothurOutEndLine(); + if (size != 0) { + m->mothurOut("Your size is too large for the number of groups you selected. Adjusting to " + toString(int (total * 0.10)) + "."); m->mothurOutEndLine(); + } size = int (total * 0.10); } m->mothurOut("Sampling " + toString(size) + " from " + toString(total) + "."); m->mothurOutEndLine(); - }else{ - - if (size == 0) { //user has not set size, set size = 10% samples size - size = int (list->getNumSeqs() * 0.10); + }else { + if (size == 0) { //user has not set size, set size = 10% samples size + if (countfile == "") { size = int (list->getNumSeqs() * 0.10); } + else { size = int (ct.getNumSeqs() * 0.10); } } - int thisSize = list->getNumSeqs(); + int thisSize = 0; + if (countfile == "") { thisSize = list->getNumSeqs(); } + else { thisSize = ct.getNumSeqs(); } + if (size > thisSize) { m->mothurOut("Your list file only contains " + toString(thisSize) + " sequences. Setting size to " + toString(thisSize) + "."); m->mothurOutEndLine(); size = thisSize; } - m->mothurOut("Sampling " + toString(size) + " from " + toString(list->getNumSeqs()) + "."); m->mothurOutEndLine(); - } - } - - - //fill names - for (int i = 0; i < list->getNumBins(); i++) { - string binnames = list->get(i); - - //parse names - string individual = ""; - int length = binnames.length(); - for(int j=0;jgetGroup(individual); - if (group == "not found") { m->mothurOut("[ERROR]: " + individual + " is not in your groupfile. please correct."); m->mothurOutEndLine(); group = "NOTFOUND"; } - - if (pickedGroups) { //if hte user picked groups, we only want to keep the names of sequences from those groups - if (m->inUsersGroups(group, Groups)) { - names.push_back(individual); - } - }else{ - names.push_back(individual); - } - }else{ //save everyone, group - names.push_back(individual); - } - individual = ""; - } - else{ - individual += binnames[j]; - } - } - //save last name - if (groupfile != "") { //if there is a groupfile given fill in group info - string group = groupMap->getGroup(individual); - if (group == "not found") { m->mothurOut("[ERROR]: " + individual + " is not in your groupfile. please correct."); m->mothurOutEndLine(); group = "NOTFOUND"; } - - if (pickedGroups) { //if hte user picked groups, we only want to keep the names of sequences from those groups - if (m->inUsersGroups(group, Groups)) { - names.push_back(individual); - } - }else{ - names.push_back(individual); - } - }else{ //save everyone, group - names.push_back(individual); - } - } - - random_shuffle(names.begin(), names.end()); - - //randomly select a subset of those names to include in the subsample - set subset; //dont want repeat sequence names added - if (persample) { - //initialize counts - map groupCounts; - map::iterator itGroupCounts; - for (int i = 0; i < Groups.size(); i++) { groupCounts[Groups[i]] = 0; } - - for (int j = 0; j < names.size(); j++) { - - if (m->control_pressed) { return 0; } - - string group = groupMap->getGroup(names[j]); - if (group == "not found") { m->mothurOut("[ERROR]: " + names[j] + " is not in your groupfile. please correct."); m->mothurOutEndLine(); group = "NOTFOUND"; } - else{ - itGroupCounts = groupCounts.find(group); - if (itGroupCounts != groupCounts.end()) { - if (groupCounts[group] < size) { subset.insert(names[j]); groupCounts[group]++; } - } - } - } - }else{ - for (int j = 0; j < size; j++) { - - if (m->control_pressed) { break; } - - subset.insert(names[j]); - } - } - - if (groupfile != "") { - //write out new groupfile - for (set::iterator it = subset.begin(); it != subset.end(); it++) { - string group = groupMap->getGroup(*it); - if (group == "not found") { group = "NOTFOUND"; } - - outGroup << *it << '\t' << group << endl; - } - outGroup.close(); delete groupMap; - } + m->mothurOut("Sampling " + toString(size) + " from " + toString(thisSize) + "."); m->mothurOutEndLine(); + } + } + set subset; //dont want repeat sequence names added + if (countfile == "") { + //fill names + for (int i = 0; i < list->getNumBins(); i++) { + string binnames = list->get(i); + vector thisBin; + m->splitAtComma(binnames, thisBin); + + for(int j=0;jmothurOut("[ERROR]: " + thisBin[j] + " is not in your groupfile. please correct."); m->mothurOutEndLine(); group = "NOTFOUND"; } + + //if hte user picked groups, we only want to keep the names of sequences from those groups + if (pickedGroups) { if (m->inUsersGroups(group, Groups)) { names.push_back(thisBin[j]); } } + else{ names.push_back(thisBin[j]); } + }//save everyone, group + else{ names.push_back(thisBin[j]); } + } + } + + random_shuffle(names.begin(), names.end()); + + //randomly select a subset of those names to include in the subsample + if (persample) { + //initialize counts + map groupCounts; + map::iterator itGroupCounts; + for (int i = 0; i < Groups.size(); i++) { groupCounts[Groups[i]] = 0; } + + for (int j = 0; j < names.size(); j++) { + + if (m->control_pressed) { delete list; delete input; return 0; } + + string group = groupMap.getGroup(names[j]); + if (group == "not found") { m->mothurOut("[ERROR]: " + names[j] + " is not in your groupfile. please correct."); m->mothurOutEndLine(); group = "NOTFOUND"; } + else{ + itGroupCounts = groupCounts.find(group); + if (itGroupCounts != groupCounts.end()) { + if (groupCounts[group] < size) { subset.insert(names[j]); groupCounts[group]++; } + } + } + } + }else{ + for (int j = 0; j < size; j++) { + if (m->control_pressed) { break; } + subset.insert(names[j]); + } + } + + if (groupfile != "") { + //write out new groupfile + for (set::iterator it = subset.begin(); it != subset.end(); it++) { + string group = groupMap.getGroup(*it); + if (group == "not found") { group = "NOTFOUND"; } + outGroup << *it << '\t' << group << endl; + } + outGroup.close(); + } + }else { + SubSample sample; CountTable sampledCt; + + if (persample) { sampledCt = sample.getSample(ct, size, Groups); } + else { sampledCt = sample.getSample(ct, size, Groups, pickedGroups); } + + vector sampledSeqs = sampledCt.getNamesOfSeqs(); + for (int i = 0; i < sampledSeqs.size(); i++) { subset.insert(sampledSeqs[i]); } + + string countOutputDir = outputDir; + if (outputDir == "") { countOutputDir += m->hasPath(countfile); } + string countOutputFileName = countOutputDir + m->getRootName(m->getSimpleName(countfile)) + getOutputFileNameTag("count", countfile); + outputTypes["count"].push_back(countOutputFileName); outputNames.push_back(countOutputFileName); + sampledCt.printTable(countOutputFileName); + } //as long as you are not at the end of the file or done wih the lines you want while((list != NULL) && ((allLines == 1) || (userLabels.size() != 0))) { - if (m->control_pressed) { delete list; delete input; out.close(); return 0; } + if (m->control_pressed) { delete list; delete input; out.close(); return 0; } if(allLines == 1 || labels.count(list->getLabel()) == 1){ @@ -1132,22 +1212,12 @@ int SubSampleCommand::processList(ListVector*& list, ofstream& out, set& if (m->control_pressed) { break; } - string binnames = list->get(i); - - //parse names - string individual = ""; - string newNames = ""; - int length = binnames.length(); - for(int j=0;jget(i); + vector binnames; + m->splitAtComma(bin, binnames); + string newNames = ""; + for(int j=0;j labels; //holds labels to be used string groups, label, outputDir; vector Groups, outputNames; int size; vector names; map > nameMap; + CountTable ct; int getSubSampleShared(); int getSubSampleList(); diff --git a/trimflowscommand.cpp b/trimflowscommand.cpp index 9f603c4..296a6fe 100644 --- a/trimflowscommand.cpp +++ b/trimflowscommand.cpp @@ -441,6 +441,8 @@ int TrimFlowsCommand::driverCreateTrim(string flowFileName, string trimFlowFileN } + if (m->debug) { m->mothurOut("[DEBUG]: " + currSeq.getName() + " " + currSeq.getUnaligned() + "\n"); } + if(barcodes.size() != 0){ success = trimOligos.stripBarcode(currSeq, barcodeIndex); if(success > bdiffs) { trashCode += 'b'; } -- 2.39.2