]> git.donarmstrong.com Git - mothur.git/blob - abstractdecisiontree.cpp
changing command name classify.shared to classifyrf.shared
[mothur.git] / abstractdecisiontree.cpp
1 //
2 //  abstractdecisiontree.cpp
3 //  Mothur
4 //
5 //  Created by Sarah Westcott on 10/1/12.
6 //  Copyright (c) 2012 Schloss Lab. All rights reserved.
7 //
8
9 #include "abstractdecisiontree.hpp"
10
11 /**************************************************************************************************/
12
13 AbstractDecisionTree::AbstractDecisionTree(vector<vector<int> >& baseDataSet,
14                                          vector<int> globalDiscardedFeatureIndices,
15                                          OptimumFeatureSubsetSelector optimumFeatureSubsetSelector, 
16                                          string treeSplitCriterion)
17
18                     : baseDataSet(baseDataSet),
19                     numSamples((int)baseDataSet.size()),
20                     numFeatures((int)(baseDataSet[0].size() - 1)),
21                     numOutputClasses(0),
22                     rootNode(NULL),
23                     nodeIdCount(0),
24                     globalDiscardedFeatureIndices(globalDiscardedFeatureIndices),
25                     optimumFeatureSubsetSize(optimumFeatureSubsetSelector.getOptimumFeatureSubsetSize(numFeatures)),
26                     treeSplitCriterion(treeSplitCriterion) {
27
28     try {
29         // TODO: istead of calculating this for every DecisionTree
30         // clacualte this once in the RandomForest class and pass the values
31         m = MothurOut::getInstance();
32         for (int i = 0;  i < numSamples; i++) {
33             if (m->control_pressed) { break; }
34             int outcome = baseDataSet[i][numFeatures];
35             vector<int>::iterator it = find(outputClasses.begin(), outputClasses.end(), outcome);
36             if (it == outputClasses.end()){       // find() will return classes.end() if the element is not found
37                 outputClasses.push_back(outcome);
38                 numOutputClasses++;
39             }
40         }
41         
42         if (m->debug) {
43             //m->mothurOut("outputClasses = " + toStringVectorInt(outputClasses));
44             m->mothurOut("numOutputClasses = " + toString(numOutputClasses) + '\n');
45         }
46
47     }
48         catch(exception& e) {
49                 m->errorOut(e, "AbstractDecisionTree", "AbstractDecisionTree");
50                 exit(1);
51         } 
52 }
53 /**************************************************************************************************/
54 int AbstractDecisionTree::createBootStrappedSamples(){
55     try {    
56         vector<bool> isInTrainingSamples(numSamples, false);
57         
58         for (int i = 0; i < numSamples; i++) {
59             if (m->control_pressed) { return 0; }
60             // TODO: optimize the rand() function call + double check if it's working properly
61             int randomIndex = rand() % numSamples;
62             bootstrappedTrainingSamples.push_back(baseDataSet[randomIndex]);
63             isInTrainingSamples[randomIndex] = true;
64         }
65         
66         for (int i = 0; i < numSamples; i++) {
67             if (m->control_pressed) { return 0; }
68             if (isInTrainingSamples[i]){ bootstrappedTrainingSampleIndices.push_back(i); }
69             else{
70                 bootstrappedTestSamples.push_back(baseDataSet[i]);
71                 bootstrappedTestSampleIndices.push_back(i);
72             }
73         }
74         
75             // do the transpose of Test Samples
76         for (int i = 0; i < bootstrappedTestSamples[0].size(); i++) {
77             if (m->control_pressed) { return 0; }
78             
79             vector<int> tmpFeatureVector(bootstrappedTestSamples.size(), 0);
80             for (int j = 0; j < bootstrappedTestSamples.size(); j++) {
81                 if (m->control_pressed) { return 0; }
82                 
83                 tmpFeatureVector[j] = bootstrappedTestSamples[j][i];
84             }
85             testSampleFeatureVectors.push_back(tmpFeatureVector);
86         }
87         
88         return 0;
89     }
90         catch(exception& e) {
91                 m->errorOut(e, "AbstractDecisionTree", "createBootStrappedSamples");
92                 exit(1);
93         } 
94 }
95 /**************************************************************************************************/
96 int AbstractDecisionTree::getMinEntropyOfFeature(vector<int> featureVector,
97                                                  vector<int> outputVector,
98                                                  double& minEntropy,
99                                                  int& featureSplitValue,
100                                                  double& intrinsicValue){
101     try {
102
103         vector< pair<int, int> > featureOutputPair(featureVector.size(), pair<int, int>(0, 0));
104         
105         for (int i = 0; i < featureVector.size(); i++) { 
106             if (m->control_pressed) { return 0; }
107             
108             featureOutputPair[i].first = featureVector[i];
109             featureOutputPair[i].second = outputVector[i];
110         }
111         // TODO: using default behavior to sort(), need to specify the comparator for added safety and compiler portability,
112         
113         IntPairVectorSorter intPairVectorSorter;
114         sort(featureOutputPair.begin(), featureOutputPair.end(), intPairVectorSorter);
115         
116         vector<int> splitPoints;
117         vector<int> uniqueFeatureValues(1, featureOutputPair[0].first);
118         
119         for (int i = 0; i < featureOutputPair.size(); i++) {
120
121             if (m->control_pressed) { return 0; }
122             int featureValue = featureOutputPair[i].first;
123
124             vector<int>::iterator it = find(uniqueFeatureValues.begin(), uniqueFeatureValues.end(), featureValue);
125             if (it == uniqueFeatureValues.end()){                 // NOT FOUND
126                 uniqueFeatureValues.push_back(featureValue);
127                 splitPoints.push_back(i);
128             }
129         }
130         
131
132         
133         int bestSplitIndex = -1;
134         if (splitPoints.size() == 0){
135             // TODO: trying out C++'s infitinity, don't know if this will work properly
136             // TODO: check the caller function of this function, there check the value if minEntropy and comapre to inf
137             // so that no wrong calculation is done
138             minEntropy = numeric_limits<double>::infinity();                          // OUTPUT
139             intrinsicValue = numeric_limits<double>::infinity();                      // OUTPUT
140             featureSplitValue = -1;                                                   // OUTPUT
141         }else{
142             getBestSplitAndMinEntropy(featureOutputPair, splitPoints, minEntropy, bestSplitIndex, intrinsicValue);  // OUTPUT
143             featureSplitValue = featureOutputPair[splitPoints[bestSplitIndex]].first;    // OUTPUT
144         }
145         
146         return 0;
147     }
148         catch(exception& e) {
149                 m->errorOut(e, "AbstractDecisionTree", "getMinEntropyOfFeature");
150                 exit(1);
151         } 
152 }
153 /**************************************************************************************************/
154 double AbstractDecisionTree::calcIntrinsicValue(int numLessThanValueAtSplitPoint, int numGreaterThanValueAtSplitPoint, int numSamples) {
155     try {
156         double upperSplitEntropy = 0.0, lowerSplitEntropy = 0.0;
157         if (numLessThanValueAtSplitPoint > 0) {
158             upperSplitEntropy = numLessThanValueAtSplitPoint * log2((double) numLessThanValueAtSplitPoint / (double) numSamples);
159         }
160         
161         if (numGreaterThanValueAtSplitPoint > 0) {
162             lowerSplitEntropy = numGreaterThanValueAtSplitPoint * log2((double) numGreaterThanValueAtSplitPoint / (double) numSamples);
163         }
164         
165         double intrinsicValue = - ((double)(upperSplitEntropy + lowerSplitEntropy) / (double)numSamples);
166         return intrinsicValue;
167     }
168         catch(exception& e) {
169                 m->errorOut(e, "AbstractDecisionTree", "calcIntrinsicValue");
170                 exit(1);
171         } 
172 }
173 /**************************************************************************************************/
174
175 int AbstractDecisionTree::getBestSplitAndMinEntropy(vector< pair<int, int> > featureOutputPairs, vector<int> splitPoints,
176                                                     double& minEntropy, int& minEntropyIndex, double& relatedIntrinsicValue){
177     try {
178         
179         int numSamples = (int)featureOutputPairs.size();
180         vector<double> entropies;
181         vector<double> intrinsicValues;
182         
183         for (int i = 0; i < splitPoints.size(); i++) {
184             if (m->control_pressed) { return 0; }
185             int index = splitPoints[i];
186             int valueAtSplitPoint = featureOutputPairs[index].first;
187
188             int numLessThanValueAtSplitPoint = 0;
189             int numGreaterThanValueAtSplitPoint = 0;
190             
191             for (int j = 0; j < featureOutputPairs.size(); j++) {
192                 if (m->control_pressed) { return 0; }
193                 pair<int, int> record = featureOutputPairs[j];
194                 if (record.first < valueAtSplitPoint){ numLessThanValueAtSplitPoint++; }
195                 else{ numGreaterThanValueAtSplitPoint++; }
196             }
197             
198             double upperEntropyOfSplit = calcSplitEntropy(featureOutputPairs, index, numOutputClasses, true);
199             double lowerEntropyOfSplit = calcSplitEntropy(featureOutputPairs, index, numOutputClasses, false);
200             
201             double totalEntropy = (numLessThanValueAtSplitPoint * upperEntropyOfSplit + numGreaterThanValueAtSplitPoint * lowerEntropyOfSplit) / (double)numSamples;
202             double intrinsicValue = calcIntrinsicValue(numLessThanValueAtSplitPoint, numGreaterThanValueAtSplitPoint, numSamples);
203             entropies.push_back(totalEntropy);
204             intrinsicValues.push_back(intrinsicValue);
205             
206         }
207                 
208         // set output values
209         vector<double>::iterator it = min_element(entropies.begin(), entropies.end());
210         minEntropy = *it;                                                         // OUTPUT
211         minEntropyIndex = (int)(it - entropies.begin());                          // OUTPUT
212         relatedIntrinsicValue = intrinsicValues[minEntropyIndex];                 // OUTPUT
213         
214         return 0;
215     }
216         catch(exception& e) {
217                 m->errorOut(e, "AbstractDecisionTree", "getBestSplitAndMinEntropy");
218                 exit(1);
219         } 
220 }
221 /**************************************************************************************************/
222
223 double AbstractDecisionTree::calcSplitEntropy(vector< pair<int, int> > featureOutputPairs, int splitIndex, int numOutputClasses, bool isUpperSplit = true) {
224     try {
225         vector<int> classCounts(numOutputClasses, 0);
226         
227         if (isUpperSplit) { 
228             for (int i = 0; i < splitIndex; i++) {
229                 if (m->control_pressed) { return 0; }
230                 classCounts[featureOutputPairs[i].second]++;
231             }
232         } else {
233             for (int i = splitIndex; i < featureOutputPairs.size(); i++) { 
234                 if (m->control_pressed) { return 0; }
235                 classCounts[featureOutputPairs[i].second]++;
236             }
237         }
238         
239         int totalClassCounts = accumulate(classCounts.begin(), classCounts.end(), 0);
240         
241         double splitEntropy = 0.0;
242         
243         for (int i = 0; i < classCounts.size(); i++) {
244             if (m->control_pressed) { return 0; }
245             if (classCounts[i] == 0) { continue; }
246             double probability = (double) classCounts[i] / (double) totalClassCounts;
247             splitEntropy += -(probability * log2(probability));
248         }
249         
250         return splitEntropy;
251     }
252         catch(exception& e) {
253                 m->errorOut(e, "AbstractDecisionTree", "calcSplitEntropy");
254                 exit(1);
255         } 
256 }
257
258 /**************************************************************************************************/
259
260 int AbstractDecisionTree::getSplitPopulation(RFTreeNode* node, vector< vector<int> >& leftChildSamples, vector< vector<int> >& rightChildSamples){    
261     try {
262         // TODO: there is a possibility of optimization if we can recycle the samples in each nodes
263         // we just need to pointers to the samples i.e. vector<int> and use it everywhere and not create the sample 
264         // sample over and over again
265         // we need to make this const so that it is not modified by all the function calling
266         // currently purgeTreeNodesDataRecursively() is used for the same purpose, but this can be avoided altogher
267         // if re-using the same data over the classes
268         
269         int splitFeatureGlobalIndex = node->getSplitFeatureIndex();
270         
271         for (int i = 0; i < node->getBootstrappedTrainingSamples().size(); i++) {
272             if (m->control_pressed) { return 0; }
273             vector<int> sample =  node->getBootstrappedTrainingSamples()[i];
274             if (m->control_pressed) { return 0; }
275             
276             if (sample[splitFeatureGlobalIndex] < node->getSplitFeatureValue()) { leftChildSamples.push_back(sample); }
277             else { rightChildSamples.push_back(sample); }
278         }
279         
280         return 0;
281     }
282         catch(exception& e) {
283                 m->errorOut(e, "AbstractDecisionTree", "getSplitPopulation");
284                 exit(1);
285         } 
286 }
287 /**************************************************************************************************/
288 // TODO: checkIfAlreadyClassified() verify code
289 // TODO: use bootstrappedOutputVector for easier calculation instead of using getBootstrappedTrainingSamples()
290 bool AbstractDecisionTree::checkIfAlreadyClassified(RFTreeNode* treeNode, int& outputClass) {
291     try {
292
293         vector<int> tempOutputClasses;
294         for (int i = 0; i < treeNode->getBootstrappedTrainingSamples().size(); i++) {
295             if (m->control_pressed) { return 0; }
296             int sampleOutputClass = treeNode->getBootstrappedTrainingSamples()[i][numFeatures];
297             vector<int>::iterator it = find(tempOutputClasses.begin(), tempOutputClasses.end(), sampleOutputClass);
298             if (it == tempOutputClasses.end()) {               // NOT FOUND
299                 tempOutputClasses.push_back(sampleOutputClass);
300             }
301         }
302         
303         if (tempOutputClasses.size() < 2) { outputClass = tempOutputClasses[0]; return true; }
304         else { outputClass = -1; return false; }
305         
306     }
307         catch(exception& e) {
308                 m->errorOut(e, "AbstractDecisionTree", "checkIfAlreadyClassified");
309                 exit(1);
310         } 
311 }
312
313 /**************************************************************************************************/