]> git.donarmstrong.com Git - mothur.git/blob - decisiontree.cpp
sffinfo bug with flow grams right index when clipQualRight=0
[mothur.git] / decisiontree.cpp
1 //
2 //  decisiontree.cpp
3 //  Mothur
4 //
5 //  Created by Sarah Westcott on 10/1/12.
6 //  Copyright (c) 2012 Schloss Lab. All rights reserved.
7 //
8
9 #include "decisiontree.hpp"
10
11 DecisionTree::DecisionTree(vector< vector<int> > baseDataSet,
12              vector<int> globalDiscardedFeatureIndices,
13              OptimumFeatureSubsetSelector optimumFeatureSubsetSelector,
14              string treeSplitCriterion) : AbstractDecisionTree(baseDataSet,
15                        globalDiscardedFeatureIndices,
16                        optimumFeatureSubsetSelector,
17                        treeSplitCriterion), variableImportanceList(numFeatures, 0){
18     try {
19         m = MothurOut::getInstance();
20         createBootStrappedSamples();
21         buildDecisionTree();
22     }
23         catch(exception& e) {
24                 m->errorOut(e, "DecisionTree", "DecisionTree");
25                 exit(1);
26         } 
27 }
28
29 /***********************************************************************/
30
31 int DecisionTree::calcTreeVariableImportanceAndError() {
32     try {
33         
34         int numCorrect;
35         double treeErrorRate;
36         calcTreeErrorRate(numCorrect, treeErrorRate);
37         
38         if (m->control_pressed) {return 0; }
39                 
40         for (int i = 0; i < numFeatures; i++) {
41             if (m->control_pressed) {return 0; }
42             // NOTE: only shuffle the features, never shuffle the output vector
43             // so i = 0 and i will be alwaays <= (numFeatures - 1) as the index at numFeatures will denote
44             // the feature vector
45             vector< vector<int> > randomlySampledTestData = randomlyShuffleAttribute(bootstrappedTestSamples, i);
46             
47             int numCorrectAfterShuffle = 0;
48             for (int j = 0; j < randomlySampledTestData.size(); j++) {
49                 if (m->control_pressed) {return 0; }
50                 vector<int> shuffledSample = randomlySampledTestData[j];
51                 int actualSampleOutputClass = shuffledSample[numFeatures];
52                 int predictedSampleOutputClass = evaluateSample(shuffledSample);
53                 if (actualSampleOutputClass == predictedSampleOutputClass) { numCorrectAfterShuffle++; }
54             }
55             variableImportanceList[i] += (numCorrect - numCorrectAfterShuffle);
56         }
57         
58         // TODO: do we need to save the variableRanks in the DecisionTree, do we need it later?
59         vector< vector<int> > variableRanks;
60         for (int i = 0; i < variableImportanceList.size(); i++) {
61             if (m->control_pressed) {return 0; }
62             if (variableImportanceList[i] > 0) {
63                 // TODO: is there a way to optimize the follow line's code?
64                 vector<int> variableRank(2, 0);
65                 variableRank[0] = i; variableRank[1] = variableImportanceList[i];
66                 variableRanks.push_back(variableRank);
67             }
68         }
69         VariableRankDescendingSorter variableRankDescendingSorter;
70         sort(variableRanks.begin(), variableRanks.end(), variableRankDescendingSorter);
71         
72         return 0;
73     }
74         catch(exception& e) {
75                 m->errorOut(e, "DecisionTree", "calcTreeVariableImportanceAndError");
76                 exit(1);
77         } 
78
79 }
80 /***********************************************************************/
81
82 // TODO: there must be a way to optimize this function
83 int DecisionTree::evaluateSample(vector<int> testSample) {
84     try {
85         RFTreeNode *node = rootNode;
86         while (true) {
87             if (m->control_pressed) {return 0; }
88             if (node->checkIsLeaf()) { return node->getOutputClass(); }
89             int sampleSplitFeatureValue = testSample[node->getSplitFeatureIndex()];
90             if (sampleSplitFeatureValue < node->getSplitFeatureValue()) { node = node->getLeftChildNode(); }
91             else { node = node->getRightChildNode(); } 
92         }
93         return 0;
94     }
95         catch(exception& e) {
96                 m->errorOut(e, "DecisionTree", "evaluateSample");
97                 exit(1);
98         } 
99
100 }
101 /***********************************************************************/
102
103 int DecisionTree::calcTreeErrorRate(int& numCorrect, double& treeErrorRate){
104     try {
105         numCorrect = 0;
106         for (int i = 0; i < bootstrappedTestSamples.size(); i++) {
107              if (m->control_pressed) {return 0; }
108             
109             vector<int> testSample = bootstrappedTestSamples[i];
110             int testSampleIndex = bootstrappedTestSampleIndices[i];
111             
112             int actualSampleOutputClass = testSample[numFeatures];
113             int predictedSampleOutputClass = evaluateSample(testSample);
114             
115             if (actualSampleOutputClass == predictedSampleOutputClass) { numCorrect++; } 
116             
117             outOfBagEstimates[testSampleIndex] = predictedSampleOutputClass;
118         }
119         
120         treeErrorRate = 1 - ((double)numCorrect / (double)bootstrappedTestSamples.size());   
121         
122         return 0;
123     }
124         catch(exception& e) {
125                 m->errorOut(e, "DecisionTree", "calcTreeErrorRate");
126                 exit(1);
127         } 
128 }
129
130 /***********************************************************************/
131
132 // TODO: optimize the algo, instead of transposing two time, we can extarct the feature,
133 // shuffle it and then re-insert in the original place, thus iproving runnting time
134 //This function randomize abundances for a given OTU/feature.
135 vector< vector<int> > DecisionTree::randomlyShuffleAttribute(vector< vector<int> > samples, int featureIndex) {
136     try {
137         // NOTE: we need (numFeatures + 1) featureVecotors, the last extra vector is actually outputVector
138         vector< vector<int> > shuffledSample = samples;
139         vector<int> featureVectors(samples.size(), 0);
140         
141         for (int j = 0; j < samples.size(); j++) {
142             if (m->control_pressed) { return shuffledSample; }
143             featureVectors[j] = samples[j][featureIndex];
144         }
145         
146         random_shuffle(featureVectors.begin(), featureVectors.end());
147
148         for (int j = 0; j < samples.size(); j++) {
149             if (m->control_pressed) {return shuffledSample; }
150             shuffledSample[j][featureIndex] = featureVectors[j];
151         }
152         
153         return shuffledSample;
154     }
155         catch(exception& e) {
156                 m->errorOut(e, "DecisionTree", "randomlyShuffleAttribute");
157                 exit(1);
158         } 
159 }
160 /***********************************************************************/
161
162 int DecisionTree::purgeTreeNodesDataRecursively(RFTreeNode* treeNode) {
163     try {
164         treeNode->bootstrappedTrainingSamples.clear();
165         treeNode->bootstrappedFeatureVectors.clear();
166         treeNode->bootstrappedOutputVector.clear();
167         treeNode->localDiscardedFeatureIndices.clear();
168         treeNode->globalDiscardedFeatureIndices.clear();
169         
170         if (treeNode->leftChildNode != NULL) { purgeTreeNodesDataRecursively(treeNode->leftChildNode); }
171         if (treeNode->rightChildNode != NULL) { purgeTreeNodesDataRecursively(treeNode->rightChildNode); }
172         return 0;
173     }
174         catch(exception& e) {
175                 m->errorOut(e, "DecisionTree", "purgeTreeNodesDataRecursively");
176                 exit(1);
177         } 
178 }
179 /***********************************************************************/
180
181 void DecisionTree::buildDecisionTree(){
182     try {
183     
184     int generation = 0;
185     rootNode = new RFTreeNode(bootstrappedTrainingSamples, globalDiscardedFeatureIndices, numFeatures, numSamples, numOutputClasses, generation);
186     
187     splitRecursively(rootNode);
188         }
189         catch(exception& e) {
190                 m->errorOut(e, "DecisionTree", "buildDecisionTree");
191                 exit(1);
192         } 
193 }
194
195 /***********************************************************************/
196
197 int DecisionTree::splitRecursively(RFTreeNode* rootNode) {
198     try {
199        
200         if (rootNode->getNumSamples() < 2){
201             rootNode->setIsLeaf(true);
202             rootNode->setOutputClass(rootNode->getBootstrappedTrainingSamples()[0][rootNode->getNumFeatures()]);
203             return 0;
204         }
205         
206         int classifiedOutputClass;
207         bool isAlreadyClassified = checkIfAlreadyClassified(rootNode, classifiedOutputClass);    
208         if (isAlreadyClassified == true){
209             rootNode->setIsLeaf(true);
210             rootNode->setOutputClass(classifiedOutputClass);
211             return 0;
212         }
213         if (m->control_pressed) {return 0;}
214         vector<int> featureSubsetIndices = selectFeatureSubsetRandomly(globalDiscardedFeatureIndices, rootNode->getLocalDiscardedFeatureIndices());
215         rootNode->setFeatureSubsetIndices(featureSubsetIndices);
216         if (m->control_pressed) {return 0;}
217       
218         findAndUpdateBestFeatureToSplitOn(rootNode);
219         
220         if (m->control_pressed) {return 0;}
221         
222         vector< vector<int> > leftChildSamples;
223         vector< vector<int> > rightChildSamples;
224         getSplitPopulation(rootNode, leftChildSamples, rightChildSamples);
225         
226         if (m->control_pressed) {return 0;}
227         
228         // TODO: need to write code to clear this memory
229         RFTreeNode* leftChildNode = new RFTreeNode(leftChildSamples, globalDiscardedFeatureIndices, numFeatures, (int)leftChildSamples.size(), numOutputClasses, rootNode->getGeneration() + 1);
230         RFTreeNode* rightChildNode = new RFTreeNode(rightChildSamples, globalDiscardedFeatureIndices, numFeatures, (int)rightChildSamples.size(), numOutputClasses, rootNode->getGeneration() + 1);
231         
232         rootNode->setLeftChildNode(leftChildNode);
233         leftChildNode->setParentNode(rootNode);
234         
235         rootNode->setRightChildNode(rightChildNode);
236         rightChildNode->setParentNode(rootNode);
237         
238         // TODO: This recursive split can be parrallelized later
239         splitRecursively(leftChildNode);
240         if (m->control_pressed) {return 0;}
241         
242         splitRecursively(rightChildNode);
243         return 0;
244         
245     }
246         catch(exception& e) {
247                 m->errorOut(e, "DecisionTree", "splitRecursively");
248                 exit(1);
249         } 
250 }
251 /***********************************************************************/
252
253 int DecisionTree::findAndUpdateBestFeatureToSplitOn(RFTreeNode* node){
254     try {
255
256         vector< vector<int> > bootstrappedFeatureVectors = node->getBootstrappedFeatureVectors();
257         if (m->control_pressed) {return 0;}
258         vector<int> bootstrappedOutputVector = node->getBootstrappedOutputVector();
259         if (m->control_pressed) {return 0;}
260         vector<int> featureSubsetIndices = node->getFeatureSubsetIndices();
261         if (m->control_pressed) {return 0;}
262         
263         vector<double> featureSubsetEntropies;
264         vector<int> featureSubsetSplitValues;
265         vector<double> featureSubsetIntrinsicValues;
266         vector<double> featureSubsetGainRatios;
267         
268         for (int i = 0; i < featureSubsetIndices.size(); i++) {
269             if (m->control_pressed) {return 0;}
270             
271             int tryIndex = featureSubsetIndices[i];
272                        
273             double featureMinEntropy;
274             int featureSplitValue;
275             double featureIntrinsicValue;
276             
277             getMinEntropyOfFeature(bootstrappedFeatureVectors[tryIndex], bootstrappedOutputVector, featureMinEntropy, featureSplitValue, featureIntrinsicValue);
278             if (m->control_pressed) {return 0;}
279             
280             featureSubsetEntropies.push_back(featureMinEntropy);
281             featureSubsetSplitValues.push_back(featureSplitValue);
282             featureSubsetIntrinsicValues.push_back(featureIntrinsicValue);
283             
284             double featureInformationGain = node->getOwnEntropy() - featureMinEntropy;
285             double featureGainRatio = (double)featureInformationGain / (double)featureIntrinsicValue;
286             featureSubsetGainRatios.push_back(featureGainRatio);
287             
288         }
289         
290         vector<double>::iterator minEntropyIterator = min_element(featureSubsetEntropies.begin(), featureSubsetEntropies.end());
291         vector<double>::iterator maxGainRatioIterator = max_element(featureSubsetGainRatios.begin(), featureSubsetGainRatios.end());
292         double featureMinEntropy = *minEntropyIterator;
293         //double featureMaxGainRatio = *maxGainRatioIterator;
294         
295         double bestFeatureSplitEntropy = featureMinEntropy;
296         int bestFeatureToSplitOnIndex = -1;
297         if (treeSplitCriterion == "gainRatio"){ 
298             bestFeatureToSplitOnIndex = (int)(maxGainRatioIterator - featureSubsetGainRatios.begin());
299             // if using 'gainRatio' measure, then featureMinEntropy must be re-updated, as the index
300             // for 'featureMaxGainRatio' would be different
301             bestFeatureSplitEntropy = featureSubsetEntropies[bestFeatureToSplitOnIndex];
302         }
303         else { bestFeatureToSplitOnIndex = (int)(minEntropyIterator - featureSubsetEntropies.begin()); }
304         
305         int bestFeatureSplitValue = featureSubsetSplitValues[bestFeatureToSplitOnIndex];
306         
307         node->setSplitFeatureIndex(featureSubsetIndices[bestFeatureToSplitOnIndex]);
308         node->setSplitFeatureValue(bestFeatureSplitValue);
309         node->setSplitFeatureEntropy(bestFeatureSplitEntropy);
310         
311         return 0;
312     }
313         catch(exception& e) {
314                 m->errorOut(e, "DecisionTree", "findAndUpdateBestFeatureToSplitOn");
315                 exit(1);
316         } 
317 }
318 /***********************************************************************/
319 vector<int> DecisionTree::selectFeatureSubsetRandomly(vector<int> globalDiscardedFeatureIndices, vector<int> localDiscardedFeatureIndices){
320     try {
321
322         vector<int> featureSubsetIndices;
323         
324         vector<int> combinedDiscardedFeatureIndices;
325         combinedDiscardedFeatureIndices.insert(combinedDiscardedFeatureIndices.end(), globalDiscardedFeatureIndices.begin(), globalDiscardedFeatureIndices.end());
326         combinedDiscardedFeatureIndices.insert(combinedDiscardedFeatureIndices.end(), localDiscardedFeatureIndices.begin(), localDiscardedFeatureIndices.end());
327         
328         sort(combinedDiscardedFeatureIndices.begin(), combinedDiscardedFeatureIndices.end());
329         
330         int numberOfRemainingSuitableFeatures = (int)(numFeatures - combinedDiscardedFeatureIndices.size());
331         int currentFeatureSubsetSize = numberOfRemainingSuitableFeatures < optimumFeatureSubsetSize ? numberOfRemainingSuitableFeatures : optimumFeatureSubsetSize;
332         
333         while (featureSubsetIndices.size() < currentFeatureSubsetSize) {
334             
335             if (m->control_pressed) { return featureSubsetIndices; }
336             
337             // TODO: optimize rand() call here
338             int randomIndex = rand() % numFeatures;
339             vector<int>::iterator it = find(featureSubsetIndices.begin(), featureSubsetIndices.end(), randomIndex);
340             if (it == featureSubsetIndices.end()){    // NOT FOUND
341                 vector<int>::iterator it2 = find(combinedDiscardedFeatureIndices.begin(), combinedDiscardedFeatureIndices.end(), randomIndex);
342                 if (it2 == combinedDiscardedFeatureIndices.end()){  // NOT FOUND AGAIN
343                     featureSubsetIndices.push_back(randomIndex);
344                 }
345             }
346         }
347         sort(featureSubsetIndices.begin(), featureSubsetIndices.end());
348         
349         //#ifdef DEBUG_LEVEL_3
350         //    PRINT_VAR(featureSubsetIndices);
351         //#endif
352         
353         return featureSubsetIndices;
354     }
355         catch(exception& e) {
356                 m->errorOut(e, "DecisionTree", "selectFeatureSubsetRandomly");
357                 exit(1);
358         } 
359 }
360 /***********************************************************************/
361
362 // TODO: printTree() needs a check if correct
363 int DecisionTree::printTree(RFTreeNode* treeNode, string caption){
364     try { 
365         string tabs = "";
366         for (int i = 0; i < treeNode->getGeneration(); i++) { tabs += "   "; }
367         //    for (int i = 0; i < treeNode->getGeneration() - 1; i++) { tabs += "|  "; }
368         //    if (treeNode->getGeneration() != 0) { tabs += "|--"; }
369         
370         if (treeNode != NULL && treeNode->checkIsLeaf() == false){
371             m->mothurOut(tabs + caption + " [ gen: " + toString(treeNode->getGeneration()) + " ] ( " + toString(treeNode->getSplitFeatureValue()) + " < X" + toString(treeNode->getSplitFeatureIndex()) +" )\n");
372             
373             printTree(treeNode->getLeftChildNode(), "leftChild");
374             printTree(treeNode->getRightChildNode(), "rightChild");
375         }else {
376             m->mothurOut(tabs + caption + " [ gen: " + toString(treeNode->getGeneration()) + " ] ( classified to: " + toString(treeNode->getOutputClass()) + ", samples: " + toString(treeNode->getNumSamples()) + " )\n");
377         }
378         return 0;
379     }
380         catch(exception& e) {
381                 m->errorOut(e, "DecisionTree", "printTree");
382                 exit(1);
383         } 
384 }
385 /***********************************************************************/
386 void DecisionTree::deleteTreeNodesRecursively(RFTreeNode* treeNode) {
387     try {
388         if (treeNode == NULL) { return; }
389         deleteTreeNodesRecursively(treeNode->leftChildNode);
390         deleteTreeNodesRecursively(treeNode->rightChildNode);
391         delete treeNode;
392     }
393         catch(exception& e) {
394                 m->errorOut(e, "DecisionTree", "deleteTreeNodesRecursively");
395                 exit(1);
396         } 
397 }
398 /***********************************************************************/
399