]> git.donarmstrong.com Git - rsem.git/blob - Gibbs.cpp
rewrote parallelization of calcCI.cpp
[rsem.git] / Gibbs.cpp
1 #include<cstdio>
2 #include<cstring>
3 #include<cstdlib>
4 #include<cassert>
5 #include<fstream>
6 #include<sstream>
7 #include<vector>
8 #include<pthread.h>
9
10 #include "utils.h"
11 #include "my_assert.h"
12 #include "sampling.h"
13
14 #include "Model.h"
15 #include "SingleModel.h"
16 #include "SingleQModel.h"
17 #include "PairedEndModel.h"
18 #include "PairedEndQModel.h"
19
20 #include "Refs.h"
21 #include "GroupInfo.h"
22
23 using namespace std;
24
25 struct Params {
26         int no, nsamples;
27         FILE *fo;
28         engine_type *engine;
29         double *pme_c, *pve_c; //posterior mean and variance vectors on counts
30         double *pme_theta;
31 };
32
33
34 struct Item {
35         int sid;
36         double conprb;
37
38         Item(int sid, double conprb) {
39                 this->sid = sid;
40                 this->conprb = conprb;
41         }
42 };
43
44 int nThreads;
45
46 int model_type;
47 int m, M, N0, N1, nHits;
48 double totc;
49 int BURNIN, NSAMPLES, GAP;
50 char imdName[STRLEN], statName[STRLEN];
51 char thetaF[STRLEN], ofgF[STRLEN], groupF[STRLEN], refF[STRLEN], modelF[STRLEN];
52 char cvsF[STRLEN];
53
54 Refs refs;
55 GroupInfo gi;
56
57 vector<int> s;
58 vector<Item> hits;
59
60 vector<double> theta;
61
62 vector<double> pme_c, pve_c; //global posterior mean and variance vectors on counts
63 vector<double> pme_theta, eel;
64
65 bool var_opt;
66 bool quiet;
67
68 Params *paramsArray;
69 pthread_t *threads;
70 pthread_attr_t attr;
71 void *status;
72 int rc;
73
74 void load_data(char* reference_name, char* statName, char* imdName) {
75         ifstream fin;
76         string line;
77         int tmpVal;
78
79         //load reference file
80         sprintf(refF, "%s.seq", reference_name);
81         refs.loadRefs(refF, 1);
82         M = refs.getM();
83
84         //load groupF
85         sprintf(groupF, "%s.grp", reference_name);
86         gi.load(groupF);
87         m = gi.getm();
88
89         //load thetaF
90         sprintf(thetaF, "%s.theta",statName);
91         fin.open(thetaF);
92         general_assert(fin.is_open(), "Cannot open " + cstrtos(thetaF) + "!");
93         fin>>tmpVal;
94         general_assert(tmpVal == M + 1, "Number of transcripts is not consistent in " + cstrtos(refF) + " and " + cstrtos(thetaF) + "!");
95         theta.assign(M + 1, 0);
96         for (int i = 0; i <= M; i++) fin>>theta[i];
97         fin.close();
98
99         //load ofgF;
100         sprintf(ofgF, "%s.ofg", imdName);
101         fin.open(ofgF);
102         general_assert(fin.is_open(), "Cannot open " + cstrtos(ofgF) + "!");
103         fin>>tmpVal>>N0;
104         general_assert(tmpVal == M, "M in " + cstrtos(ofgF) + " is not consistent with " + cstrtos(refF) + "!");
105         getline(fin, line);
106
107         s.clear(); hits.clear();
108         s.push_back(0);
109         while (getline(fin, line)) {
110                 istringstream strin(line);
111                 int sid;
112                 double conprb;
113
114                 while (strin>>sid>>conprb) {
115                         hits.push_back(Item(sid, conprb));
116                 }
117                 s.push_back(hits.size());
118         }
119         fin.close();
120
121         N1 = s.size() - 1;
122         nHits = hits.size();
123
124         totc = N0 + N1 + (M + 1);
125
126         if (verbose) { printf("Loading Data is finished!\n"); }
127 }
128
129 // assign threads
130 void init() {
131         int quotient, left;
132         char outF[STRLEN];
133         char splitF[STRLEN];
134
135         quotient = NSAMPLES / nThreads;
136         left = NSAMPLES % nThreads;
137
138         sprintf(cvsF, "%s.countvectors", imdName);
139         paramsArray = new Params[nThreads];
140         threads = new pthread_t[nThreads];
141
142         for (int i = 0; i < nThreads; i++) {
143                 paramsArray[i].no = i;
144
145                 paramsArray[i].nsamples = quotient;
146                 if (i < left) paramsArray[i].nsamples++;
147
148                 sprintf(outF, "%s%d", cvsF, i);
149                 paramsArray[i].fo = fopen(outF, "w");
150
151                 paramsArray[i].engine = engineFactory::new_engine();
152                 paramsArray[i].pme_c = new double[M + 1];
153                 memset(paramsArray[i].pme_c, 0, sizeof(double) * (M + 1));
154                 paramsArray[i].pve_c = new double[M + 1];
155                 memset(paramsArray[i].pve_c, 0, sizeof(double) * (M + 1));
156                 paramsArray[i].pme_theta = new double[M + 1];
157                 memset(paramsArray[i].pme_theta, 0, sizeof(double) * (M + 1));
158         }
159
160         // output task splitting information
161         sprintf(splitF, "%s.split", imdName);
162         FILE *fo = fopen(splitF, "w");
163         fprintf(fo, "%d", nThreads);
164         for (int i = 0; i < nThreads; i++) fprintf(fo, " %d", paramsArray[i].nsamples);
165         fprintf(fo, "\n");
166         fclose(fo);
167
168         /* set thread attribute to be joinable */
169         pthread_attr_init(&attr);
170         pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
171
172         if (verbose) { printf("Initialization finished!"); }
173 }
174
175 //sample theta from Dir(1)
176 void sampleTheta(engine_type& engine, vector<double>& theta) {
177         gamma_dist gm(1);
178         gamma_generator gmg(engine, gm);
179         double denom;
180
181         theta.assign(M + 1, 0);
182         denom = 0.0;
183         for (int i = 0; i <= M; i++) {
184                 theta[i] = gmg();
185                 denom += theta[i];
186         }
187         assert(denom > EPSILON);
188         for (int i = 0; i <= M; i++) theta[i] /= denom;
189 }
190
191 void writeCountVector(FILE* fo, vector<int>& counts) {
192         for (int i = 0; i < M; i++) {
193                 fprintf(fo, "%d ", counts[i]);
194         }
195         fprintf(fo, "%d\n", counts[M]);
196 }
197
198 void* Gibbs(void* arg) {
199         int len, fr, to;
200         int CHAINLEN;
201         Params *params = (Params*)arg;
202
203         vector<double> theta;
204         vector<int> z, counts;
205         vector<double> arr;
206
207         uniform01 rg(*params->engine);
208
209         // generate initial state
210         sampleTheta(*params->engine, theta);
211
212         z.assign(N1, 0);
213
214         counts.assign(M + 1, 1); // 1 pseudo count
215         counts[0] += N0;
216
217         for (int i = 0; i < N1; i++) {
218                 fr = s[i]; to = s[i + 1];
219                 len = to - fr;
220                 arr.assign(len, 0);
221                 for (int j = fr; j < to; j++) {
222                         arr[j - fr] = theta[hits[j].sid] * hits[j].conprb;
223                         if (j > fr) arr[j - fr] += arr[j - fr - 1];  // cumulative
224                 }
225                 z[i] = hits[fr + sample(rg, arr, len)].sid;
226                 ++counts[z[i]];
227         }
228
229         // Gibbs sampling
230         CHAINLEN = 1 + (params->nsamples - 1) * GAP;
231         for (int ROUND = 1; ROUND <= BURNIN + CHAINLEN; ROUND++) {
232
233                 for (int i = 0; i < N1; i++) {
234                         --counts[z[i]];
235                         fr = s[i]; to = s[i + 1]; len = to - fr;
236                         arr.assign(len, 0);
237                         for (int j = fr; j < to; j++) {
238                                 arr[j - fr] = counts[hits[j].sid] * hits[j].conprb;
239                                 if (j > fr) arr[j - fr] += arr[j - fr - 1]; //cumulative
240                         }
241                         z[i] = hits[fr + sample(rg, arr, len)].sid;
242                         ++counts[z[i]];
243                 }
244
245                 if (ROUND > BURNIN) {
246                         if ((ROUND - BURNIN - 1) % GAP == 0) {
247                                 writeCountVector(params->fo, counts);
248                                 for (int i = 0; i <= M; i++) {
249                                         params->pme_c[i] += counts[i] - 1;
250                                         params->pve_c[i] += (counts[i] - 1) * (counts[i] - 1);
251                                         params->pme_theta[i] += counts[i] / totc;
252                                 }
253                         }
254                 }
255
256                 if (verbose && ROUND % 100 == 0) { printf("Thread %d, ROUND %d is finished!\n", params->no, ROUND); }
257         }
258
259         return NULL;
260 }
261
262 void release() {
263 //      char inpF[STRLEN], command[STRLEN];
264         string line;
265
266         /* destroy attribute */
267         pthread_attr_destroy(&attr);
268         delete[] threads;
269
270         pme_c.assign(M + 1, 0);
271         pve_c.assign(M + 1, 0);
272         pme_theta.assign(M + 1, 0);
273         for (int i = 0; i < nThreads; i++) {
274                 fclose(paramsArray[i].fo);
275                 delete paramsArray[i].engine;
276                 for (int j = 0; j <= M; j++) {
277                         pme_c[j] += paramsArray[i].pme_c[j];
278                         pve_c[j] += paramsArray[i].pve_c[j];
279                         pme_theta[j] += paramsArray[i].pme_theta[j];
280                 }
281                 delete[] paramsArray[i].pme_c;
282                 delete[] paramsArray[i].pve_c;
283                 delete[] paramsArray[i].pme_theta;
284         }
285         delete[] paramsArray;
286
287
288         for (int i = 0; i <= M; i++) {
289                 pme_c[i] /= NSAMPLES;
290                 pve_c[i] = (pve_c[i] - NSAMPLES * pme_c[i] * pme_c[i]) / (NSAMPLES - 1);
291                 pme_theta[i] /= NSAMPLES;
292         }
293
294         /*
295         // combine files
296         FILE *fo = fopen(cvsF, "a");
297         for (int i = 1; i < nThreads; i++) {
298                 sprintf(inpF, "%s%d", cvsF, i);
299                 ifstream fin(inpF);
300                 while (getline(fin, line)) {
301                         fprintf(fo, "%s\n", line.c_str());
302                 }
303                 fin.close();
304                 sprintf(command, "rm -f %s", inpF);
305                 int status = system(command);
306                 general_assert(status == 0, "Fail to delete file " + cstrtos(inpF) + "!");
307         }
308         fclose(fo);
309         */
310 }
311
312 template<class ModelType>
313 void calcExpectedEffectiveLengths(ModelType& model) {
314         int lb, ub, span;
315         double *pdf = NULL, *cdf = NULL, *clen = NULL; // clen[i] = \sigma_{j=1}^{i}pdf[i]*(lb+i)
316   
317         model.getGLD().copyTo(pdf, cdf, lb, ub, span);
318         clen = new double[span + 1];
319         clen[0] = 0.0;
320         for (int i = 1; i <= span; i++) {
321                 clen[i] = clen[i - 1] + pdf[i] * (lb + i);
322         }
323
324         eel.assign(M + 1, 0.0);
325         for (int i = 1; i <= M; i++) {
326                 int totLen = refs.getRef(i).getTotLen();
327                 int fullLen = refs.getRef(i).getFullLen();
328                 int pos1 = max(min(totLen - fullLen + 1, ub) - lb, 0);
329                 int pos2 = max(min(totLen, ub) - lb, 0);
330
331                 if (pos2 == 0) { eel[i] = 0.0; continue; }
332     
333                 eel[i] = fullLen * cdf[pos1] + ((cdf[pos2] - cdf[pos1]) * (totLen + 1) - (clen[pos2] - clen[pos1]));
334                 assert(eel[i] >= 0);
335                 if (eel[i] < MINEEL) { eel[i] = 0.0; }
336         }
337   
338         delete[] pdf;
339         delete[] cdf;
340         delete[] clen;
341 }
342
343 template<class ModelType>
344 void writeEstimatedParameters(char* modelF, char* imdName) {
345         ModelType model;
346         double denom;
347         char outF[STRLEN];
348         FILE *fo;
349
350         model.read(modelF);
351
352         calcExpectedEffectiveLengths<ModelType>(model);
353
354         denom = pme_theta[0];
355         for (int i = 1; i <= M; i++)
356           if (eel[i] < EPSILON) pme_theta[i] = 0.0;
357           else denom += pme_theta[i];
358
359         general_assert(denom >= EPSILON, "No Expected Effective Length is no less than " + ftos(MINEEL, 6) + "?!");
360
361         for (int i = 0; i <= M; i++) pme_theta[i] /= denom;
362
363         denom = 0.0;
364         double *mw = model.getMW();
365         for (int i = 0; i <= M; i++) {
366           pme_theta[i] = (mw[i] < EPSILON ? 0.0 : pme_theta[i] / mw[i]);
367           denom += pme_theta[i];
368         }
369         assert(denom >= EPSILON);
370         for (int i = 0; i <= M; i++) pme_theta[i] /= denom;
371
372         //calculate tau values
373         double *tau = new double[M + 1];
374         memset(tau, 0, sizeof(double) * (M + 1));
375
376         denom = 0.0;
377         for (int i = 1; i <= M; i++) 
378           if (eel[i] > EPSILON) {
379             tau[i] = pme_theta[i] / eel[i];
380             denom += tau[i];
381           }
382
383         general_assert(denom >= EPSILON, "No alignable reads?!");
384
385         for (int i = 1; i <= M; i++) {
386                 tau[i] /= denom;
387         }
388
389         //isoform level results
390         sprintf(outF, "%s.iso_res", imdName);
391         fo = fopen(outF, "a");
392         general_assert(fo != NULL, "Cannot open " + cstrtos(outF) + "!");
393
394         for (int i = 1; i <= M; i++)
395                 fprintf(fo, "%.2f%c", pme_c[i], (i < M ? '\t' : '\n'));
396         for (int i = 1; i <= M; i++)
397                 fprintf(fo, "%.15g%c", tau[i], (i < M ? '\t' : '\n'));
398
399         fclose(fo);
400
401         //gene level results
402         sprintf(outF, "%s.gene_res", imdName);
403         fo = fopen(outF, "a");
404         general_assert(fo != NULL, "Cannot open " + cstrtos(outF) + "!");
405
406         for (int i = 0; i < m; i++) {
407                 double sumC = 0.0; //  sum of pme counts
408                 int b = gi.spAt(i), e = gi.spAt(i + 1);
409                 for (int j = b; j < e; j++) {
410                         sumC += pme_c[j];
411                 }
412                 fprintf(fo, "%.15g%c", sumC, (i < m - 1 ? '\t' : '\n'));
413         }
414         for (int i = 0; i < m; i++) {
415                 double sumT = 0.0; //  sum of tau values
416                 int b = gi.spAt(i), e = gi.spAt(i + 1);
417                 for (int j = b; j < e; j++) {
418                         sumT += tau[j];
419                 }
420                 fprintf(fo, "%.15g%c", sumT, (i < m - 1 ? '\t' : '\n'));
421         }
422         fclose(fo);
423
424         delete[] tau;
425
426         if (verbose) { printf("Gibbs based expression values are written!\n"); }
427 }
428
429 int main(int argc, char* argv[]) {
430         if (argc < 7) {
431                 printf("Usage: rsem-run-gibbs-multi reference_name sample_name sampleToken BURNIN NSAMPLES GAP [-p #Threads] [--var] [-q]\n");
432                 exit(-1);
433         }
434
435         BURNIN = atoi(argv[4]);
436         NSAMPLES = atoi(argv[5]);
437         GAP = atoi(argv[6]);
438         sprintf(imdName, "%s.temp/%s", argv[2], argv[3]);
439         sprintf(statName, "%s.stat/%s", argv[2], argv[3]);
440         load_data(argv[1], statName, imdName);
441
442         nThreads = 1;
443         var_opt = false;
444         quiet = false;
445
446         for (int i = 7; i < argc; i++) {
447                 if (!strcmp(argv[i], "-p")) nThreads = atoi(argv[i + 1]);
448                 if (!strcmp(argv[i], "--var")) var_opt = true;
449                 if (!strcmp(argv[i], "-q")) quiet = true;
450         }
451         verbose = !quiet;
452
453         assert(NSAMPLES > 1); // Otherwise, we cannot calculate posterior variance
454
455         if (nThreads > NSAMPLES) {
456                 nThreads = NSAMPLES;
457                 printf("Warning: Number of samples is less than number of threads! Change the number of threads to %d!\n", nThreads);
458         }
459
460         if (verbose) printf("Gibbs started!\n");
461
462         init();
463         for (int i = 0; i < nThreads; i++) {
464                 rc = pthread_create(&threads[i], &attr, Gibbs, (void*)(&paramsArray[i]));
465                 pthread_assert(rc, "pthread_create", "Cannot create thread " + itos(i) + " (numbered from 0)!");
466         }
467         for (int i = 0; i < nThreads; i++) {
468                 rc = pthread_join(threads[i], &status);
469                 pthread_assert(rc, "pthread_join", "Cannot join thread " + itos(i) + " (numbered from 0)!");
470         }
471         release();
472
473         if (verbose) printf("Gibbs finished!\n");
474
475         sprintf(modelF, "%s.model", statName);
476         FILE *fi = fopen(modelF, "r");
477         general_assert(fi != NULL, "Cannot open " + cstrtos(modelF) + "!");
478         assert(fscanf(fi, "%d", &model_type) == 1);
479         fclose(fi);
480
481         switch(model_type) {
482         case 0 : writeEstimatedParameters<SingleModel>(modelF, imdName); break;
483         case 1 : writeEstimatedParameters<SingleQModel>(modelF, imdName); break;
484         case 2 : writeEstimatedParameters<PairedEndModel>(modelF, imdName); break;
485         case 3 : writeEstimatedParameters<PairedEndQModel>(modelF, imdName); break;
486         }
487
488         if (var_opt) {
489                 char varF[STRLEN];
490
491                 sprintf(varF, "%s.var", statName);
492                 FILE *fo = fopen(varF, "w");
493                 general_assert(fo != NULL, "Cannot open " + cstrtos(varF) + "!");
494                 for (int i = 0; i < m; i++) {
495                         int b = gi.spAt(i), e = gi.spAt(i + 1), number_of_isoforms = e - b;
496                         for (int j = b; j < e; j++) {
497                                 fprintf(fo, "%s\t%d\t%.15g\t%.15g\n", refs.getRef(j).getName().c_str(), number_of_isoforms, pme_c[j], pve_c[j]);
498                         }
499                 }
500                 fclose(fo);
501         }
502
503         return 0;
504 }