5 // Created by Sarah Westcott on 10/1/12.
6 // Copyright (c) 2012 Schloss Lab. All rights reserved.
9 #include "decisiontree.hpp"
11 DecisionTree::DecisionTree(vector< vector<int> >& baseDataSet,
12 vector<int> globalDiscardedFeatureIndices,
13 OptimumFeatureSubsetSelector optimumFeatureSubsetSelector,
14 string treeSplitCriterion,
15 float featureStandardDeviationThreshold)
16 : AbstractDecisionTree(baseDataSet,
17 globalDiscardedFeatureIndices,
18 optimumFeatureSubsetSelector,
20 variableImportanceList(numFeatures, 0),
21 featureStandardDeviationThreshold(featureStandardDeviationThreshold) {
24 m = MothurOut::getInstance();
25 createBootStrappedSamples();
29 m->errorOut(e, "DecisionTree", "DecisionTree");
34 /***********************************************************************/
36 int DecisionTree::calcTreeVariableImportanceAndError(int& numCorrect, double& treeErrorRate) {
38 vector< vector<int> > randomlySampledTestData(bootstrappedTestSamples.size(), vector<int>(bootstrappedTestSamples[0].size(), 0));
40 // TODO: is is possible to further speed up the following O(N^2) by using std::copy?
41 for (int i = 0; i < bootstrappedTestSamples.size(); i++) {
42 for (int j = 0; j < bootstrappedTestSamples[i].size(); j++) {
43 randomlySampledTestData[i][j] = bootstrappedTestSamples[i][j];
47 for (int i = 0; i < numFeatures; i++) {
48 if (m->control_pressed) { return 0; }
50 // if the index is in globalDiscardedFeatureIndices (i.e, null feature) we don't want to shuffle them
51 vector<int>::iterator it = find(globalDiscardedFeatureIndices.begin(), globalDiscardedFeatureIndices.end(), i);
52 if (it == globalDiscardedFeatureIndices.end()) { // NOT FOUND
53 // if the standard deviation is very low, we know it's not a good feature at all
54 // we can save some time here by discarding that feature
56 vector<int> featureVector = testSampleFeatureVectors[i];
57 if (m->getStandardDeviation(featureVector) > featureStandardDeviationThreshold) {
58 // NOTE: only shuffle the features, never shuffle the output vector
59 // so i = 0 and i will be alwaays <= (numFeatures - 1) as the index at numFeatures will denote
61 randomlyShuffleAttribute(bootstrappedTestSamples, i, i - 1, randomlySampledTestData);
63 int numCorrectAfterShuffle = 0;
64 for (int j = 0; j < randomlySampledTestData.size(); j++) {
65 if (m->control_pressed) {return 0; }
67 vector<int> shuffledSample = randomlySampledTestData[j];
68 int actualSampleOutputClass = shuffledSample[numFeatures];
69 int predictedSampleOutputClass = evaluateSample(shuffledSample);
70 if (actualSampleOutputClass == predictedSampleOutputClass) { numCorrectAfterShuffle++; }
72 variableImportanceList[i] += (numCorrect - numCorrectAfterShuffle);
77 // TODO: do we need to save the variableRanks in the DecisionTree, do we need it later?
78 vector< pair<int, int> > variableRanks;
80 for (int i = 0; i < variableImportanceList.size(); i++) {
81 if (m->control_pressed) {return 0; }
82 if (variableImportanceList[i] > 0) {
83 // TODO: is there a way to optimize the follow line's code?
84 pair<int, int> variableRank(0, 0);
85 variableRank.first = i;
86 variableRank.second = variableImportanceList[i];
87 variableRanks.push_back(variableRank);
90 VariableRankDescendingSorter variableRankDescendingSorter;
91 sort(variableRanks.begin(), variableRanks.end(), variableRankDescendingSorter);
96 m->errorOut(e, "DecisionTree", "calcTreeVariableImportanceAndError");
101 /***********************************************************************/
103 // TODO: there must be a way to optimize this function
104 int DecisionTree::evaluateSample(vector<int> testSample) {
106 RFTreeNode *node = rootNode;
108 if (m->control_pressed) { return 0; }
110 if (node->checkIsLeaf()) { return node->getOutputClass(); }
112 int sampleSplitFeatureValue = testSample[node->getSplitFeatureIndex()];
113 if (sampleSplitFeatureValue < node->getSplitFeatureValue()) { node = node->getLeftChildNode(); }
114 else { node = node->getRightChildNode(); }
118 catch(exception& e) {
119 m->errorOut(e, "DecisionTree", "evaluateSample");
124 /***********************************************************************/
126 int DecisionTree::calcTreeErrorRate(int& numCorrect, double& treeErrorRate){
129 for (int i = 0; i < bootstrappedTestSamples.size(); i++) {
130 if (m->control_pressed) {return 0; }
132 vector<int> testSample = bootstrappedTestSamples[i];
133 int testSampleIndex = bootstrappedTestSampleIndices[i];
135 int actualSampleOutputClass = testSample[numFeatures];
136 int predictedSampleOutputClass = evaluateSample(testSample);
138 if (actualSampleOutputClass == predictedSampleOutputClass) { numCorrect++; }
140 outOfBagEstimates[testSampleIndex] = predictedSampleOutputClass;
143 treeErrorRate = 1 - ((double)numCorrect / (double)bootstrappedTestSamples.size());
147 catch(exception& e) {
148 m->errorOut(e, "DecisionTree", "calcTreeErrorRate");
153 /***********************************************************************/
154 // TODO: optimize the algo, instead of transposing two time, we can extarct the feature,
155 // shuffle it and then re-insert in the original place, thus iproving runnting time
156 //This function randomize abundances for a given OTU/feature.
158 void DecisionTree::randomlyShuffleAttribute(const vector< vector<int> >& samples,
159 const int featureIndex,
160 const int prevFeatureIndex,
161 vector< vector<int> >& shuffledSample) {
163 // NOTE: we need (numFeatures + 1) featureVecotors, the last extra vector is actually outputVector
165 // restore previously shuffled feature
166 if (prevFeatureIndex > -1) {
167 for (int j = 0; j < samples.size(); j++) {
168 if (m->control_pressed) { return; }
169 shuffledSample[j][prevFeatureIndex] = samples[j][prevFeatureIndex];
173 // now do the shuffling
174 vector<int> featureVectors(samples.size(), 0);
175 for (int j = 0; j < samples.size(); j++) {
176 if (m->control_pressed) { return; }
177 featureVectors[j] = samples[j][featureIndex];
179 random_shuffle(featureVectors.begin(), featureVectors.end());
180 for (int j = 0; j < samples.size(); j++) {
181 if (m->control_pressed) { return; }
182 shuffledSample[j][featureIndex] = featureVectors[j];
185 catch(exception& e) {
186 m->errorOut(e, "DecisionTree", "randomlyShuffleAttribute");
192 /***********************************************************************/
194 int DecisionTree::purgeTreeNodesDataRecursively(RFTreeNode* treeNode) {
196 treeNode->bootstrappedTrainingSamples.clear();
197 treeNode->bootstrappedFeatureVectors.clear();
198 treeNode->bootstrappedOutputVector.clear();
199 treeNode->localDiscardedFeatureIndices.clear();
200 treeNode->globalDiscardedFeatureIndices.clear();
202 if (treeNode->leftChildNode != NULL) { purgeTreeNodesDataRecursively(treeNode->leftChildNode); }
203 if (treeNode->rightChildNode != NULL) { purgeTreeNodesDataRecursively(treeNode->rightChildNode); }
206 catch(exception& e) {
207 m->errorOut(e, "DecisionTree", "purgeTreeNodesDataRecursively");
211 /***********************************************************************/
213 void DecisionTree::buildDecisionTree(){
217 rootNode = new RFTreeNode(bootstrappedTrainingSamples, globalDiscardedFeatureIndices, numFeatures, numSamples, numOutputClasses, generation, nodeIdCount, featureStandardDeviationThreshold);
220 splitRecursively(rootNode);
223 catch(exception& e) {
224 m->errorOut(e, "DecisionTree", "buildDecisionTree");
229 /***********************************************************************/
231 int DecisionTree::splitRecursively(RFTreeNode* rootNode) {
234 if (rootNode->getNumSamples() < 2){
235 rootNode->setIsLeaf(true);
236 rootNode->setOutputClass(rootNode->getBootstrappedTrainingSamples()[0][rootNode->getNumFeatures()]);
240 int classifiedOutputClass;
241 bool isAlreadyClassified = checkIfAlreadyClassified(rootNode, classifiedOutputClass);
242 if (isAlreadyClassified == true){
243 rootNode->setIsLeaf(true);
244 rootNode->setOutputClass(classifiedOutputClass);
247 if (m->control_pressed) { return 0; }
248 vector<int> featureSubsetIndices = selectFeatureSubsetRandomly(globalDiscardedFeatureIndices, rootNode->getLocalDiscardedFeatureIndices());
250 // TODO: need to check if the value is actually copied correctly
251 rootNode->setFeatureSubsetIndices(featureSubsetIndices);
252 if (m->control_pressed) { return 0; }
254 findAndUpdateBestFeatureToSplitOn(rootNode);
256 // update rootNode outputClass, this is needed for pruning
257 // this is only for internal nodes
258 updateOutputClassOfNode(rootNode);
260 if (m->control_pressed) { return 0; }
262 vector< vector<int> > leftChildSamples;
263 vector< vector<int> > rightChildSamples;
264 getSplitPopulation(rootNode, leftChildSamples, rightChildSamples);
266 if (m->control_pressed) { return 0; }
268 // TODO: need to write code to clear this memory
269 RFTreeNode* leftChildNode = new RFTreeNode(leftChildSamples, globalDiscardedFeatureIndices, numFeatures, (int)leftChildSamples.size(), numOutputClasses, rootNode->getGeneration() + 1, nodeIdCount, featureStandardDeviationThreshold);
271 RFTreeNode* rightChildNode = new RFTreeNode(rightChildSamples, globalDiscardedFeatureIndices, numFeatures, (int)rightChildSamples.size(), numOutputClasses, rootNode->getGeneration() + 1, nodeIdCount, featureStandardDeviationThreshold);
274 rootNode->setLeftChildNode(leftChildNode);
275 leftChildNode->setParentNode(rootNode);
277 rootNode->setRightChildNode(rightChildNode);
278 rightChildNode->setParentNode(rootNode);
280 // TODO: This recursive split can be parrallelized later
281 splitRecursively(leftChildNode);
282 if (m->control_pressed) { return 0; }
284 splitRecursively(rightChildNode);
288 catch(exception& e) {
289 m->errorOut(e, "DecisionTree", "splitRecursively");
293 /***********************************************************************/
295 int DecisionTree::findAndUpdateBestFeatureToSplitOn(RFTreeNode* node){
298 vector< vector<int> > bootstrappedFeatureVectors = node->getBootstrappedFeatureVectors();
299 if (m->control_pressed) { return 0; }
300 vector<int> bootstrappedOutputVector = node->getBootstrappedOutputVector();
301 if (m->control_pressed) { return 0; }
302 vector<int> featureSubsetIndices = node->getFeatureSubsetIndices();
303 if (m->control_pressed) { return 0; }
305 vector<double> featureSubsetEntropies;
306 vector<int> featureSubsetSplitValues;
307 vector<double> featureSubsetIntrinsicValues;
308 vector<double> featureSubsetGainRatios;
310 for (int i = 0; i < featureSubsetIndices.size(); i++) {
311 if (m->control_pressed) { return 0; }
313 int tryIndex = featureSubsetIndices[i];
315 double featureMinEntropy;
316 int featureSplitValue;
317 double featureIntrinsicValue;
319 getMinEntropyOfFeature(bootstrappedFeatureVectors[tryIndex], bootstrappedOutputVector, featureMinEntropy, featureSplitValue, featureIntrinsicValue);
320 if (m->control_pressed) { return 0; }
322 featureSubsetEntropies.push_back(featureMinEntropy);
323 featureSubsetSplitValues.push_back(featureSplitValue);
324 featureSubsetIntrinsicValues.push_back(featureIntrinsicValue);
326 double featureInformationGain = node->getOwnEntropy() - featureMinEntropy;
327 double featureGainRatio = (double)featureInformationGain / (double)featureIntrinsicValue;
328 featureSubsetGainRatios.push_back(featureGainRatio);
332 vector<double>::iterator minEntropyIterator = min_element(featureSubsetEntropies.begin(), featureSubsetEntropies.end());
333 vector<double>::iterator maxGainRatioIterator = max_element(featureSubsetGainRatios.begin(), featureSubsetGainRatios.end());
334 double featureMinEntropy = *minEntropyIterator;
336 // TODO: kept the following line as future reference, can be useful
337 // double featureMaxGainRatio = *maxGainRatioIterator;
339 double bestFeatureSplitEntropy = featureMinEntropy;
340 int bestFeatureToSplitOnIndex = -1;
341 if (treeSplitCriterion == "gainratio"){
342 bestFeatureToSplitOnIndex = (int)(maxGainRatioIterator - featureSubsetGainRatios.begin());
343 // if using 'gainRatio' measure, then featureMinEntropy must be re-updated, as the index
344 // for 'featureMaxGainRatio' would be different
345 bestFeatureSplitEntropy = featureSubsetEntropies[bestFeatureToSplitOnIndex];
346 } else if ( treeSplitCriterion == "infogain"){
347 bestFeatureToSplitOnIndex = (int)(minEntropyIterator - featureSubsetEntropies.begin());
349 // TODO: we need an abort mechanism here
352 // TODO: is the following line needed? kept is as future reference
353 // splitInformationGain = node.ownEntropy - node.splitFeatureEntropy
355 int bestFeatureSplitValue = featureSubsetSplitValues[bestFeatureToSplitOnIndex];
357 node->setSplitFeatureIndex(featureSubsetIndices[bestFeatureToSplitOnIndex]);
358 node->setSplitFeatureValue(bestFeatureSplitValue);
359 node->setSplitFeatureEntropy(bestFeatureSplitEntropy);
360 // TODO: kept the following line as future reference
361 // node.splitInformationGain = splitInformationGain
365 catch(exception& e) {
366 m->errorOut(e, "DecisionTree", "findAndUpdateBestFeatureToSplitOn");
370 /***********************************************************************/
371 vector<int> DecisionTree::selectFeatureSubsetRandomly(vector<int> globalDiscardedFeatureIndices, vector<int> localDiscardedFeatureIndices){
374 vector<int> featureSubsetIndices;
376 vector<int> combinedDiscardedFeatureIndices;
377 combinedDiscardedFeatureIndices.insert(combinedDiscardedFeatureIndices.end(), globalDiscardedFeatureIndices.begin(), globalDiscardedFeatureIndices.end());
378 combinedDiscardedFeatureIndices.insert(combinedDiscardedFeatureIndices.end(), localDiscardedFeatureIndices.begin(), localDiscardedFeatureIndices.end());
380 sort(combinedDiscardedFeatureIndices.begin(), combinedDiscardedFeatureIndices.end());
382 int numberOfRemainingSuitableFeatures = (int)(numFeatures - combinedDiscardedFeatureIndices.size());
383 int currentFeatureSubsetSize = numberOfRemainingSuitableFeatures < optimumFeatureSubsetSize ? numberOfRemainingSuitableFeatures : optimumFeatureSubsetSize;
385 while (featureSubsetIndices.size() < currentFeatureSubsetSize) {
387 if (m->control_pressed) { return featureSubsetIndices; }
389 // TODO: optimize rand() call here
390 int randomIndex = rand() % numFeatures;
391 vector<int>::iterator it = find(featureSubsetIndices.begin(), featureSubsetIndices.end(), randomIndex);
392 if (it == featureSubsetIndices.end()){ // NOT FOUND
393 vector<int>::iterator it2 = find(combinedDiscardedFeatureIndices.begin(), combinedDiscardedFeatureIndices.end(), randomIndex);
394 if (it2 == combinedDiscardedFeatureIndices.end()){ // NOT FOUND AGAIN
395 featureSubsetIndices.push_back(randomIndex);
399 sort(featureSubsetIndices.begin(), featureSubsetIndices.end());
401 //#ifdef DEBUG_LEVEL_3
402 // PRINT_VAR(featureSubsetIndices);
405 return featureSubsetIndices;
407 catch(exception& e) {
408 m->errorOut(e, "DecisionTree", "selectFeatureSubsetRandomly");
412 /***********************************************************************/
414 // TODO: printTree() needs a check if correct
415 int DecisionTree::printTree(RFTreeNode* treeNode, string caption){
418 for (int i = 0; i < treeNode->getGeneration(); i++) { tabs += "|--"; }
419 // for (int i = 0; i < treeNode->getGeneration() - 1; i++) { tabs += "| "; }
420 // if (treeNode->getGeneration() != 0) { tabs += "|--"; }
422 if (treeNode != NULL && treeNode->checkIsLeaf() == false){
423 m->mothurOut(tabs + caption + " [ gen: " + toString(treeNode->getGeneration()) + " , id: " + toString(treeNode->nodeId) + " ] ( " + toString(treeNode->getSplitFeatureValue()) + " < X" + toString(treeNode->getSplitFeatureIndex()) + " ) ( predicted: " + toString(treeNode->outputClass) + " , misclassified: " + toString(treeNode->testSampleMisclassificationCount) + " )\n");
425 printTree(treeNode->getLeftChildNode(), "left ");
426 printTree(treeNode->getRightChildNode(), "right");
428 m->mothurOut(tabs + caption + " [ gen: " + toString(treeNode->getGeneration()) + " , id: " + toString(treeNode->nodeId) + " ] ( classified to: " + toString(treeNode->getOutputClass()) + ", samples: " + toString(treeNode->getNumSamples()) + " , misclassified: " + toString(treeNode->testSampleMisclassificationCount) + " )\n");
432 catch(exception& e) {
433 m->errorOut(e, "DecisionTree", "printTree");
437 /***********************************************************************/
438 void DecisionTree::deleteTreeNodesRecursively(RFTreeNode* treeNode) {
440 if (treeNode == NULL) { return; }
441 deleteTreeNodesRecursively(treeNode->leftChildNode);
442 deleteTreeNodesRecursively(treeNode->rightChildNode);
443 delete treeNode; treeNode = NULL;
445 catch(exception& e) {
446 m->errorOut(e, "DecisionTree", "deleteTreeNodesRecursively");
450 /***********************************************************************/
452 void DecisionTree::pruneTree(double pruneAggressiveness = 0.9) {
454 // find out the number of misclassification by each of the nodes
455 for (int i = 0; i < bootstrappedTestSamples.size(); i++) {
456 if (m->control_pressed) { return; }
458 vector<int> testSample = bootstrappedTestSamples[i];
459 updateMisclassificationCountRecursively(rootNode, testSample);
462 // do the actual pruning
463 pruneRecursively(rootNode, pruneAggressiveness);
465 /***********************************************************************/
467 void DecisionTree::pruneRecursively(RFTreeNode* treeNode, double pruneAggressiveness){
469 if (treeNode != NULL && treeNode->checkIsLeaf() == false) {
470 if (m->control_pressed) { return; }
472 pruneRecursively(treeNode->leftChildNode, pruneAggressiveness);
473 pruneRecursively(treeNode->rightChildNode, pruneAggressiveness);
475 int subTreeMisclassificationCount = treeNode->leftChildNode->getTestSampleMisclassificationCount() + treeNode->rightChildNode->getTestSampleMisclassificationCount();
476 int ownMisclassificationCount = treeNode->getTestSampleMisclassificationCount();
478 if (subTreeMisclassificationCount * pruneAggressiveness > ownMisclassificationCount) {
479 // TODO: need to check the effect of these two delete calls
480 delete treeNode->leftChildNode;
481 treeNode->leftChildNode = NULL;
483 delete treeNode->rightChildNode;
484 treeNode->rightChildNode = NULL;
486 treeNode->isLeaf = true;
491 /***********************************************************************/
493 void DecisionTree::updateMisclassificationCountRecursively(RFTreeNode* treeNode, vector<int> testSample) {
495 int actualSampleOutputClass = testSample[numFeatures];
496 int nodePredictedOutputClass = treeNode->outputClass;
498 if (actualSampleOutputClass != nodePredictedOutputClass) {
499 treeNode->testSampleMisclassificationCount++;
500 map<int, int>::iterator it = nodeMisclassificationCounts.find(treeNode->nodeId);
501 if (it == nodeMisclassificationCounts.end()) { // NOT FOUND
502 nodeMisclassificationCounts[treeNode->nodeId] = 0;
504 nodeMisclassificationCounts[treeNode->nodeId]++;
507 if (treeNode->checkIsLeaf() == false) { // NOT A LEAF
508 int sampleSplitFeatureValue = testSample[treeNode->splitFeatureIndex];
509 if (sampleSplitFeatureValue < treeNode->splitFeatureValue) {
510 updateMisclassificationCountRecursively(treeNode->leftChildNode, testSample);
512 updateMisclassificationCountRecursively(treeNode->rightChildNode, testSample);
517 /***********************************************************************/
519 void DecisionTree::updateOutputClassOfNode(RFTreeNode* treeNode) {
520 vector<int> counts(numOutputClasses, 0);
521 for (int i = 0; i < treeNode->bootstrappedOutputVector.size(); i++) {
522 int bootstrappedOutput = treeNode->bootstrappedOutputVector[i];
523 counts[bootstrappedOutput]++;
526 vector<int>::iterator majorityVotedOutputClassCountIterator = max_element(counts.begin(), counts.end());
527 int majorityVotedOutputClassCount = *majorityVotedOutputClassCountIterator;
528 vector<int>::iterator it = find(counts.begin(), counts.end(), majorityVotedOutputClassCount);
529 int majorityVotedOutputClass = (int)(it - counts.begin());
530 treeNode->setOutputClass(majorityVotedOutputClass);
533 /***********************************************************************/