5 // Created by Patrick Schloss on 10/31/12.
6 // Copyright (c) 2012 University of Michigan. All rights reserved.
9 #include "calcsparcc.h"
10 #include "linearalgebra.h"
12 /**************************************************************************************************/
14 CalcSparcc::CalcSparcc(vector<vector<float> > sharedVector, int maxIterations, int numSamplings, string method){
16 m = MothurOut::getInstance();
17 numOTUs = (int)sharedVector[0].size();
18 numGroups = (int)sharedVector.size();
19 normalizationMethod = method;
20 int numOTUs = (int)sharedVector[0].size();
22 addPseudoCount(sharedVector);
24 vector<vector<vector<float> > > allCorrelations(numSamplings);
26 // float cycClockStart = clock();
27 // unsigned long long cycTimeStart = time(NULL);
29 for(int i=0;i<numSamplings;i++){
30 if (m->control_pressed) { break; }
31 vector<float> logFractions = getLogFractions(sharedVector, method);
32 getT_Matrix(logFractions); //this step is slow...
35 vector<float> basisVariances = getBasisVariances(); //this step is slow...
36 vector<vector<float> > correlation = getBasisCorrelations(basisVariances);
38 excluded.resize(numOTUs);
39 for(int j=0;j<numOTUs;j++){ excluded[j].assign(numOTUs, 0); }
43 int excludeColumn = -1;
46 while(maxRho > 0.10 && iter < maxIterations){
47 maxRho = getExcludedPairs(correlation, excludeRow, excludeColumn);
48 excludeValues(excludeRow, excludeColumn);
49 vector<float> excludedBasisVariances = getBasisVariances();
50 correlation = getBasisCorrelations(excludedBasisVariances);
53 allCorrelations[i] = correlation;
56 if (!m->control_pressed) {
58 getMedian(allCorrelations);
61 median = allCorrelations[0];
64 // cout << median[0][3] << '\t' << median[0][6] << endl;
67 m->errorOut(e, "CalcSparcc", "CalcSparcc");
72 /**************************************************************************************************/
74 void CalcSparcc::addPseudoCount(vector<vector<float> >& sharedVector){
76 for(int i=0;i<numGroups;i++){ //iterate across the groups
77 if (m->control_pressed) { return; }
78 for(int j=0;j<numOTUs;j++){
79 sharedVector[i][j] += 1;
84 m->errorOut(e, "CalcSparcc", "addPseudoCount");
90 /**************************************************************************************************/
92 vector<float> CalcSparcc::getLogFractions(vector<vector<float> > sharedVector, string method){ //dirichlet by default
94 vector<float> logSharedFractions(numGroups * numOTUs, 0);
96 if(method == "dirichlet"){
97 vector<float> alphas(numGroups);
98 for(int i=0;i<numGroups;i++){ //iterate across the groups
99 if (m->control_pressed) { return logSharedFractions; }
100 alphas = RNG.randomDirichlet(sharedVector[i]);
102 for(int j=0;j<numOTUs;j++){
103 logSharedFractions[i * numOTUs + j] = alphas[j];
107 else if(method == "relabund"){
108 for(int i=0;i<numGroups;i++){
109 if (m->control_pressed) { return logSharedFractions; }
111 for(int j=0;j<numOTUs;j++){
112 total += sharedVector[i][j];
114 for(int j=0;j<numOTUs;j++){
115 logSharedFractions[i * numOTUs + j] = sharedVector[i][j]/total;
120 for(int i=0;i<logSharedFractions.size();i++){
121 logSharedFractions[i] = log(logSharedFractions[i]);
124 return logSharedFractions;
126 catch(exception& e) {
127 m->errorOut(e, "CalcSparcc", "addPseudoCount");
133 /**************************************************************************************************/
135 void CalcSparcc::getT_Matrix(vector<float> sharedFractions){
137 tMatrix.resize(numOTUs * numOTUs, 0);
139 vector<float> diff(numGroups);
141 for(int j1=0;j1<numOTUs;j1++){
142 for(int j2=0;j2<j1;j2++){
143 if (m->control_pressed) { return; }
145 for(int i=0;i<numGroups;i++){
146 diff[i] = sharedFractions[i * numOTUs + j1] - sharedFractions[i * numOTUs + j2];
149 mean /= float(numGroups);
151 float variance = 0.0;
152 for(int i=0;i<numGroups;i++){
153 variance += (diff[i] - mean) * (diff[i] - mean);
155 variance /= (float)(numGroups-1);
157 tMatrix[j1 * numOTUs + j2] = variance;
158 tMatrix[j2 * numOTUs + j1] = tMatrix[j1 * numOTUs + j2];
162 catch(exception& e) {
163 m->errorOut(e, "CalcSparcc", "getT_Matrix");
169 /**************************************************************************************************/
171 void CalcSparcc::getT_Vector(){
173 tVector.assign(numOTUs, 0);
175 for(int j1=0;j1<numOTUs;j1++){
176 if (m->control_pressed) { return; }
177 for(int j2=0;j2<numOTUs;j2++){
178 tVector[j1] += tMatrix[j1 * numOTUs + j2];
182 catch(exception& e) {
183 m->errorOut(e, "CalcSparcc", "getT_Vector");
188 /**************************************************************************************************/
190 void CalcSparcc::getD_Matrix(){
192 float d = numOTUs - 1.0;
194 dMatrix.resize(numOTUs);
195 for(int i=0;i<numOTUs;i++){
196 if (m->control_pressed) { return; }
197 dMatrix[i].resize(numOTUs, 1);
201 catch(exception& e) {
202 m->errorOut(e, "CalcSparcc", "getD_Matrix");
207 /**************************************************************************************************/
209 vector<float> CalcSparcc::getBasisVariances(){
213 vector<float> variances = LA.solveEquations(dMatrix, tVector);
215 for(int i=0;i<variances.size();i++){
216 if (m->control_pressed) { return variances; }
217 if(variances[i] < 0){ variances[i] = 1e-4; }
222 catch(exception& e) {
223 m->errorOut(e, "CalcSparcc", "getBasisVariances");
228 /**************************************************************************************************/
230 vector<vector<float> > CalcSparcc::getBasisCorrelations(vector<float> basisVariance){
232 vector<vector<float> > rho(numOTUs);
233 for(int i=0;i<numOTUs;i++){ rho[i].resize(numOTUs, 0); }
235 for(int i=0;i<numOTUs;i++){
236 float var_i = basisVariance[i];
237 float sqrt_var_i = sqrt(var_i);
241 for(int j=0;j<i;j++){
242 if (m->control_pressed) { return rho; }
243 float var_j = basisVariance[j];
245 rho[i][j] = (var_i + var_j - tMatrix[i * numOTUs + j]) / (2.0 * sqrt_var_i * sqrt(var_j));
246 if(rho[i][j] > 1.0) { rho[i][j] = 1.0; }
247 else if(rho[i][j] < -1.0) { rho[i][j] = -1.0; }
249 rho[j][i] = rho[i][j];
256 catch(exception& e) {
257 m->errorOut(e, "CalcSparcc", "getBasisCorrelations");
262 /**************************************************************************************************/
264 float CalcSparcc::getExcludedPairs(vector<vector<float> > rho, int& maxRow, int& maxColumn){
270 for(int i=0;i<numOTUs;i++){
272 for(int j=0;j<i;j++){
273 if (m->control_pressed) { return maxRho; }
274 float tester = abs(rho[i][j]);
276 if(tester > maxRho && excluded[i][j] != 1){
287 catch(exception& e) {
288 m->errorOut(e, "CalcSparcc", "getExcludedPairs");
293 /**************************************************************************************************/
295 void CalcSparcc::excludeValues(int excludeRow, int excludeColumn){
297 tVector[excludeRow] -= tMatrix[excludeRow * numOTUs + excludeColumn];
298 tVector[excludeColumn] -= tMatrix[excludeRow * numOTUs + excludeColumn];
300 dMatrix[excludeRow][excludeColumn] = 0;
301 dMatrix[excludeColumn][excludeRow] = 0;
302 dMatrix[excludeRow][excludeRow]--;
303 dMatrix[excludeColumn][excludeColumn]--;
305 excluded[excludeRow][excludeColumn] = 1;
306 excluded[excludeColumn][excludeRow] = 1;
308 catch(exception& e) {
309 m->errorOut(e, "CalcSparcc", "excludeValues");
314 /**************************************************************************************************/
316 void CalcSparcc::getMedian(vector<vector<vector<float> > > allCorrelations){
318 int numSamples = (int)allCorrelations.size();
319 median.resize(numOTUs);
320 for(int i=0;i<numOTUs;i++){ median[i].assign(numOTUs, 1); }
322 vector<float> hold(numSamples);
324 for(int i=0;i<numOTUs;i++){
325 for(int j=0;j<i;j++){
326 if (m->control_pressed) { return; }
328 for(int k=0;k<numSamples;k++){
329 hold[k] = allCorrelations[k][i][j];
332 sort(hold.begin(), hold.end());
333 median[i][j] = hold[int(numSamples * 0.5)];
334 median[j][i] = median[i][j];
338 catch(exception& e) {
339 m->errorOut(e, "CalcSparcc", "getMedian");
344 /**************************************************************************************************/