X-Git-Url: https://git.donarmstrong.com/?p=mothur.git;a=blobdiff_plain;f=forest.cpp;h=3cfb5b9054cf5d43fbbb35eb7b81a6a2ba0f5833;hp=8ac1b79ca12c540ccb8c38498cc62b0201bc2f41;hb=b206f634aae1b4ce13978d203247fb64757d5482;hpb=5e1ab7456ec5e9e516cfa0fec6afef2c2a03a257 diff --git a/forest.cpp b/forest.cpp index 8ac1b79..3cfb5b9 100644 --- a/forest.cpp +++ b/forest.cpp @@ -10,16 +10,30 @@ /***********************************************************************/ Forest::Forest(const std::vector < std::vector > dataSet, - const int numDecisionTrees, - const string treeSplitCriterion = "informationGain") -: dataSet(dataSet), -numDecisionTrees(numDecisionTrees), -numSamples((int)dataSet.size()), -numFeatures((int)(dataSet[0].size() - 1)), -globalDiscardedFeatureIndices(getGlobalDiscardedFeatureIndices()), -globalVariableImportanceList(numFeatures, 0), -treeSplitCriterion(treeSplitCriterion) { + const int numDecisionTrees, + const string treeSplitCriterion = "gainratio", + const bool doPruning = false, + const float pruneAggressiveness = 0.9, + const bool discardHighErrorTrees = true, + const float highErrorTreeDiscardThreshold = 0.4, + const string optimumFeatureSubsetSelectionCriteria = "log2", + const float featureStandardDeviationThreshold = 0.0) + : dataSet(dataSet), + numDecisionTrees(numDecisionTrees), + numSamples((int)dataSet.size()), + numFeatures((int)(dataSet[0].size() - 1)), + globalVariableImportanceList(numFeatures, 0), + treeSplitCriterion(treeSplitCriterion), + doPruning(doPruning), + pruneAggressiveness(pruneAggressiveness), + discardHighErrorTrees(discardHighErrorTrees), + highErrorTreeDiscardThreshold(highErrorTreeDiscardThreshold), + optimumFeatureSubsetSelectionCriteria(optimumFeatureSubsetSelectionCriteria), + featureStandardDeviationThreshold(featureStandardDeviationThreshold) + { + m = MothurOut::getInstance(); + globalDiscardedFeatureIndices = getGlobalDiscardedFeatureIndices(); // TODO: double check if the implemenatation of 'globalOutOfBagEstimates' is correct } @@ -27,10 +41,11 @@ treeSplitCriterion(treeSplitCriterion) { vector Forest::getGlobalDiscardedFeatureIndices() { try { - vector globalDiscardedFeatureIndices; + //vector globalDiscardedFeatureIndices; + //globalDiscardedFeatureIndices.push_back(1); // calculate feature vectors - vector< vector > featureVectors(numFeatures, vector(numSamples, 0)); + vector< vector > featureVectors(numFeatures, vector(numSamples, 0) ); for (int i = 0; i < numSamples; i++) { if (m->control_pressed) { return globalDiscardedFeatureIndices; } for (int j = 0; j < numFeatures; j++) { featureVectors[j][i] = dataSet[i][j]; } @@ -39,7 +54,7 @@ vector Forest::getGlobalDiscardedFeatureIndices() { for (int i = 0; i < featureVectors.size(); i++) { if (m->control_pressed) { return globalDiscardedFeatureIndices; } double standardDeviation = m->getStandardDeviation(featureVectors[i]); - if (standardDeviation <= 0){ globalDiscardedFeatureIndices.push_back(i); } + if (standardDeviation <= featureStandardDeviationThreshold){ globalDiscardedFeatureIndices.push_back(i); } } if (m->debug) {