5 // Created by Sarah Westcott on 10/1/12.
6 // Copyright (c) 2012 Schloss Lab. All rights reserved.
9 #include "decisiontree.hpp"
11 DecisionTree::DecisionTree(vector< vector<int> > baseDataSet,
12 vector<int> globalDiscardedFeatureIndices,
13 OptimumFeatureSubsetSelector optimumFeatureSubsetSelector,
14 string treeSplitCriterion) : AbstractDecisionTree(baseDataSet,
15 globalDiscardedFeatureIndices,
16 optimumFeatureSubsetSelector,
17 treeSplitCriterion), variableImportanceList(numFeatures, 0){
19 m = MothurOut::getInstance();
20 createBootStrappedSamples();
24 m->errorOut(e, "DecisionTree", "DecisionTree");
29 /***********************************************************************/
31 int DecisionTree::calcTreeVariableImportanceAndError() {
36 calcTreeErrorRate(numCorrect, treeErrorRate);
38 if (m->control_pressed) {return 0; }
40 for (int i = 0; i < numFeatures; i++) {
41 if (m->control_pressed) {return 0; }
42 // NOTE: only shuffle the features, never shuffle the output vector
43 // so i = 0 and i will be alwaays <= (numFeatures - 1) as the index at numFeatures will denote
45 vector< vector<int> > randomlySampledTestData = randomlyShuffleAttribute(bootstrappedTestSamples, i);
47 int numCorrectAfterShuffle = 0;
48 for (int j = 0; j < randomlySampledTestData.size(); j++) {
49 if (m->control_pressed) {return 0; }
50 vector<int> shuffledSample = randomlySampledTestData[j];
51 int actualSampleOutputClass = shuffledSample[numFeatures];
52 int predictedSampleOutputClass = evaluateSample(shuffledSample);
53 if (actualSampleOutputClass == predictedSampleOutputClass) { numCorrectAfterShuffle++; }
55 variableImportanceList[i] += (numCorrect - numCorrectAfterShuffle);
58 // TODO: do we need to save the variableRanks in the DecisionTree, do we need it later?
59 vector< vector<int> > variableRanks;
60 for (int i = 0; i < variableImportanceList.size(); i++) {
61 if (m->control_pressed) {return 0; }
62 if (variableImportanceList[i] > 0) {
63 // TODO: is there a way to optimize the follow line's code?
64 vector<int> variableRank(2, 0);
65 variableRank[0] = i; variableRank[1] = variableImportanceList[i];
66 variableRanks.push_back(variableRank);
69 VariableRankDescendingSorter variableRankDescendingSorter;
70 sort(variableRanks.begin(), variableRanks.end(), variableRankDescendingSorter);
75 m->errorOut(e, "DecisionTree", "calcTreeVariableImportanceAndError");
80 /***********************************************************************/
82 // TODO: there must be a way to optimize this function
83 int DecisionTree::evaluateSample(vector<int> testSample) {
85 RFTreeNode *node = rootNode;
87 if (m->control_pressed) {return 0; }
88 if (node->checkIsLeaf()) { return node->getOutputClass(); }
89 int sampleSplitFeatureValue = testSample[node->getSplitFeatureIndex()];
90 if (sampleSplitFeatureValue < node->getSplitFeatureValue()) { node = node->getLeftChildNode(); }
91 else { node = node->getRightChildNode(); }
96 m->errorOut(e, "DecisionTree", "evaluateSample");
101 /***********************************************************************/
103 int DecisionTree::calcTreeErrorRate(int& numCorrect, double& treeErrorRate){
106 for (int i = 0; i < bootstrappedTestSamples.size(); i++) {
107 if (m->control_pressed) {return 0; }
109 vector<int> testSample = bootstrappedTestSamples[i];
110 int testSampleIndex = bootstrappedTestSampleIndices[i];
112 int actualSampleOutputClass = testSample[numFeatures];
113 int predictedSampleOutputClass = evaluateSample(testSample);
115 if (actualSampleOutputClass == predictedSampleOutputClass) { numCorrect++; }
117 outOfBagEstimates[testSampleIndex] = predictedSampleOutputClass;
120 treeErrorRate = 1 - ((double)numCorrect / (double)bootstrappedTestSamples.size());
124 catch(exception& e) {
125 m->errorOut(e, "DecisionTree", "calcTreeErrorRate");
130 /***********************************************************************/
132 // TODO: optimize the algo, instead of transposing two time, we can extarct the feature,
133 // shuffle it and then re-insert in the original place, thus iproving runnting time
134 //This function randomize abundances for a given OTU/feature.
135 vector< vector<int> > DecisionTree::randomlyShuffleAttribute(vector< vector<int> > samples, int featureIndex) {
137 // NOTE: we need (numFeatures + 1) featureVecotors, the last extra vector is actually outputVector
138 vector< vector<int> > shuffledSample = samples;
139 vector<int> featureVectors(samples.size(), 0);
141 for (int j = 0; j < samples.size(); j++) {
142 if (m->control_pressed) { return shuffledSample; }
143 featureVectors[j] = samples[j][featureIndex];
146 random_shuffle(featureVectors.begin(), featureVectors.end());
148 for (int j = 0; j < samples.size(); j++) {
149 if (m->control_pressed) {return shuffledSample; }
150 shuffledSample[j][featureIndex] = featureVectors[j];
153 return shuffledSample;
155 catch(exception& e) {
156 m->errorOut(e, "DecisionTree", "randomlyShuffleAttribute");
160 /***********************************************************************/
162 int DecisionTree::purgeTreeNodesDataRecursively(RFTreeNode* treeNode) {
164 treeNode->bootstrappedTrainingSamples.clear();
165 treeNode->bootstrappedFeatureVectors.clear();
166 treeNode->bootstrappedOutputVector.clear();
167 treeNode->localDiscardedFeatureIndices.clear();
168 treeNode->globalDiscardedFeatureIndices.clear();
170 if (treeNode->leftChildNode != NULL) { purgeTreeNodesDataRecursively(treeNode->leftChildNode); }
171 if (treeNode->rightChildNode != NULL) { purgeTreeNodesDataRecursively(treeNode->rightChildNode); }
174 catch(exception& e) {
175 m->errorOut(e, "DecisionTree", "purgeTreeNodesDataRecursively");
179 /***********************************************************************/
181 void DecisionTree::buildDecisionTree(){
185 rootNode = new RFTreeNode(bootstrappedTrainingSamples, globalDiscardedFeatureIndices, numFeatures, numSamples, numOutputClasses, generation);
187 splitRecursively(rootNode);
189 catch(exception& e) {
190 m->errorOut(e, "DecisionTree", "buildDecisionTree");
195 /***********************************************************************/
197 int DecisionTree::splitRecursively(RFTreeNode* rootNode) {
200 if (rootNode->getNumSamples() < 2){
201 rootNode->setIsLeaf(true);
202 rootNode->setOutputClass(rootNode->getBootstrappedTrainingSamples()[0][rootNode->getNumFeatures()]);
206 int classifiedOutputClass;
207 bool isAlreadyClassified = checkIfAlreadyClassified(rootNode, classifiedOutputClass);
208 if (isAlreadyClassified == true){
209 rootNode->setIsLeaf(true);
210 rootNode->setOutputClass(classifiedOutputClass);
213 if (m->control_pressed) {return 0;}
214 vector<int> featureSubsetIndices = selectFeatureSubsetRandomly(globalDiscardedFeatureIndices, rootNode->getLocalDiscardedFeatureIndices());
215 rootNode->setFeatureSubsetIndices(featureSubsetIndices);
216 if (m->control_pressed) {return 0;}
218 findAndUpdateBestFeatureToSplitOn(rootNode);
220 if (m->control_pressed) {return 0;}
222 vector< vector<int> > leftChildSamples;
223 vector< vector<int> > rightChildSamples;
224 getSplitPopulation(rootNode, leftChildSamples, rightChildSamples);
226 if (m->control_pressed) {return 0;}
228 // TODO: need to write code to clear this memory
229 RFTreeNode* leftChildNode = new RFTreeNode(leftChildSamples, globalDiscardedFeatureIndices, numFeatures, (int)leftChildSamples.size(), numOutputClasses, rootNode->getGeneration() + 1);
230 RFTreeNode* rightChildNode = new RFTreeNode(rightChildSamples, globalDiscardedFeatureIndices, numFeatures, (int)rightChildSamples.size(), numOutputClasses, rootNode->getGeneration() + 1);
232 rootNode->setLeftChildNode(leftChildNode);
233 leftChildNode->setParentNode(rootNode);
235 rootNode->setRightChildNode(rightChildNode);
236 rightChildNode->setParentNode(rootNode);
238 // TODO: This recursive split can be parrallelized later
239 splitRecursively(leftChildNode);
240 if (m->control_pressed) {return 0;}
242 splitRecursively(rightChildNode);
246 catch(exception& e) {
247 m->errorOut(e, "DecisionTree", "splitRecursively");
251 /***********************************************************************/
253 int DecisionTree::findAndUpdateBestFeatureToSplitOn(RFTreeNode* node){
256 vector< vector<int> > bootstrappedFeatureVectors = node->getBootstrappedFeatureVectors();
257 if (m->control_pressed) {return 0;}
258 vector<int> bootstrappedOutputVector = node->getBootstrappedOutputVector();
259 if (m->control_pressed) {return 0;}
260 vector<int> featureSubsetIndices = node->getFeatureSubsetIndices();
261 if (m->control_pressed) {return 0;}
263 vector<double> featureSubsetEntropies;
264 vector<int> featureSubsetSplitValues;
265 vector<double> featureSubsetIntrinsicValues;
266 vector<double> featureSubsetGainRatios;
268 for (int i = 0; i < featureSubsetIndices.size(); i++) {
269 if (m->control_pressed) {return 0;}
271 int tryIndex = featureSubsetIndices[i];
273 double featureMinEntropy;
274 int featureSplitValue;
275 double featureIntrinsicValue;
277 getMinEntropyOfFeature(bootstrappedFeatureVectors[tryIndex], bootstrappedOutputVector, featureMinEntropy, featureSplitValue, featureIntrinsicValue);
278 if (m->control_pressed) {return 0;}
280 featureSubsetEntropies.push_back(featureMinEntropy);
281 featureSubsetSplitValues.push_back(featureSplitValue);
282 featureSubsetIntrinsicValues.push_back(featureIntrinsicValue);
284 double featureInformationGain = node->getOwnEntropy() - featureMinEntropy;
285 double featureGainRatio = (double)featureInformationGain / (double)featureIntrinsicValue;
286 featureSubsetGainRatios.push_back(featureGainRatio);
290 vector<double>::iterator minEntropyIterator = min_element(featureSubsetEntropies.begin(), featureSubsetEntropies.end());
291 vector<double>::iterator maxGainRatioIterator = max_element(featureSubsetGainRatios.begin(), featureSubsetGainRatios.end());
292 double featureMinEntropy = *minEntropyIterator;
293 //double featureMaxGainRatio = *maxGainRatioIterator;
295 double bestFeatureSplitEntropy = featureMinEntropy;
296 int bestFeatureToSplitOnIndex = -1;
297 if (treeSplitCriterion == "gainRatio"){
298 bestFeatureToSplitOnIndex = (int)(maxGainRatioIterator - featureSubsetGainRatios.begin());
299 // if using 'gainRatio' measure, then featureMinEntropy must be re-updated, as the index
300 // for 'featureMaxGainRatio' would be different
301 bestFeatureSplitEntropy = featureSubsetEntropies[bestFeatureToSplitOnIndex];
303 else { bestFeatureToSplitOnIndex = (int)(minEntropyIterator - featureSubsetEntropies.begin()); }
305 int bestFeatureSplitValue = featureSubsetSplitValues[bestFeatureToSplitOnIndex];
307 node->setSplitFeatureIndex(featureSubsetIndices[bestFeatureToSplitOnIndex]);
308 node->setSplitFeatureValue(bestFeatureSplitValue);
309 node->setSplitFeatureEntropy(bestFeatureSplitEntropy);
313 catch(exception& e) {
314 m->errorOut(e, "DecisionTree", "findAndUpdateBestFeatureToSplitOn");
318 /***********************************************************************/
319 vector<int> DecisionTree::selectFeatureSubsetRandomly(vector<int> globalDiscardedFeatureIndices, vector<int> localDiscardedFeatureIndices){
322 vector<int> featureSubsetIndices;
324 vector<int> combinedDiscardedFeatureIndices;
325 combinedDiscardedFeatureIndices.insert(combinedDiscardedFeatureIndices.end(), globalDiscardedFeatureIndices.begin(), globalDiscardedFeatureIndices.end());
326 combinedDiscardedFeatureIndices.insert(combinedDiscardedFeatureIndices.end(), localDiscardedFeatureIndices.begin(), localDiscardedFeatureIndices.end());
328 sort(combinedDiscardedFeatureIndices.begin(), combinedDiscardedFeatureIndices.end());
330 int numberOfRemainingSuitableFeatures = (int)(numFeatures - combinedDiscardedFeatureIndices.size());
331 int currentFeatureSubsetSize = numberOfRemainingSuitableFeatures < optimumFeatureSubsetSize ? numberOfRemainingSuitableFeatures : optimumFeatureSubsetSize;
333 while (featureSubsetIndices.size() < currentFeatureSubsetSize) {
335 if (m->control_pressed) { return featureSubsetIndices; }
337 // TODO: optimize rand() call here
338 int randomIndex = rand() % numFeatures;
339 vector<int>::iterator it = find(featureSubsetIndices.begin(), featureSubsetIndices.end(), randomIndex);
340 if (it == featureSubsetIndices.end()){ // NOT FOUND
341 vector<int>::iterator it2 = find(combinedDiscardedFeatureIndices.begin(), combinedDiscardedFeatureIndices.end(), randomIndex);
342 if (it2 == combinedDiscardedFeatureIndices.end()){ // NOT FOUND AGAIN
343 featureSubsetIndices.push_back(randomIndex);
347 sort(featureSubsetIndices.begin(), featureSubsetIndices.end());
349 //#ifdef DEBUG_LEVEL_3
350 // PRINT_VAR(featureSubsetIndices);
353 return featureSubsetIndices;
355 catch(exception& e) {
356 m->errorOut(e, "DecisionTree", "selectFeatureSubsetRandomly");
360 /***********************************************************************/
362 // TODO: printTree() needs a check if correct
363 int DecisionTree::printTree(RFTreeNode* treeNode, string caption){
366 for (int i = 0; i < treeNode->getGeneration(); i++) { tabs += " "; }
367 // for (int i = 0; i < treeNode->getGeneration() - 1; i++) { tabs += "| "; }
368 // if (treeNode->getGeneration() != 0) { tabs += "|--"; }
370 if (treeNode != NULL && treeNode->checkIsLeaf() == false){
371 m->mothurOut(tabs + caption + " [ gen: " + toString(treeNode->getGeneration()) + " ] ( " + toString(treeNode->getSplitFeatureValue()) + " < X" + toString(treeNode->getSplitFeatureIndex()) +" )\n");
373 printTree(treeNode->getLeftChildNode(), "leftChild");
374 printTree(treeNode->getRightChildNode(), "rightChild");
376 m->mothurOut(tabs + caption + " [ gen: " + toString(treeNode->getGeneration()) + " ] ( classified to: " + toString(treeNode->getOutputClass()) + ", samples: " + toString(treeNode->getNumSamples()) + " )\n");
380 catch(exception& e) {
381 m->errorOut(e, "DecisionTree", "printTree");
385 /***********************************************************************/
386 void DecisionTree::deleteTreeNodesRecursively(RFTreeNode* treeNode) {
388 if (treeNode == NULL) { return; }
389 deleteTreeNodesRecursively(treeNode->leftChildNode);
390 deleteTreeNodesRecursively(treeNode->rightChildNode);
393 catch(exception& e) {
394 m->errorOut(e, "DecisionTree", "deleteTreeNodesRecursively");
398 /***********************************************************************/