]> git.donarmstrong.com Git - mothur.git/blob - decisiontree.cpp
working on pam
[mothur.git] / decisiontree.cpp
1 //
2 //  decisiontree.cpp
3 //  Mothur
4 //
5 //  Created by Sarah Westcott on 10/1/12.
6 //  Copyright (c) 2012 Schloss Lab. All rights reserved.
7 //
8
9 #include "decisiontree.hpp"
10
11 DecisionTree::DecisionTree(vector< vector<int> >& baseDataSet,
12                            vector<int> globalDiscardedFeatureIndices,
13                            OptimumFeatureSubsetSelector optimumFeatureSubsetSelector,
14                            string treeSplitCriterion,
15                            float featureStandardDeviationThreshold)
16             : AbstractDecisionTree(baseDataSet,
17                                    globalDiscardedFeatureIndices,
18                                    optimumFeatureSubsetSelector,
19                                    treeSplitCriterion),
20             variableImportanceList(numFeatures, 0),
21             featureStandardDeviationThreshold(featureStandardDeviationThreshold) {
22                 
23     try {
24         m = MothurOut::getInstance();
25         createBootStrappedSamples();
26         buildDecisionTree();
27     }
28         catch(exception& e) {
29                 m->errorOut(e, "DecisionTree", "DecisionTree");
30                 exit(1);
31         } 
32 }
33
34 /***********************************************************************/
35
36 int DecisionTree::calcTreeVariableImportanceAndError(int& numCorrect, double& treeErrorRate) {
37     try {
38         vector< vector<int> > randomlySampledTestData(bootstrappedTestSamples.size(), vector<int>(bootstrappedTestSamples[0].size(), 0));
39         
40             // TODO: is is possible to further speed up the following O(N^2) by using std::copy?
41         for (int i = 0; i < bootstrappedTestSamples.size(); i++) {
42             for (int j = 0; j < bootstrappedTestSamples[i].size(); j++) {
43                 randomlySampledTestData[i][j] = bootstrappedTestSamples[i][j];
44             }
45         }
46         
47         for (int i = 0; i < numFeatures; i++) {
48             if (m->control_pressed) { return 0; }
49             
50                 // if the index is in globalDiscardedFeatureIndices (i.e, null feature) we don't want to shuffle them
51             vector<int>::iterator it = find(globalDiscardedFeatureIndices.begin(), globalDiscardedFeatureIndices.end(), i);
52             if (it == globalDiscardedFeatureIndices.end()) {        // NOT FOUND
53                 // if the standard deviation is very low, we know it's not a good feature at all
54                 // we can save some time here by discarding that feature
55                 
56                 vector<int> featureVector = testSampleFeatureVectors[i];
57                 if (m->getStandardDeviation(featureVector) > featureStandardDeviationThreshold) {
58                     // NOTE: only shuffle the features, never shuffle the output vector
59                     // so i = 0 and i will be alwaays <= (numFeatures - 1) as the index at numFeatures will denote
60                     // the feature vector
61                     randomlyShuffleAttribute(bootstrappedTestSamples, i, i - 1, randomlySampledTestData);
62
63                     int numCorrectAfterShuffle = 0;
64                     for (int j = 0; j < randomlySampledTestData.size(); j++) {
65                         if (m->control_pressed) {return 0; }
66                         
67                         vector<int> shuffledSample = randomlySampledTestData[j];
68                         int actualSampleOutputClass = shuffledSample[numFeatures];
69                         int predictedSampleOutputClass = evaluateSample(shuffledSample);
70                         if (actualSampleOutputClass == predictedSampleOutputClass) { numCorrectAfterShuffle++; }
71                     }
72                     variableImportanceList[i] += (numCorrect - numCorrectAfterShuffle);
73                 }
74             }
75         }
76         
77         // TODO: do we need to save the variableRanks in the DecisionTree, do we need it later?
78         vector< pair<int, int> > variableRanks;
79         
80         for (int i = 0; i < variableImportanceList.size(); i++) {
81             if (m->control_pressed) {return 0; }
82             if (variableImportanceList[i] > 0) {
83                 // TODO: is there a way to optimize the follow line's code?
84                 pair<int, int> variableRank(0, 0);
85                 variableRank.first = i;
86                 variableRank.second = variableImportanceList[i];
87                 variableRanks.push_back(variableRank);
88             }
89         }
90         VariableRankDescendingSorter variableRankDescendingSorter;
91         sort(variableRanks.begin(), variableRanks.end(), variableRankDescendingSorter);
92         
93         return 0;
94     }
95         catch(exception& e) {
96                 m->errorOut(e, "DecisionTree", "calcTreeVariableImportanceAndError");
97                 exit(1);
98         } 
99
100 }
101 /***********************************************************************/
102
103 // TODO: there must be a way to optimize this function
104 int DecisionTree::evaluateSample(vector<int> testSample) {
105     try {
106         RFTreeNode *node = rootNode;
107         while (true) {
108             if (m->control_pressed) { return 0; }
109             
110             if (node->checkIsLeaf()) { return node->getOutputClass(); }
111             
112             int sampleSplitFeatureValue = testSample[node->getSplitFeatureIndex()];
113             if (sampleSplitFeatureValue < node->getSplitFeatureValue()) { node = node->getLeftChildNode(); }
114             else { node = node->getRightChildNode(); } 
115         }
116         return 0;
117     }
118         catch(exception& e) {
119                 m->errorOut(e, "DecisionTree", "evaluateSample");
120                 exit(1);
121         } 
122
123 }
124 /***********************************************************************/
125
126 int DecisionTree::calcTreeErrorRate(int& numCorrect, double& treeErrorRate){
127     numCorrect = 0;
128     try {
129         for (int i = 0; i < bootstrappedTestSamples.size(); i++) {
130              if (m->control_pressed) {return 0; }
131             
132             vector<int> testSample = bootstrappedTestSamples[i];
133             int testSampleIndex = bootstrappedTestSampleIndices[i];
134             
135             int actualSampleOutputClass = testSample[numFeatures];
136             int predictedSampleOutputClass = evaluateSample(testSample);
137             
138             if (actualSampleOutputClass == predictedSampleOutputClass) { numCorrect++; } 
139             
140             outOfBagEstimates[testSampleIndex] = predictedSampleOutputClass;
141         }
142         
143         treeErrorRate = 1 - ((double)numCorrect / (double)bootstrappedTestSamples.size());   
144         
145         return 0;
146     }
147         catch(exception& e) {
148                 m->errorOut(e, "DecisionTree", "calcTreeErrorRate");
149                 exit(1);
150         } 
151 }
152
153 /***********************************************************************/
154 // TODO: optimize the algo, instead of transposing two time, we can extarct the feature,
155 // shuffle it and then re-insert in the original place, thus iproving runnting time
156 //This function randomize abundances for a given OTU/feature.
157
158 void DecisionTree::randomlyShuffleAttribute(const vector< vector<int> >& samples,
159                                const int featureIndex,
160                                const int prevFeatureIndex,
161                                vector< vector<int> >& shuffledSample) {
162     try {
163         // NOTE: we need (numFeatures + 1) featureVecotors, the last extra vector is actually outputVector
164         
165         // restore previously shuffled feature
166         if (prevFeatureIndex > -1) {
167             for (int j = 0; j < samples.size(); j++) {
168                 if (m->control_pressed) { return; }
169                 shuffledSample[j][prevFeatureIndex] = samples[j][prevFeatureIndex];
170             }
171         }
172         
173         // now do the shuffling
174         vector<int> featureVectors(samples.size(), 0);
175         for (int j = 0; j < samples.size(); j++) {
176             if (m->control_pressed) { return; }
177             featureVectors[j] = samples[j][featureIndex];
178         }
179         random_shuffle(featureVectors.begin(), featureVectors.end());
180         for (int j = 0; j < samples.size(); j++) {
181             if (m->control_pressed) { return; }
182             shuffledSample[j][featureIndex] = featureVectors[j];
183         }
184     }
185         catch(exception& e) {
186         m->errorOut(e, "DecisionTree", "randomlyShuffleAttribute");
187                 exit(1);
188         }
189     
190 }
191
192 /***********************************************************************/
193
194 int DecisionTree::purgeTreeNodesDataRecursively(RFTreeNode* treeNode) {
195     try {
196         treeNode->bootstrappedTrainingSamples.clear();
197         treeNode->bootstrappedFeatureVectors.clear();
198         treeNode->bootstrappedOutputVector.clear();
199         treeNode->localDiscardedFeatureIndices.clear();
200         treeNode->globalDiscardedFeatureIndices.clear();
201         
202         if (treeNode->leftChildNode != NULL) { purgeTreeNodesDataRecursively(treeNode->leftChildNode); }
203         if (treeNode->rightChildNode != NULL) { purgeTreeNodesDataRecursively(treeNode->rightChildNode); }
204         return 0;
205     }
206         catch(exception& e) {
207                 m->errorOut(e, "DecisionTree", "purgeTreeNodesDataRecursively");
208                 exit(1);
209         } 
210 }
211 /***********************************************************************/
212
213 void DecisionTree::buildDecisionTree(){
214     try {
215     
216         int generation = 0;
217         rootNode = new RFTreeNode(bootstrappedTrainingSamples, globalDiscardedFeatureIndices, numFeatures, numSamples, numOutputClasses, generation, nodeIdCount, featureStandardDeviationThreshold);
218         nodeIdCount++;
219         
220         splitRecursively(rootNode);
221         
222         }
223         catch(exception& e) {
224                 m->errorOut(e, "DecisionTree", "buildDecisionTree");
225                 exit(1);
226         } 
227 }
228
229 /***********************************************************************/
230
231 int DecisionTree::splitRecursively(RFTreeNode* rootNode) {
232     try {
233        
234         if (rootNode->getNumSamples() < 2){
235             rootNode->setIsLeaf(true);
236             rootNode->setOutputClass(rootNode->getBootstrappedTrainingSamples()[0][rootNode->getNumFeatures()]);
237             return 0;
238         }
239         
240         int classifiedOutputClass;
241         bool isAlreadyClassified = checkIfAlreadyClassified(rootNode, classifiedOutputClass);    
242         if (isAlreadyClassified == true){
243             rootNode->setIsLeaf(true);
244             rootNode->setOutputClass(classifiedOutputClass);
245             return 0;
246         }
247         if (m->control_pressed) { return 0; }
248         vector<int> featureSubsetIndices = selectFeatureSubsetRandomly(globalDiscardedFeatureIndices, rootNode->getLocalDiscardedFeatureIndices());
249         
250             // TODO: need to check if the value is actually copied correctly
251         rootNode->setFeatureSubsetIndices(featureSubsetIndices);
252         if (m->control_pressed) { return 0; }
253       
254         findAndUpdateBestFeatureToSplitOn(rootNode);
255         
256         // update rootNode outputClass, this is needed for pruning
257         // this is only for internal nodes
258         updateOutputClassOfNode(rootNode);
259         
260         if (m->control_pressed) { return 0; }
261         
262         vector< vector<int> > leftChildSamples;
263         vector< vector<int> > rightChildSamples;
264         getSplitPopulation(rootNode, leftChildSamples, rightChildSamples);
265         
266         if (m->control_pressed) { return 0; }
267         
268         // TODO: need to write code to clear this memory
269         RFTreeNode* leftChildNode = new RFTreeNode(leftChildSamples, globalDiscardedFeatureIndices, numFeatures, (int)leftChildSamples.size(), numOutputClasses, rootNode->getGeneration() + 1, nodeIdCount, featureStandardDeviationThreshold);
270         nodeIdCount++;
271         RFTreeNode* rightChildNode = new RFTreeNode(rightChildSamples, globalDiscardedFeatureIndices, numFeatures, (int)rightChildSamples.size(), numOutputClasses, rootNode->getGeneration() + 1, nodeIdCount, featureStandardDeviationThreshold);
272         nodeIdCount++;
273         
274         rootNode->setLeftChildNode(leftChildNode);
275         leftChildNode->setParentNode(rootNode);
276         
277         rootNode->setRightChildNode(rightChildNode);
278         rightChildNode->setParentNode(rootNode);
279         
280         // TODO: This recursive split can be parrallelized later
281         splitRecursively(leftChildNode);
282         if (m->control_pressed) { return 0; }
283         
284         splitRecursively(rightChildNode);
285         return 0;
286         
287     }
288         catch(exception& e) {
289                 m->errorOut(e, "DecisionTree", "splitRecursively");
290                 exit(1);
291         } 
292 }
293 /***********************************************************************/
294
295 int DecisionTree::findAndUpdateBestFeatureToSplitOn(RFTreeNode* node){
296     try {
297
298         vector< vector<int> > bootstrappedFeatureVectors = node->getBootstrappedFeatureVectors();
299         if (m->control_pressed) { return 0; }
300         vector<int> bootstrappedOutputVector = node->getBootstrappedOutputVector();
301         if (m->control_pressed) { return 0; }
302         vector<int> featureSubsetIndices = node->getFeatureSubsetIndices();
303         if (m->control_pressed) { return 0; }
304         
305         vector<double> featureSubsetEntropies;
306         vector<int> featureSubsetSplitValues;
307         vector<double> featureSubsetIntrinsicValues;
308         vector<double> featureSubsetGainRatios;
309         
310         for (int i = 0; i < featureSubsetIndices.size(); i++) {
311             if (m->control_pressed) { return 0; }
312             
313             int tryIndex = featureSubsetIndices[i];
314                        
315             double featureMinEntropy;
316             int featureSplitValue;
317             double featureIntrinsicValue;
318             
319             getMinEntropyOfFeature(bootstrappedFeatureVectors[tryIndex], bootstrappedOutputVector, featureMinEntropy, featureSplitValue, featureIntrinsicValue);
320             if (m->control_pressed) { return 0; }
321             
322             featureSubsetEntropies.push_back(featureMinEntropy);
323             featureSubsetSplitValues.push_back(featureSplitValue);
324             featureSubsetIntrinsicValues.push_back(featureIntrinsicValue);
325             
326             double featureInformationGain = node->getOwnEntropy() - featureMinEntropy;
327             double featureGainRatio = (double)featureInformationGain / (double)featureIntrinsicValue;
328             featureSubsetGainRatios.push_back(featureGainRatio);
329             
330         }
331         
332         vector<double>::iterator minEntropyIterator = min_element(featureSubsetEntropies.begin(), featureSubsetEntropies.end());
333         vector<double>::iterator maxGainRatioIterator = max_element(featureSubsetGainRatios.begin(), featureSubsetGainRatios.end());
334         double featureMinEntropy = *minEntropyIterator;
335         
336         // TODO: kept the following line as future reference, can be useful
337         // double featureMaxGainRatio = *maxGainRatioIterator;
338         
339         double bestFeatureSplitEntropy = featureMinEntropy;
340         int bestFeatureToSplitOnIndex = -1;
341         if (treeSplitCriterion == "gainratio"){
342             bestFeatureToSplitOnIndex = (int)(maxGainRatioIterator - featureSubsetGainRatios.begin());
343             // if using 'gainRatio' measure, then featureMinEntropy must be re-updated, as the index
344             // for 'featureMaxGainRatio' would be different
345             bestFeatureSplitEntropy = featureSubsetEntropies[bestFeatureToSplitOnIndex];
346         } else  if ( treeSplitCriterion == "infogain"){
347             bestFeatureToSplitOnIndex = (int)(minEntropyIterator - featureSubsetEntropies.begin());
348         } else {
349                 // TODO: we need an abort mechanism here
350         }
351         
352             // TODO: is the following line needed? kept is as future reference
353         // splitInformationGain = node.ownEntropy - node.splitFeatureEntropy
354         
355         int bestFeatureSplitValue = featureSubsetSplitValues[bestFeatureToSplitOnIndex];
356         
357         node->setSplitFeatureIndex(featureSubsetIndices[bestFeatureToSplitOnIndex]);
358         node->setSplitFeatureValue(bestFeatureSplitValue);
359         node->setSplitFeatureEntropy(bestFeatureSplitEntropy);
360             // TODO: kept the following line as future reference
361         // node.splitInformationGain = splitInformationGain
362         
363         return 0;
364     }
365         catch(exception& e) {
366                 m->errorOut(e, "DecisionTree", "findAndUpdateBestFeatureToSplitOn");
367                 exit(1);
368         } 
369 }
370 /***********************************************************************/
371 vector<int> DecisionTree::selectFeatureSubsetRandomly(vector<int> globalDiscardedFeatureIndices, vector<int> localDiscardedFeatureIndices){
372     try {
373
374         vector<int> featureSubsetIndices;
375         
376         vector<int> combinedDiscardedFeatureIndices;
377         combinedDiscardedFeatureIndices.insert(combinedDiscardedFeatureIndices.end(), globalDiscardedFeatureIndices.begin(), globalDiscardedFeatureIndices.end());
378         combinedDiscardedFeatureIndices.insert(combinedDiscardedFeatureIndices.end(), localDiscardedFeatureIndices.begin(), localDiscardedFeatureIndices.end());
379         
380         sort(combinedDiscardedFeatureIndices.begin(), combinedDiscardedFeatureIndices.end());
381         
382         int numberOfRemainingSuitableFeatures = (int)(numFeatures - combinedDiscardedFeatureIndices.size());
383         int currentFeatureSubsetSize = numberOfRemainingSuitableFeatures < optimumFeatureSubsetSize ? numberOfRemainingSuitableFeatures : optimumFeatureSubsetSize;
384         
385         while (featureSubsetIndices.size() < currentFeatureSubsetSize) {
386             
387             if (m->control_pressed) { return featureSubsetIndices; }
388             
389             // TODO: optimize rand() call here
390             int randomIndex = rand() % numFeatures;
391             vector<int>::iterator it = find(featureSubsetIndices.begin(), featureSubsetIndices.end(), randomIndex);
392             if (it == featureSubsetIndices.end()){    // NOT FOUND
393                 vector<int>::iterator it2 = find(combinedDiscardedFeatureIndices.begin(), combinedDiscardedFeatureIndices.end(), randomIndex);
394                 if (it2 == combinedDiscardedFeatureIndices.end()){  // NOT FOUND AGAIN
395                     featureSubsetIndices.push_back(randomIndex);
396                 }
397             }
398         }
399         sort(featureSubsetIndices.begin(), featureSubsetIndices.end());
400         
401         //#ifdef DEBUG_LEVEL_3
402         //    PRINT_VAR(featureSubsetIndices);
403         //#endif
404         
405         return featureSubsetIndices;
406     }
407         catch(exception& e) {
408                 m->errorOut(e, "DecisionTree", "selectFeatureSubsetRandomly");
409                 exit(1);
410         } 
411 }
412 /***********************************************************************/
413
414 // TODO: printTree() needs a check if correct
415 int DecisionTree::printTree(RFTreeNode* treeNode, string caption){
416     try { 
417         string tabs = "";
418         for (int i = 0; i < treeNode->getGeneration(); i++) { tabs += "|--"; }
419         //    for (int i = 0; i < treeNode->getGeneration() - 1; i++) { tabs += "|  "; }
420         //    if (treeNode->getGeneration() != 0) { tabs += "|--"; }
421         
422         if (treeNode != NULL && treeNode->checkIsLeaf() == false){
423             m->mothurOut(tabs + caption + " [ gen: " + toString(treeNode->getGeneration()) + " , id: " + toString(treeNode->nodeId) + " ] ( " + toString(treeNode->getSplitFeatureValue()) + " < X" + toString(treeNode->getSplitFeatureIndex()) + " ) ( predicted: " + toString(treeNode->outputClass) + " , misclassified: " + toString(treeNode->testSampleMisclassificationCount) + " )\n");
424             
425             printTree(treeNode->getLeftChildNode(), "left ");
426             printTree(treeNode->getRightChildNode(), "right");
427         }else {
428             m->mothurOut(tabs + caption + " [ gen: " + toString(treeNode->getGeneration()) + " , id: " + toString(treeNode->nodeId) + " ] ( classified to: " + toString(treeNode->getOutputClass()) + ", samples: " + toString(treeNode->getNumSamples()) + " , misclassified: " + toString(treeNode->testSampleMisclassificationCount) + " )\n");
429         }
430         return 0;
431     }
432         catch(exception& e) {
433                 m->errorOut(e, "DecisionTree", "printTree");
434                 exit(1);
435         } 
436 }
437 /***********************************************************************/
438 void DecisionTree::deleteTreeNodesRecursively(RFTreeNode* treeNode) {
439     try {
440         if (treeNode == NULL) { return; }
441         deleteTreeNodesRecursively(treeNode->leftChildNode);
442         deleteTreeNodesRecursively(treeNode->rightChildNode);
443         delete treeNode; treeNode = NULL;
444     }
445         catch(exception& e) {
446                 m->errorOut(e, "DecisionTree", "deleteTreeNodesRecursively");
447                 exit(1);
448         } 
449 }
450 /***********************************************************************/
451
452 void DecisionTree::pruneTree(double pruneAggressiveness = 0.9) {
453     
454     // find out the number of misclassification by each of the nodes
455     for (int i = 0; i < bootstrappedTestSamples.size(); i++) {
456         if (m->control_pressed) { return; }
457         
458         vector<int> testSample = bootstrappedTestSamples[i];
459         updateMisclassificationCountRecursively(rootNode, testSample);
460     }
461     
462     // do the actual pruning
463     pruneRecursively(rootNode, pruneAggressiveness);
464 }
465 /***********************************************************************/
466
467 void DecisionTree::pruneRecursively(RFTreeNode* treeNode, double pruneAggressiveness){
468     
469     if (treeNode != NULL && treeNode->checkIsLeaf() == false) {
470         if (m->control_pressed) { return; }
471         
472         pruneRecursively(treeNode->leftChildNode, pruneAggressiveness);
473         pruneRecursively(treeNode->rightChildNode, pruneAggressiveness);
474         
475         int subTreeMisclassificationCount = treeNode->leftChildNode->getTestSampleMisclassificationCount() + treeNode->rightChildNode->getTestSampleMisclassificationCount();
476         int ownMisclassificationCount = treeNode->getTestSampleMisclassificationCount();
477         
478         if (subTreeMisclassificationCount * pruneAggressiveness > ownMisclassificationCount) {
479                 // TODO: need to check the effect of these two delete calls
480             delete treeNode->leftChildNode;
481             treeNode->leftChildNode = NULL;
482             
483             delete treeNode->rightChildNode;
484             treeNode->rightChildNode = NULL;
485             
486             treeNode->isLeaf = true;
487         }
488         
489     }
490 }
491 /***********************************************************************/
492
493 void DecisionTree::updateMisclassificationCountRecursively(RFTreeNode* treeNode, vector<int> testSample) {
494     
495     int actualSampleOutputClass = testSample[numFeatures];
496     int nodePredictedOutputClass = treeNode->outputClass;
497     
498     if (actualSampleOutputClass != nodePredictedOutputClass) {
499         treeNode->testSampleMisclassificationCount++;
500         map<int, int>::iterator it = nodeMisclassificationCounts.find(treeNode->nodeId);
501         if (it == nodeMisclassificationCounts.end()) {  // NOT FOUND
502             nodeMisclassificationCounts[treeNode->nodeId] = 0;
503         }
504         nodeMisclassificationCounts[treeNode->nodeId]++;
505     }
506     
507     if (treeNode->checkIsLeaf() == false) { // NOT A LEAF
508         int sampleSplitFeatureValue = testSample[treeNode->splitFeatureIndex];
509         if (sampleSplitFeatureValue < treeNode->splitFeatureValue) {
510             updateMisclassificationCountRecursively(treeNode->leftChildNode, testSample);
511         } else {
512             updateMisclassificationCountRecursively(treeNode->rightChildNode, testSample);
513         }
514     }
515 }
516
517 /***********************************************************************/
518
519 void DecisionTree::updateOutputClassOfNode(RFTreeNode* treeNode) {
520     vector<int> counts(numOutputClasses, 0);
521     for (int i = 0; i < treeNode->bootstrappedOutputVector.size(); i++) {
522         int bootstrappedOutput = treeNode->bootstrappedOutputVector[i];
523         counts[bootstrappedOutput]++;
524     }
525
526     vector<int>::iterator majorityVotedOutputClassCountIterator = max_element(counts.begin(), counts.end());
527     int majorityVotedOutputClassCount = *majorityVotedOutputClassCountIterator;
528     vector<int>::iterator it = find(counts.begin(), counts.end(), majorityVotedOutputClassCount);
529     int majorityVotedOutputClass = (int)(it - counts.begin());
530     treeNode->setOutputClass(majorityVotedOutputClass);
531
532 }
533 /***********************************************************************/
534
535