X-Git-Url: https://git.donarmstrong.com/?a=blobdiff_plain;f=forest.cpp;h=3cfb5b9054cf5d43fbbb35eb7b81a6a2ba0f5833;hb=250e3b11b1c9c1e1ad458ab6c7e71ac2e67e11d9;hp=58c7f7e7d4d780b54d4c757e4653a368a60f6a74;hpb=d263f4b3a4f96c672317d317061f6adb72656427;p=mothur.git diff --git a/forest.cpp b/forest.cpp index 58c7f7e..3cfb5b9 100644 --- a/forest.cpp +++ b/forest.cpp @@ -10,16 +10,30 @@ /***********************************************************************/ Forest::Forest(const std::vector < std::vector > dataSet, - const int numDecisionTrees, - const string treeSplitCriterion = "informationGain") -: dataSet(dataSet), -numDecisionTrees(numDecisionTrees), -numSamples((int)dataSet.size()), -numFeatures((int)(dataSet[0].size() - 1)), -globalDiscardedFeatureIndices(getGlobalDiscardedFeatureIndices()), -globalVariableImportanceList(numFeatures, 0), -treeSplitCriterion(treeSplitCriterion) { + const int numDecisionTrees, + const string treeSplitCriterion = "gainratio", + const bool doPruning = false, + const float pruneAggressiveness = 0.9, + const bool discardHighErrorTrees = true, + const float highErrorTreeDiscardThreshold = 0.4, + const string optimumFeatureSubsetSelectionCriteria = "log2", + const float featureStandardDeviationThreshold = 0.0) + : dataSet(dataSet), + numDecisionTrees(numDecisionTrees), + numSamples((int)dataSet.size()), + numFeatures((int)(dataSet[0].size() - 1)), + globalVariableImportanceList(numFeatures, 0), + treeSplitCriterion(treeSplitCriterion), + doPruning(doPruning), + pruneAggressiveness(pruneAggressiveness), + discardHighErrorTrees(discardHighErrorTrees), + highErrorTreeDiscardThreshold(highErrorTreeDiscardThreshold), + optimumFeatureSubsetSelectionCriteria(optimumFeatureSubsetSelectionCriteria), + featureStandardDeviationThreshold(featureStandardDeviationThreshold) + { + m = MothurOut::getInstance(); + globalDiscardedFeatureIndices = getGlobalDiscardedFeatureIndices(); // TODO: double check if the implemenatation of 'globalOutOfBagEstimates' is correct } @@ -40,7 +54,7 @@ vector Forest::getGlobalDiscardedFeatureIndices() { for (int i = 0; i < featureVectors.size(); i++) { if (m->control_pressed) { return globalDiscardedFeatureIndices; } double standardDeviation = m->getStandardDeviation(featureVectors[i]); - if (standardDeviation <= 0){ globalDiscardedFeatureIndices.push_back(i); } + if (standardDeviation <= featureStandardDeviationThreshold){ globalDiscardedFeatureIndices.push_back(i); } } if (m->debug) {