X-Git-Url: https://git.donarmstrong.com/?a=blobdiff_plain;f=forest.cpp;h=3cfb5b9054cf5d43fbbb35eb7b81a6a2ba0f5833;hb=6b32d112bb60e9f7eb6d4407a4eed4c49b67bced;hp=179ecef19d26de2a61f68776445dfdfc72132f07;hpb=5eb72762d405db2dd83c2ec8a5d5c2eb57800ca3;p=mothur.git diff --git a/forest.cpp b/forest.cpp index 179ecef..3cfb5b9 100644 --- a/forest.cpp +++ b/forest.cpp @@ -10,14 +10,28 @@ /***********************************************************************/ Forest::Forest(const std::vector < std::vector > dataSet, - const int numDecisionTrees, - const string treeSplitCriterion = "informationGain") -: dataSet(dataSet), -numDecisionTrees(numDecisionTrees), -numSamples((int)dataSet.size()), -numFeatures((int)(dataSet[0].size() - 1)), -globalVariableImportanceList(numFeatures, 0), -treeSplitCriterion(treeSplitCriterion) { + const int numDecisionTrees, + const string treeSplitCriterion = "gainratio", + const bool doPruning = false, + const float pruneAggressiveness = 0.9, + const bool discardHighErrorTrees = true, + const float highErrorTreeDiscardThreshold = 0.4, + const string optimumFeatureSubsetSelectionCriteria = "log2", + const float featureStandardDeviationThreshold = 0.0) + : dataSet(dataSet), + numDecisionTrees(numDecisionTrees), + numSamples((int)dataSet.size()), + numFeatures((int)(dataSet[0].size() - 1)), + globalVariableImportanceList(numFeatures, 0), + treeSplitCriterion(treeSplitCriterion), + doPruning(doPruning), + pruneAggressiveness(pruneAggressiveness), + discardHighErrorTrees(discardHighErrorTrees), + highErrorTreeDiscardThreshold(highErrorTreeDiscardThreshold), + optimumFeatureSubsetSelectionCriteria(optimumFeatureSubsetSelectionCriteria), + featureStandardDeviationThreshold(featureStandardDeviationThreshold) + { + m = MothurOut::getInstance(); globalDiscardedFeatureIndices = getGlobalDiscardedFeatureIndices(); // TODO: double check if the implemenatation of 'globalOutOfBagEstimates' is correct @@ -40,7 +54,7 @@ vector Forest::getGlobalDiscardedFeatureIndices() { for (int i = 0; i < featureVectors.size(); i++) { if (m->control_pressed) { return globalDiscardedFeatureIndices; } double standardDeviation = m->getStandardDeviation(featureVectors[i]); - if (standardDeviation <= 0){ globalDiscardedFeatureIndices.push_back(i); } + if (standardDeviation <= featureStandardDeviationThreshold){ globalDiscardedFeatureIndices.push_back(i); } } if (m->debug) {