]> git.donarmstrong.com Git - mothur.git/blob - abstractdecisiontree.cpp
added modify names parameter to set.dir
[mothur.git] / abstractdecisiontree.cpp
1 //
2 //  abstractdecisiontree.cpp
3 //  Mothur
4 //
5 //  Created by Sarah Westcott on 10/1/12.
6 //  Copyright (c) 2012 Schloss Lab. All rights reserved.
7 //
8
9 #include "abstractdecisiontree.hpp"
10
11 /**************************************************************************************************/
12
13 AbstractDecisionTree::AbstractDecisionTree(vector<vector<int> >baseDataSet, 
14                      vector<int> globalDiscardedFeatureIndices, 
15                      OptimumFeatureSubsetSelector optimumFeatureSubsetSelector, 
16                      string treeSplitCriterion) : baseDataSet(baseDataSet),
17 numSamples((int)baseDataSet.size()),
18 numFeatures((int)(baseDataSet[0].size() - 1)),
19 numOutputClasses(0),
20 rootNode(NULL),
21 globalDiscardedFeatureIndices(globalDiscardedFeatureIndices),
22 optimumFeatureSubsetSize(optimumFeatureSubsetSelector.getOptimumFeatureSubsetSize(numFeatures)),
23 treeSplitCriterion(treeSplitCriterion) {
24
25     try {
26     // TODO: istead of calculating this for every DecisionTree
27     // clacualte this once in the RandomForest class and pass the values
28     m = MothurOut::getInstance();
29     for (int i = 0;  i < numSamples; i++) {
30         if (m->control_pressed) { break; }
31         int outcome = baseDataSet[i][numFeatures];
32         vector<int>::iterator it = find(outputClasses.begin(), outputClasses.end(), outcome);
33         if (it == outputClasses.end()){       // find() will return classes.end() if the element is not found
34             outputClasses.push_back(outcome);
35             numOutputClasses++;
36         }
37     }
38     
39     if (m->debug) {
40         //m->mothurOut("outputClasses = " + toStringVectorInt(outputClasses));
41         m->mothurOut("numOutputClasses = " + toString(numOutputClasses) + '\n');
42     }
43
44     }
45         catch(exception& e) {
46                 m->errorOut(e, "AbstractDecisionTree", "AbstractDecisionTree");
47                 exit(1);
48         } 
49 }
50 /**************************************************************************************************/
51 int AbstractDecisionTree::createBootStrappedSamples(){
52     try {    
53     vector<bool> isInTrainingSamples(numSamples, false);
54     
55     for (int i = 0; i < numSamples; i++) {
56         if (m->control_pressed) { return 0; }
57         // TODO: optimize the rand() function call + double check if it's working properly
58         int randomIndex = rand() % numSamples;
59         bootstrappedTrainingSamples.push_back(baseDataSet[randomIndex]);
60         isInTrainingSamples[randomIndex] = true;
61     }
62     
63     for (int i = 0; i < numSamples; i++) {
64         if (m->control_pressed) { return 0; }
65         if (isInTrainingSamples[i]){ bootstrappedTrainingSampleIndices.push_back(i); }
66         else{
67             bootstrappedTestSamples.push_back(baseDataSet[i]);
68             bootstrappedTestSampleIndices.push_back(i);
69         }
70     }
71     
72         return 0;
73     }
74         catch(exception& e) {
75                 m->errorOut(e, "AbstractDecisionTree", "createBootStrappedSamples");
76                 exit(1);
77         } 
78 }
79 /**************************************************************************************************/
80 int AbstractDecisionTree::getMinEntropyOfFeature(vector<int> featureVector, vector<int> outputVector, double& minEntropy, int& featureSplitValue, double& intrinsicValue){
81     try {
82
83         vector< vector<int> > featureOutputPair(featureVector.size(), vector<int>(2, 0));
84         for (int i = 0; i < featureVector.size(); i++) { 
85             if (m->control_pressed) { return 0; }
86             featureOutputPair[i][0] = featureVector[i];
87             featureOutputPair[i][1] = outputVector[i];
88         }
89         // TODO: using default behavior to sort(), need to specify the comparator for added safety and compiler portability
90         sort(featureOutputPair.begin(), featureOutputPair.end());
91         
92         
93         vector<int> splitPoints;
94         vector<int> uniqueFeatureValues(1, featureOutputPair[0][0]);
95         
96         for (int i = 0; i < featureOutputPair.size(); i++) {
97             if (m->control_pressed) { return 0; }
98             int featureValue = featureOutputPair[i][0];
99             vector<int>::iterator it = find(uniqueFeatureValues.begin(), uniqueFeatureValues.end(), featureValue);
100             if (it == uniqueFeatureValues.end()){                 // NOT FOUND
101                 uniqueFeatureValues.push_back(featureValue);
102                 splitPoints.push_back(i);
103             }
104         }
105         
106
107         
108         int bestSplitIndex = -1;
109         if (splitPoints.size() == 0){
110             // TODO: trying out C++'s infitinity, don't know if this will work properly
111             // TODO: check the caller function of this function, there check the value if minEntropy and comapre to inf
112             // so that no wrong calculation is done
113             minEntropy = numeric_limits<double>::infinity();                          // OUTPUT
114             intrinsicValue = numeric_limits<double>::infinity();                      // OUTPUT
115             featureSplitValue = -1;                                                   // OUTPUT
116         }else{
117             getBestSplitAndMinEntropy(featureOutputPair, splitPoints, minEntropy, bestSplitIndex, intrinsicValue);  // OUTPUT
118             featureSplitValue = featureOutputPair[splitPoints[bestSplitIndex]][0];    // OUTPUT
119         }
120         
121         return 0;
122     }
123         catch(exception& e) {
124                 m->errorOut(e, "AbstractDecisionTree", "getMinEntropyOfFeature");
125                 exit(1);
126         } 
127 }
128 /**************************************************************************************************/
129 double AbstractDecisionTree::calcIntrinsicValue(int numLessThanValueAtSplitPoint, int numGreaterThanValueAtSplitPoint, int numSamples) {
130     try {
131         double upperSplitEntropy = 0.0, lowerSplitEntropy = 0.0;
132         if (numLessThanValueAtSplitPoint > 0) {
133             upperSplitEntropy = numLessThanValueAtSplitPoint * log2((double) numLessThanValueAtSplitPoint / (double) numSamples);
134         }
135         
136         if (numGreaterThanValueAtSplitPoint > 0) {
137             lowerSplitEntropy = numGreaterThanValueAtSplitPoint * log2((double) numGreaterThanValueAtSplitPoint / (double) numSamples);
138         }
139         
140         double intrinsicValue = - ((double)(upperSplitEntropy + lowerSplitEntropy) / (double)numSamples);
141         return intrinsicValue;
142     }
143         catch(exception& e) {
144                 m->errorOut(e, "AbstractDecisionTree", "calcIntrinsicValue");
145                 exit(1);
146         } 
147 }
148 /**************************************************************************************************/
149 int AbstractDecisionTree::getBestSplitAndMinEntropy(vector< vector<int> > featureOutputPairs, vector<int> splitPoints,
150                                double& minEntropy, int& minEntropyIndex, double& relatedIntrinsicValue){
151     try {
152         
153         int numSamples = (int)featureOutputPairs.size();
154         vector<double> entropies;
155         vector<double> intrinsicValues;
156         
157         for (int i = 0; i < splitPoints.size(); i++) {
158              if (m->control_pressed) { return 0; }
159             int index = splitPoints[i];
160             int valueAtSplitPoint = featureOutputPairs[index][0];
161             int numLessThanValueAtSplitPoint = 0;
162             int numGreaterThanValueAtSplitPoint = 0;
163             
164             for (int j = 0; j < featureOutputPairs.size(); j++) {
165                  if (m->control_pressed) { return 0; }
166                 vector<int> record = featureOutputPairs[j];
167                 if (record[0] < valueAtSplitPoint){ numLessThanValueAtSplitPoint++; }
168                 else{ numGreaterThanValueAtSplitPoint++; }
169             }
170             
171             double upperEntropyOfSplit = calcSplitEntropy(featureOutputPairs, index, numOutputClasses, true);
172             double lowerEntropyOfSplit = calcSplitEntropy(featureOutputPairs, index, numOutputClasses, false);
173             
174             double totalEntropy = (numLessThanValueAtSplitPoint * upperEntropyOfSplit + numGreaterThanValueAtSplitPoint * lowerEntropyOfSplit) / (double)numSamples;
175             double intrinsicValue = calcIntrinsicValue(numLessThanValueAtSplitPoint, numGreaterThanValueAtSplitPoint, numSamples);
176             entropies.push_back(totalEntropy);
177             intrinsicValues.push_back(intrinsicValue);
178             
179         }
180                 
181         // set output values
182         vector<double>::iterator it = min_element(entropies.begin(), entropies.end());
183         minEntropy = *it;                                                         // OUTPUT
184         minEntropyIndex = (int)(it - entropies.begin());                          // OUTPUT
185         relatedIntrinsicValue = intrinsicValues[minEntropyIndex];                 // OUTPUT
186         
187         return 0;
188     }
189         catch(exception& e) {
190                 m->errorOut(e, "AbstractDecisionTree", "getBestSplitAndMinEntropy");
191                 exit(1);
192         } 
193 }
194 /**************************************************************************************************/
195
196 double AbstractDecisionTree::calcSplitEntropy(vector< vector<int> > featureOutputPairs, int splitIndex, int numOutputClasses, bool isUpperSplit = true) {
197     try {
198         vector<int> classCounts(numOutputClasses, 0);
199         
200         if (isUpperSplit) { 
201             for (int i = 0; i < splitIndex; i++) { 
202                 if (m->control_pressed) { return 0; }
203                 classCounts[featureOutputPairs[i][1]]++; 
204             }
205         } else {
206             for (int i = splitIndex; i < featureOutputPairs.size(); i++) { 
207                 if (m->control_pressed) { return 0; }
208                 classCounts[featureOutputPairs[i][1]]++; 
209             }
210         }
211         
212         int totalClassCounts = accumulate(classCounts.begin(), classCounts.end(), 0);
213         
214         double splitEntropy = 0.0;
215         
216         for (int i = 0; i < classCounts.size(); i++) {
217             if (m->control_pressed) { return 0; }
218             if (classCounts[i] == 0) { continue; }
219             double probability = (double) classCounts[i] / (double) totalClassCounts;
220             splitEntropy += -(probability * log2(probability));
221         }
222         
223         return splitEntropy;
224     }
225         catch(exception& e) {
226                 m->errorOut(e, "AbstractDecisionTree", "calcSplitEntropy");
227                 exit(1);
228         } 
229 }
230
231 /**************************************************************************************************/
232
233 int AbstractDecisionTree::getSplitPopulation(RFTreeNode* node, vector< vector<int> >& leftChildSamples, vector< vector<int> >& rightChildSamples){    
234     try {
235         // TODO: there is a possibility of optimization if we can recycle the samples in each nodes
236         // we just need to pointers to the samples i.e. vector<int> and use it everywhere and not create the sample 
237         // sample over and over again
238         // we need to make this const so that it is not modified by all the function calling
239         // currently purgeTreeNodesDataRecursively() is used for the same purpose, but this can be avoided altogher
240         // if re-using the same data over the classes
241         
242         int splitFeatureGlobalIndex = node->getSplitFeatureIndex();
243         
244         for (int i = 0; i < node->getBootstrappedTrainingSamples().size(); i++) {
245             if (m->control_pressed) { return 0; }
246             vector<int> sample =  node->getBootstrappedTrainingSamples()[i];
247             if (m->control_pressed) { return 0; }
248             if (sample[splitFeatureGlobalIndex] < node->getSplitFeatureValue()){ leftChildSamples.push_back(sample); }
249             else{ rightChildSamples.push_back(sample); }
250         }
251         
252         return 0;
253     }
254         catch(exception& e) {
255                 m->errorOut(e, "AbstractDecisionTree", "getSplitPopulation");
256                 exit(1);
257         } 
258 }
259 /**************************************************************************************************/
260 // TODO: checkIfAlreadyClassified() verify code
261 // TODO: use bootstrappedOutputVector for easier calculation instead of using getBootstrappedTrainingSamples()
262 bool AbstractDecisionTree::checkIfAlreadyClassified(RFTreeNode* treeNode, int& outputClass) {
263     try {
264
265         vector<int> tempOutputClasses;
266         for (int i = 0; i < treeNode->getBootstrappedTrainingSamples().size(); i++) {
267             if (m->control_pressed) { return 0; }
268             int sampleOutputClass = treeNode->getBootstrappedTrainingSamples()[i][numFeatures];
269             vector<int>::iterator it = find(tempOutputClasses.begin(), tempOutputClasses.end(), sampleOutputClass);
270             if (it == tempOutputClasses.end()) {               // NOT FOUND
271                 tempOutputClasses.push_back(sampleOutputClass);
272             }
273         }
274         
275         if (tempOutputClasses.size() < 2) { outputClass = tempOutputClasses[0]; return true; }
276         else { outputClass = -1; return false; }
277         
278     }
279         catch(exception& e) {
280                 m->errorOut(e, "AbstractDecisionTree", "checkIfAlreadyClassified");
281                 exit(1);
282         } 
283 }
284
285 /**************************************************************************************************/