]> git.donarmstrong.com Git - mothur.git/blob - calcsparcc.cpp
working on pam
[mothur.git] / calcsparcc.cpp
1 //
2 //  runSparcc.cpp
3 //  PDSSparCC
4 //
5 //  Created by Patrick Schloss on 10/31/12.
6 //  Copyright (c) 2012 University of Michigan. All rights reserved.
7 //
8
9 #include "calcsparcc.h"
10 #include "linearalgebra.h"
11
12 /**************************************************************************************************/
13
14 CalcSparcc::CalcSparcc(vector<vector<float> > sharedVector, int maxIterations, int numSamplings, string method){
15     try {
16         m = MothurOut::getInstance();
17         numOTUs = (int)sharedVector[0].size();
18         numGroups = (int)sharedVector.size();
19         normalizationMethod = method;
20         int numOTUs = (int)sharedVector[0].size();
21         
22         addPseudoCount(sharedVector);
23         
24         vector<vector<vector<float> > > allCorrelations(numSamplings);
25         
26         //    float cycClockStart = clock();
27         //    unsigned long long cycTimeStart = time(NULL);
28         
29         for(int i=0;i<numSamplings;i++){
30             if (m->control_pressed) { break; }
31             vector<float> logFractions =  getLogFractions(sharedVector, method);
32             getT_Matrix(logFractions);     //this step is slow...
33             getT_Vector();
34             getD_Matrix();
35             vector<float> basisVariances = getBasisVariances();     //this step is slow...
36             vector<vector<float> > correlation = getBasisCorrelations(basisVariances);
37             
38             excluded.resize(numOTUs);
39             for(int j=0;j<numOTUs;j++){ excluded[j].assign(numOTUs, 0); }
40             
41             float maxRho = 1;
42             int excludeRow = -1;
43             int excludeColumn = -1;
44             
45             int iter = 0;
46             while(maxRho > 0.10 && iter < maxIterations){
47                 maxRho = getExcludedPairs(correlation, excludeRow, excludeColumn);
48                 excludeValues(excludeRow, excludeColumn);
49                 vector<float> excludedBasisVariances = getBasisVariances();
50                 correlation = getBasisCorrelations(excludedBasisVariances);
51                 iter++;
52             }
53             allCorrelations[i] = correlation;
54         }
55         
56         if (!m->control_pressed) {
57             if(numSamplings > 1){
58                 getMedian(allCorrelations);
59             }
60             else{
61                 median = allCorrelations[0];
62             }
63         }
64 //    cout << median[0][3] << '\t' << median[0][6] << endl;
65     }
66     catch(exception& e) {
67         m->errorOut(e, "CalcSparcc", "CalcSparcc");
68         exit(1);
69     }
70 }
71     
72 /**************************************************************************************************/
73
74 void CalcSparcc::addPseudoCount(vector<vector<float> >& sharedVector){
75     try {
76         for(int i=0;i<numGroups;i++){   //iterate across the groups
77             if (m->control_pressed) { return; }
78             for(int j=0;j<numOTUs;j++){
79                 sharedVector[i][j] += 1;
80             }
81         }
82     }
83     catch(exception& e) {
84         m->errorOut(e, "CalcSparcc", "addPseudoCount");
85         exit(1);
86     }
87
88 }
89
90 /**************************************************************************************************/
91
92 vector<float> CalcSparcc::getLogFractions(vector<vector<float> > sharedVector, string method){   //dirichlet by default
93     try {
94         vector<float> logSharedFractions(numGroups * numOTUs, 0);
95         
96         if(method == "dirichlet"){
97             vector<float> alphas(numGroups);
98             for(int i=0;i<numGroups;i++){   //iterate across the groups
99                 if (m->control_pressed) { return logSharedFractions; }
100                 alphas = RNG.randomDirichlet(sharedVector[i]);
101                 
102                 for(int j=0;j<numOTUs;j++){
103                     logSharedFractions[i * numOTUs + j] = alphas[j];
104                 }
105             }
106         }
107         else if(method == "relabund"){
108             for(int i=0;i<numGroups;i++){
109                 if (m->control_pressed) { return logSharedFractions; }
110                 float total = 0.0;
111                 for(int j=0;j<numOTUs;j++){
112                     total += sharedVector[i][j];
113                 }
114                 for(int j=0;j<numOTUs;j++){
115                     logSharedFractions[i * numOTUs + j] = sharedVector[i][j]/total;
116                 }
117             }
118         }
119         
120         for(int i=0;i<logSharedFractions.size();i++){
121             logSharedFractions[i] = log(logSharedFractions[i]);
122         }
123         
124         return logSharedFractions;
125     }
126     catch(exception& e) {
127         m->errorOut(e, "CalcSparcc", "addPseudoCount");
128         exit(1);
129     }
130
131 }
132
133 /**************************************************************************************************/
134
135 void CalcSparcc::getT_Matrix(vector<float> sharedFractions){
136     try {
137         tMatrix.resize(numOTUs * numOTUs, 0);
138         
139         vector<float> diff(numGroups);
140         
141         for(int j1=0;j1<numOTUs;j1++){
142             for(int j2=0;j2<j1;j2++){
143                 if (m->control_pressed) { return; }
144                 float mean = 0.0;
145                 for(int i=0;i<numGroups;i++){
146                     diff[i] = sharedFractions[i * numOTUs + j1] - sharedFractions[i * numOTUs + j2];
147                     mean += diff[i];
148                 }
149                 mean /= float(numGroups);
150                 
151                 float variance = 0.0;
152                 for(int i=0;i<numGroups;i++){
153                     variance += (diff[i] - mean) * (diff[i] - mean);
154                 }
155                 variance /= (float)(numGroups-1);
156                 
157                 tMatrix[j1 * numOTUs + j2] = variance;
158                 tMatrix[j2 * numOTUs + j1] = tMatrix[j1 * numOTUs + j2];
159             }
160         }
161     }
162     catch(exception& e) {
163         m->errorOut(e, "CalcSparcc", "getT_Matrix");
164         exit(1);
165     }
166
167 }
168
169 /**************************************************************************************************/
170
171 void CalcSparcc::getT_Vector(){
172     try {
173         tVector.assign(numOTUs, 0);
174         
175         for(int j1=0;j1<numOTUs;j1++){
176             if (m->control_pressed) { return; }
177             for(int j2=0;j2<numOTUs;j2++){
178                 tVector[j1] += tMatrix[j1 * numOTUs + j2];
179             }
180         }
181     }
182     catch(exception& e) {
183         m->errorOut(e, "CalcSparcc", "getT_Vector");
184         exit(1);
185     }
186 }
187
188 /**************************************************************************************************/
189
190 void CalcSparcc::getD_Matrix(){
191     try {
192         float d = numOTUs - 1.0;
193         
194         dMatrix.resize(numOTUs);
195         for(int i=0;i<numOTUs;i++){
196             if (m->control_pressed) { return; }
197             dMatrix[i].resize(numOTUs, 1);
198             dMatrix[i][i] = d;
199         }
200     }
201     catch(exception& e) {
202         m->errorOut(e, "CalcSparcc", "getD_Matrix");
203         exit(1);
204     }
205 }
206
207 /**************************************************************************************************/
208
209 vector<float> CalcSparcc::getBasisVariances(){
210     try {
211         LinearAlgebra LA;
212         
213         vector<float> variances = LA.solveEquations(dMatrix, tVector);
214         
215         for(int i=0;i<variances.size();i++){
216             if (m->control_pressed) { return variances; }
217             if(variances[i] < 0){   variances[i] = 1e-4;    }
218         }
219         
220         return variances;
221     }
222     catch(exception& e) {
223         m->errorOut(e, "CalcSparcc", "getBasisVariances");
224         exit(1);
225     }
226 }
227
228 /**************************************************************************************************/
229
230 vector<vector<float> > CalcSparcc::getBasisCorrelations(vector<float> basisVariance){
231     try {
232         vector<vector<float> > rho(numOTUs);
233         for(int i=0;i<numOTUs;i++){ rho[i].resize(numOTUs, 0);    }
234         
235         for(int i=0;i<numOTUs;i++){
236             float var_i = basisVariance[i];
237             float sqrt_var_i = sqrt(var_i);
238             
239             rho[i][i] = 1.00;
240             
241             for(int j=0;j<i;j++){
242                 if (m->control_pressed) { return rho; }
243                 float var_j = basisVariance[j];
244                 
245                 rho[i][j] = (var_i + var_j - tMatrix[i * numOTUs + j]) / (2.0 * sqrt_var_i * sqrt(var_j));
246                 if(rho[i][j] > 1.0)         {   rho[i][j] = 1.0;   }
247                 else if(rho[i][j] < -1.0)   {   rho[i][j] = -1.0;  }
248                 
249                 rho[j][i] = rho[i][j];
250                 
251             }
252         }
253         
254         return rho;
255     }
256     catch(exception& e) {
257         m->errorOut(e, "CalcSparcc", "getBasisCorrelations");
258         exit(1);
259     }
260 }
261
262 /**************************************************************************************************/
263
264 float CalcSparcc::getExcludedPairs(vector<vector<float> > rho, int& maxRow, int& maxColumn){
265     try {
266         float maxRho = 0;
267         maxRow = -1;
268         maxColumn = -1;
269         
270         for(int i=0;i<numOTUs;i++){
271             
272             for(int j=0;j<i;j++){
273                 if (m->control_pressed) { return maxRho; }
274                 float tester = abs(rho[i][j]);
275                 
276                 if(tester > maxRho && excluded[i][j] != 1){
277                     maxRho = tester;
278                     maxRow = i;
279                     maxColumn = j;
280                 }
281             }
282             
283         }
284         
285         return maxRho;
286     }
287     catch(exception& e) {
288         m->errorOut(e, "CalcSparcc", "getExcludedPairs");
289         exit(1);
290     }
291 }
292
293 /**************************************************************************************************/
294
295 void CalcSparcc::excludeValues(int excludeRow, int excludeColumn){
296     try { 
297         tVector[excludeRow] -= tMatrix[excludeRow * numOTUs + excludeColumn];
298         tVector[excludeColumn] -= tMatrix[excludeRow * numOTUs + excludeColumn];
299         
300         dMatrix[excludeRow][excludeColumn] = 0;
301         dMatrix[excludeColumn][excludeRow] = 0;
302         dMatrix[excludeRow][excludeRow]--;
303         dMatrix[excludeColumn][excludeColumn]--;
304         
305         excluded[excludeRow][excludeColumn] = 1;
306         excluded[excludeColumn][excludeRow] = 1;
307     }
308     catch(exception& e) {
309         m->errorOut(e, "CalcSparcc", "excludeValues");
310         exit(1);
311     }
312 }
313
314 /**************************************************************************************************/
315
316 void CalcSparcc::getMedian(vector<vector<vector<float> > > allCorrelations){
317     try {
318         int numSamples = (int)allCorrelations.size();
319         median.resize(numOTUs);
320         for(int i=0;i<numOTUs;i++){ median[i].assign(numOTUs, 1);   }
321         
322         vector<float> hold(numSamples);
323         
324         for(int i=0;i<numOTUs;i++){
325             for(int j=0;j<i;j++){
326                 if (m->control_pressed) { return; }
327                 
328                 for(int k=0;k<numSamples;k++){
329                     hold[k] = allCorrelations[k][i][j];
330                 }
331                 
332                 sort(hold.begin(), hold.end());
333                 median[i][j] = hold[int(numSamples * 0.5)];
334                 median[j][i] = median[i][j];
335             }
336         }
337     }
338     catch(exception& e) {
339         m->errorOut(e, "CalcSparcc", "getMedian");
340         exit(1);
341     }
342 }
343
344 /**************************************************************************************************/