2 // abstractdecisiontree.cpp
5 // Created by Sarah Westcott on 10/1/12.
6 // Copyright (c) 2012 Schloss Lab. All rights reserved.
9 #include "abstractdecisiontree.hpp"
11 /**************************************************************************************************/
13 AbstractDecisionTree::AbstractDecisionTree(vector<vector<int> >baseDataSet,
14 vector<int> globalDiscardedFeatureIndices,
15 OptimumFeatureSubsetSelector optimumFeatureSubsetSelector,
16 string treeSplitCriterion) : baseDataSet(baseDataSet),
17 numSamples((int)baseDataSet.size()),
18 numFeatures((int)(baseDataSet[0].size() - 1)),
21 globalDiscardedFeatureIndices(globalDiscardedFeatureIndices),
22 optimumFeatureSubsetSize(optimumFeatureSubsetSelector.getOptimumFeatureSubsetSize(numFeatures)),
23 treeSplitCriterion(treeSplitCriterion) {
26 // TODO: istead of calculating this for every DecisionTree
27 // clacualte this once in the RandomForest class and pass the values
28 m = MothurOut::getInstance();
29 for (int i = 0; i < numSamples; i++) {
30 if (m->control_pressed) { break; }
31 int outcome = baseDataSet[i][numFeatures];
32 vector<int>::iterator it = find(outputClasses.begin(), outputClasses.end(), outcome);
33 if (it == outputClasses.end()){ // find() will return classes.end() if the element is not found
34 outputClasses.push_back(outcome);
40 //m->mothurOut("outputClasses = " + toStringVectorInt(outputClasses));
41 m->mothurOut("numOutputClasses = " + toString(numOutputClasses) + '\n');
46 m->errorOut(e, "AbstractDecisionTree", "AbstractDecisionTree");
50 /**************************************************************************************************/
51 int AbstractDecisionTree::createBootStrappedSamples(){
53 vector<bool> isInTrainingSamples(numSamples, false);
55 for (int i = 0; i < numSamples; i++) {
56 if (m->control_pressed) { return 0; }
57 // TODO: optimize the rand() function call + double check if it's working properly
58 int randomIndex = rand() % numSamples;
59 bootstrappedTrainingSamples.push_back(baseDataSet[randomIndex]);
60 isInTrainingSamples[randomIndex] = true;
63 for (int i = 0; i < numSamples; i++) {
64 if (m->control_pressed) { return 0; }
65 if (isInTrainingSamples[i]){ bootstrappedTrainingSampleIndices.push_back(i); }
67 bootstrappedTestSamples.push_back(baseDataSet[i]);
68 bootstrappedTestSampleIndices.push_back(i);
75 m->errorOut(e, "AbstractDecisionTree", "createBootStrappedSamples");
79 /**************************************************************************************************/
80 int AbstractDecisionTree::getMinEntropyOfFeature(vector<int> featureVector, vector<int> outputVector, double& minEntropy, int& featureSplitValue, double& intrinsicValue){
83 vector< vector<int> > featureOutputPair(featureVector.size(), vector<int>(2, 0));
84 for (int i = 0; i < featureVector.size(); i++) {
85 if (m->control_pressed) { return 0; }
86 featureOutputPair[i][0] = featureVector[i];
87 featureOutputPair[i][1] = outputVector[i];
89 // TODO: using default behavior to sort(), need to specify the comparator for added safety and compiler portability
90 sort(featureOutputPair.begin(), featureOutputPair.end());
93 vector<int> splitPoints;
94 vector<int> uniqueFeatureValues(1, featureOutputPair[0][0]);
96 for (int i = 0; i < featureOutputPair.size(); i++) {
97 if (m->control_pressed) { return 0; }
98 int featureValue = featureOutputPair[i][0];
99 vector<int>::iterator it = find(uniqueFeatureValues.begin(), uniqueFeatureValues.end(), featureValue);
100 if (it == uniqueFeatureValues.end()){ // NOT FOUND
101 uniqueFeatureValues.push_back(featureValue);
102 splitPoints.push_back(i);
108 int bestSplitIndex = -1;
109 if (splitPoints.size() == 0){
110 // TODO: trying out C++'s infitinity, don't know if this will work properly
111 // TODO: check the caller function of this function, there check the value if minEntropy and comapre to inf
112 // so that no wrong calculation is done
113 minEntropy = numeric_limits<double>::infinity(); // OUTPUT
114 intrinsicValue = numeric_limits<double>::infinity(); // OUTPUT
115 featureSplitValue = -1; // OUTPUT
117 getBestSplitAndMinEntropy(featureOutputPair, splitPoints, minEntropy, bestSplitIndex, intrinsicValue); // OUTPUT
118 featureSplitValue = featureOutputPair[splitPoints[bestSplitIndex]][0]; // OUTPUT
123 catch(exception& e) {
124 m->errorOut(e, "AbstractDecisionTree", "getMinEntropyOfFeature");
128 /**************************************************************************************************/
129 double AbstractDecisionTree::calcIntrinsicValue(int numLessThanValueAtSplitPoint, int numGreaterThanValueAtSplitPoint, int numSamples) {
131 double upperSplitEntropy = 0.0, lowerSplitEntropy = 0.0;
132 if (numLessThanValueAtSplitPoint > 0) {
133 upperSplitEntropy = numLessThanValueAtSplitPoint * log2((double) numLessThanValueAtSplitPoint / (double) numSamples);
136 if (numGreaterThanValueAtSplitPoint > 0) {
137 lowerSplitEntropy = numGreaterThanValueAtSplitPoint * log2((double) numGreaterThanValueAtSplitPoint / (double) numSamples);
140 double intrinsicValue = - ((double)(upperSplitEntropy + lowerSplitEntropy) / (double)numSamples);
141 return intrinsicValue;
143 catch(exception& e) {
144 m->errorOut(e, "AbstractDecisionTree", "calcIntrinsicValue");
148 /**************************************************************************************************/
149 int AbstractDecisionTree::getBestSplitAndMinEntropy(vector< vector<int> > featureOutputPairs, vector<int> splitPoints,
150 double& minEntropy, int& minEntropyIndex, double& relatedIntrinsicValue){
153 int numSamples = (int)featureOutputPairs.size();
154 vector<double> entropies;
155 vector<double> intrinsicValues;
157 for (int i = 0; i < splitPoints.size(); i++) {
158 if (m->control_pressed) { return 0; }
159 int index = splitPoints[i];
160 int valueAtSplitPoint = featureOutputPairs[index][0];
161 int numLessThanValueAtSplitPoint = 0;
162 int numGreaterThanValueAtSplitPoint = 0;
164 for (int j = 0; j < featureOutputPairs.size(); j++) {
165 if (m->control_pressed) { return 0; }
166 vector<int> record = featureOutputPairs[j];
167 if (record[0] < valueAtSplitPoint){ numLessThanValueAtSplitPoint++; }
168 else{ numGreaterThanValueAtSplitPoint++; }
171 double upperEntropyOfSplit = calcSplitEntropy(featureOutputPairs, index, numOutputClasses, true);
172 double lowerEntropyOfSplit = calcSplitEntropy(featureOutputPairs, index, numOutputClasses, false);
174 double totalEntropy = (numLessThanValueAtSplitPoint * upperEntropyOfSplit + numGreaterThanValueAtSplitPoint * lowerEntropyOfSplit) / (double)numSamples;
175 double intrinsicValue = calcIntrinsicValue(numLessThanValueAtSplitPoint, numGreaterThanValueAtSplitPoint, numSamples);
176 entropies.push_back(totalEntropy);
177 intrinsicValues.push_back(intrinsicValue);
182 vector<double>::iterator it = min_element(entropies.begin(), entropies.end());
183 minEntropy = *it; // OUTPUT
184 minEntropyIndex = (int)(it - entropies.begin()); // OUTPUT
185 relatedIntrinsicValue = intrinsicValues[minEntropyIndex]; // OUTPUT
189 catch(exception& e) {
190 m->errorOut(e, "AbstractDecisionTree", "getBestSplitAndMinEntropy");
194 /**************************************************************************************************/
196 double AbstractDecisionTree::calcSplitEntropy(vector< vector<int> > featureOutputPairs, int splitIndex, int numOutputClasses, bool isUpperSplit = true) {
198 vector<int> classCounts(numOutputClasses, 0);
201 for (int i = 0; i < splitIndex; i++) {
202 if (m->control_pressed) { return 0; }
203 classCounts[featureOutputPairs[i][1]]++;
206 for (int i = splitIndex; i < featureOutputPairs.size(); i++) {
207 if (m->control_pressed) { return 0; }
208 classCounts[featureOutputPairs[i][1]]++;
212 int totalClassCounts = accumulate(classCounts.begin(), classCounts.end(), 0);
214 double splitEntropy = 0.0;
216 for (int i = 0; i < classCounts.size(); i++) {
217 if (m->control_pressed) { return 0; }
218 if (classCounts[i] == 0) { continue; }
219 double probability = (double) classCounts[i] / (double) totalClassCounts;
220 splitEntropy += -(probability * log2(probability));
225 catch(exception& e) {
226 m->errorOut(e, "AbstractDecisionTree", "calcSplitEntropy");
231 /**************************************************************************************************/
233 int AbstractDecisionTree::getSplitPopulation(RFTreeNode* node, vector< vector<int> >& leftChildSamples, vector< vector<int> >& rightChildSamples){
235 // TODO: there is a possibility of optimization if we can recycle the samples in each nodes
236 // we just need to pointers to the samples i.e. vector<int> and use it everywhere and not create the sample
237 // sample over and over again
238 // we need to make this const so that it is not modified by all the function calling
239 // currently purgeTreeNodesDataRecursively() is used for the same purpose, but this can be avoided altogher
240 // if re-using the same data over the classes
242 int splitFeatureGlobalIndex = node->getSplitFeatureIndex();
244 for (int i = 0; i < node->getBootstrappedTrainingSamples().size(); i++) {
245 if (m->control_pressed) { return 0; }
246 vector<int> sample = node->getBootstrappedTrainingSamples()[i];
247 if (m->control_pressed) { return 0; }
248 if (sample[splitFeatureGlobalIndex] < node->getSplitFeatureValue()){ leftChildSamples.push_back(sample); }
249 else{ rightChildSamples.push_back(sample); }
254 catch(exception& e) {
255 m->errorOut(e, "AbstractDecisionTree", "getSplitPopulation");
259 /**************************************************************************************************/
260 // TODO: checkIfAlreadyClassified() verify code
261 // TODO: use bootstrappedOutputVector for easier calculation instead of using getBootstrappedTrainingSamples()
262 bool AbstractDecisionTree::checkIfAlreadyClassified(RFTreeNode* treeNode, int& outputClass) {
265 vector<int> tempOutputClasses;
266 for (int i = 0; i < treeNode->getBootstrappedTrainingSamples().size(); i++) {
267 if (m->control_pressed) { return 0; }
268 int sampleOutputClass = treeNode->getBootstrappedTrainingSamples()[i][numFeatures];
269 vector<int>::iterator it = find(tempOutputClasses.begin(), tempOutputClasses.end(), sampleOutputClass);
270 if (it == tempOutputClasses.end()) { // NOT FOUND
271 tempOutputClasses.push_back(sampleOutputClass);
275 if (tempOutputClasses.size() < 2) { outputClass = tempOutputClasses[0]; return true; }
276 else { outputClass = -1; return false; }
279 catch(exception& e) {
280 m->errorOut(e, "AbstractDecisionTree", "checkIfAlreadyClassified");
285 /**************************************************************************************************/