/***********************************************************************/
-RandomForest::RandomForest(const vector <vector<int> > dataSet,const int numDecisionTrees,
- const string treeSplitCriterion = "informationGain") : AbstractRandomForest(dataSet, numDecisionTrees, treeSplitCriterion) {
+RandomForest::RandomForest(const vector <vector<int> > dataSet,
+ const int numDecisionTrees,
+ const string treeSplitCriterion = "gainratio",
+ const bool doPruning = false,
+ const float pruneAggressiveness = 0.9,
+ const bool discardHighErrorTrees = true,
+ const float highErrorTreeDiscardThreshold = 0.4,
+ const string optimumFeatureSubsetSelectionCriteria = "log2",
+ const float featureStandardDeviationThreshold = 0.0)
+ : Forest(dataSet, numDecisionTrees, treeSplitCriterion, doPruning, pruneAggressiveness, discardHighErrorTrees, highErrorTreeDiscardThreshold, optimumFeatureSubsetSelectionCriteria, featureStandardDeviationThreshold) {
m = MothurOut::getInstance();
}
}
/***********************************************************************/
-// DONE
int RandomForest::calcForrestVariableImportance(string filename) {
try {
- // TODO: need to add try/catch operators to fix this
- // follow the link: http://en.wikipedia.org/wiki/Dynamic_cast
+ // follow the link: http://en.wikipedia.org/wiki/Dynamic_cast
//if you are going to dynamically cast, aren't you undoing the advantage of abstraction. Why abstract at all?
//could cause maintenance issues later if other types of Abstract decison trees are created that cannot be cast as a decision tree.
- for (int i = 0; i < decisionTrees.size(); i++) {
- if (m->control_pressed) { return 0; }
- DecisionTree* decisionTree = dynamic_cast<DecisionTree*>(decisionTrees[i]);
+ for (int i = 0; i < decisionTrees.size(); i++) {
+ if (m->control_pressed) { return 0; }
+
+ DecisionTree* decisionTree = dynamic_cast<DecisionTree*>(decisionTrees[i]);
+
+ for (int j = 0; j < numFeatures; j++) {
+ globalVariableImportanceList[j] += (double)decisionTree->variableImportanceList[j];
+ }
+ }
- for (int j = 0; j < numFeatures; j++) {
- globalVariableImportanceList[j] += (double)decisionTree->variableImportanceList[j];
+ for (int i = 0; i < numFeatures; i++) {
+ globalVariableImportanceList[i] /= (double)numDecisionTrees;
}
- }
-
- for (int i = 0; i < numFeatures; i++) {
- cout << "[" << i << ',' << globalVariableImportanceList[i] << "], ";
- globalVariableImportanceList[i] /= (double)numDecisionTrees;
- }
-
- vector< vector<double> > globalVariableRanks;
- for (int i = 0; i < globalVariableImportanceList.size(); i++) {
- if (globalVariableImportanceList[i] > 0) {
- vector<double> globalVariableRank(2, 0);
- globalVariableRank[0] = i; globalVariableRank[1] = globalVariableImportanceList[i];
- globalVariableRanks.push_back(globalVariableRank);
+
+ vector< pair<int, double> > globalVariableRanks;
+ for (int i = 0; i < globalVariableImportanceList.size(); i++) {
+ //cout << "[" << i << ',' << globalVariableImportanceList[i] << "], ";
+ if (globalVariableImportanceList[i] > 0) {
+ pair<int, double> globalVariableRank(0, 0.0);
+ globalVariableRank.first = i;
+ globalVariableRank.second = globalVariableImportanceList[i];
+ globalVariableRanks.push_back(globalVariableRank);
+ }
}
- }
-
- VariableRankDescendingSorterDouble variableRankDescendingSorter;
- sort(globalVariableRanks.begin(), globalVariableRanks.end(), variableRankDescendingSorter);
+
+// for (int i = 0; i < globalVariableRanks.size(); i++) {
+// cout << m->currentBinLabels[(int)globalVariableRanks[i][0]] << '\t' << globalVariableImportanceList[globalVariableRanks[i][0]] << endl;
+// }
+
+
+ VariableRankDescendingSorterDouble variableRankDescendingSorter;
+ sort(globalVariableRanks.begin(), globalVariableRanks.end(), variableRankDescendingSorter);
+
ofstream out;
m->openOutputFile(filename, out);
out <<"OTU\tRank\n";
for (int i = 0; i < globalVariableRanks.size(); i++) {
- out << m->currentBinLabels[(int)globalVariableRanks[i][0]] << '\t' << globalVariableImportanceList[globalVariableRanks[i][0]] << endl;
+ out << m->currentBinLabels[(int)globalVariableRanks[i].first] << '\t' << globalVariableImportanceList[globalVariableRanks[i].first] << endl;
}
out.close();
return 0;
}
}
/***********************************************************************/
-// DONE
int RandomForest::populateDecisionTrees() {
try {
+ vector<double> errorRateImprovements;
+
for (int i = 0; i < numDecisionTrees; i++) {
+
if (m->control_pressed) { return 0; }
- if (((i+1) % 10) == 0) { m->mothurOut("Creating " + toString(i+1) + " (th) Decision tree\n"); }
+ if (((i+1) % 100) == 0) { m->mothurOut("Creating " + toString(i+1) + " (th) Decision tree\n"); }
+
// TODO: need to first fix if we are going to use pointer based system or anything else
- DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector("log2"), treeSplitCriterion);
- decisionTree->calcTreeVariableImportanceAndError();
- if (m->control_pressed) { return 0; }
- updateGlobalOutOfBagEstimates(decisionTree);
- if (m->control_pressed) { return 0; }
- decisionTree->purgeDataSetsFromTree();
- if (m->control_pressed) { return 0; }
- decisionTrees.push_back(decisionTree);
+ DecisionTree* decisionTree = new DecisionTree(dataSet, globalDiscardedFeatureIndices, OptimumFeatureSubsetSelector(optimumFeatureSubsetSelectionCriteria), treeSplitCriterion, featureStandardDeviationThreshold);
+
+ if (m->debug && doPruning) {
+ m->mothurOut("Before pruning\n");
+ decisionTree->printTree(decisionTree->rootNode, "ROOT");
+ }
+
+ int numCorrect;
+ double treeErrorRate;
+
+ decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate);
+ double prePrunedErrorRate = treeErrorRate;
+
+ if (m->debug) {
+ m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n");
+ }
+
+ if (doPruning) {
+ decisionTree->pruneTree(pruneAggressiveness);
+ if (m->debug) {
+ m->mothurOut("After pruning\n");
+ decisionTree->printTree(decisionTree->rootNode, "ROOT");
+ }
+ decisionTree->calcTreeErrorRate(numCorrect, treeErrorRate);
+ }
+ double postPrunedErrorRate = treeErrorRate;
+
+
+ decisionTree->calcTreeVariableImportanceAndError(numCorrect, treeErrorRate);
+ double errorRateImprovement = (prePrunedErrorRate - postPrunedErrorRate) / prePrunedErrorRate;
+
+ if (m->debug) {
+ m->mothurOut("treeErrorRate: " + toString(treeErrorRate) + " numCorrect: " + toString(numCorrect) + "\n");
+ if (doPruning) {
+ m->mothurOut("errorRateImprovement: " + toString(errorRateImprovement) + "\n");
+ }
+ }
+
+
+ if (discardHighErrorTrees) {
+ if (treeErrorRate < highErrorTreeDiscardThreshold) {
+ updateGlobalOutOfBagEstimates(decisionTree);
+ decisionTree->purgeDataSetsFromTree();
+ decisionTrees.push_back(decisionTree);
+ if (doPruning) {
+ errorRateImprovements.push_back(errorRateImprovement);
+ }
+ } else {
+ delete decisionTree;
+ }
+ } else {
+ updateGlobalOutOfBagEstimates(decisionTree);
+ decisionTree->purgeDataSetsFromTree();
+ decisionTrees.push_back(decisionTree);
+ if (doPruning) {
+ errorRateImprovements.push_back(errorRateImprovement);
+ }
+ }
+ }
+
+ double avgErrorRateImprovement = -1.0;
+ if (errorRateImprovements.size() > 0) {
+ avgErrorRateImprovement = accumulate(errorRateImprovements.begin(), errorRateImprovements.end(), 0.0);
+// cout << "Total " << avgErrorRateImprovement << " size " << errorRateImprovements.size() << endl;
+ avgErrorRateImprovement /= errorRateImprovements.size();
}
- if (m->debug) {
- // m->mothurOut("globalOutOfBagEstimates = " + toStringVectorMap(globalOutOfBagEstimates)+ "\n");
+ if (m->debug && doPruning) {
+ m->mothurOut("avgErrorRateImprovement:" + toString(avgErrorRateImprovement) + "\n");
}
+ // m->mothurOut("globalOutOfBagEstimates = " + toStringVectorMap(globalOutOfBagEstimates)+ "\n");
+
return 0;
}
}
/***********************************************************************/
// TODO: need to finalize bettween reference and pointer for DecisionTree [partially solved]
-// TODO: make this pure virtual in superclass
+// DONE: make this pure virtual in superclass
// DONE
int RandomForest::updateGlobalOutOfBagEstimates(DecisionTree* decisionTree) {
try {