]> git.donarmstrong.com Git - rsem.git/blob - Gibbs.cpp
add --no-bam-output option to not generate output bam files
[rsem.git] / Gibbs.cpp
1 #include<cstdio>
2 #include<cstring>
3 #include<cstdlib>
4 #include<cassert>
5 #include<fstream>
6 #include<sstream>
7 #include<vector>
8 #include<pthread.h>
9
10 #include "utils.h"
11 #include "my_assert.h"
12 #include "sampling.h"
13
14 #include "Model.h"
15 #include "SingleModel.h"
16 #include "SingleQModel.h"
17 #include "PairedEndModel.h"
18 #include "PairedEndQModel.h"
19
20 #include "Refs.h"
21 #include "GroupInfo.h"
22
23 using namespace std;
24
25 struct Params {
26         int no, nsamples;
27         FILE *fo;
28         engine_type *engine;
29         double *pme_c, *pve_c; //posterior mean and variance vectors on counts
30         double *pme_theta;
31 };
32
33
34 struct Item {
35         int sid;
36         double conprb;
37
38         Item(int sid, double conprb) {
39                 this->sid = sid;
40                 this->conprb = conprb;
41         }
42 };
43
44 int nThreads;
45
46 int model_type;
47 int m, M, N0, N1, nHits;
48 double totc;
49 int BURNIN, NSAMPLES, GAP;
50 char imdName[STRLEN], statName[STRLEN];
51 char thetaF[STRLEN], ofgF[STRLEN], groupF[STRLEN], refF[STRLEN], modelF[STRLEN];
52 char cvsF[STRLEN];
53
54 Refs refs;
55 GroupInfo gi;
56
57 vector<int> s;
58 vector<Item> hits;
59
60 vector<double> theta;
61
62 vector<double> pme_c, pve_c; //global posterior mean and variance vectors on counts
63 vector<double> pme_theta, eel;
64
65 bool var_opt;
66 bool quiet;
67
68 Params *paramsArray;
69 pthread_t *threads;
70 pthread_attr_t attr;
71 void *status;
72 int rc;
73
74 void load_data(char* reference_name, char* statName, char* imdName) {
75         ifstream fin;
76         string line;
77         int tmpVal;
78
79         //load reference file
80         sprintf(refF, "%s.seq", reference_name);
81         refs.loadRefs(refF, 1);
82         M = refs.getM();
83
84         //load groupF
85         sprintf(groupF, "%s.grp", reference_name);
86         gi.load(groupF);
87         m = gi.getm();
88
89         //load thetaF
90         sprintf(thetaF, "%s.theta",statName);
91         fin.open(thetaF);
92         general_assert(fin.is_open(), "Cannot open " + cstrtos(thetaF) + "!");
93         fin>>tmpVal;
94         general_assert(tmpVal == M + 1, "Number of transcripts is not consistent in " + cstrtos(refF) + " and " + cstrtos(thetaF) + "!");
95         theta.assign(M + 1, 0);
96         for (int i = 0; i <= M; i++) fin>>theta[i];
97         fin.close();
98
99         //load ofgF;
100         sprintf(ofgF, "%s.ofg", imdName);
101         fin.open(ofgF);
102         general_assert(fin.is_open(), "Cannot open " + cstrtos(ofgF) + "!");
103         fin>>tmpVal>>N0;
104         general_assert(tmpVal == M, "M in " + cstrtos(ofgF) + " is not consistent with " + cstrtos(refF) + "!");
105         getline(fin, line);
106
107         s.clear(); hits.clear();
108         s.push_back(0);
109         while (getline(fin, line)) {
110                 istringstream strin(line);
111                 int sid;
112                 double conprb;
113
114                 while (strin>>sid>>conprb) {
115                         hits.push_back(Item(sid, conprb));
116                 }
117                 s.push_back(hits.size());
118         }
119         fin.close();
120
121         N1 = s.size() - 1;
122         nHits = hits.size();
123
124         totc = N0 + N1 + (M + 1);
125
126         if (verbose) { printf("Loading Data is finished!\n"); }
127 }
128
129 // assign threads
130 void init() {
131         int quotient, left;
132         char outF[STRLEN];
133
134         quotient = NSAMPLES / nThreads;
135         left = NSAMPLES % nThreads;
136
137         sprintf(cvsF, "%s.countvectors", imdName);
138         paramsArray = new Params[nThreads];
139         threads = new pthread_t[nThreads];
140
141         for (int i = 0; i < nThreads; i++) {
142                 paramsArray[i].no = i;
143
144                 paramsArray[i].nsamples = quotient;
145                 if (i < left) paramsArray[i].nsamples++;
146
147                 sprintf(outF, "%s%d", cvsF, i);
148                 paramsArray[i].fo = fopen(outF, "w");
149
150                 paramsArray[i].engine = engineFactory::new_engine();
151                 paramsArray[i].pme_c = new double[M + 1];
152                 memset(paramsArray[i].pme_c, 0, sizeof(double) * (M + 1));
153                 paramsArray[i].pve_c = new double[M + 1];
154                 memset(paramsArray[i].pve_c, 0, sizeof(double) * (M + 1));
155                 paramsArray[i].pme_theta = new double[M + 1];
156                 memset(paramsArray[i].pme_theta, 0, sizeof(double) * (M + 1));
157         }
158
159         /* set thread attribute to be joinable */
160         pthread_attr_init(&attr);
161         pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
162
163         if (verbose) { printf("Initialization finished!"); }
164 }
165
166 //sample theta from Dir(1)
167 void sampleTheta(engine_type& engine, vector<double>& theta) {
168         gamma_dist gm(1);
169         gamma_generator gmg(engine, gm);
170         double denom;
171
172         theta.assign(M + 1, 0);
173         denom = 0.0;
174         for (int i = 0; i <= M; i++) {
175                 theta[i] = gmg();
176                 denom += theta[i];
177         }
178         assert(denom > EPSILON);
179         for (int i = 0; i <= M; i++) theta[i] /= denom;
180 }
181
182 void writeCountVector(FILE* fo, vector<int>& counts) {
183         for (int i = 0; i < M; i++) {
184                 fprintf(fo, "%d ", counts[i]);
185         }
186         fprintf(fo, "%d\n", counts[M]);
187 }
188
189 void* Gibbs(void* arg) {
190         int len, fr, to;
191         int CHAINLEN;
192         Params *params = (Params*)arg;
193
194         vector<double> theta;
195         vector<int> z, counts;
196         vector<double> arr;
197
198         uniform01 rg(*params->engine);
199
200         // generate initial state
201         sampleTheta(*params->engine, theta);
202
203         z.assign(N1, 0);
204
205         counts.assign(M + 1, 1); // 1 pseudo count
206         counts[0] += N0;
207
208         for (int i = 0; i < N1; i++) {
209                 fr = s[i]; to = s[i + 1];
210                 len = to - fr;
211                 arr.assign(len, 0);
212                 for (int j = fr; j < to; j++) {
213                         arr[j - fr] = theta[hits[j].sid] * hits[j].conprb;
214                         if (j > fr) arr[j - fr] += arr[j - fr - 1];  // cumulative
215                 }
216                 z[i] = hits[fr + sample(rg, arr, len)].sid;
217                 ++counts[z[i]];
218         }
219
220         // Gibbs sampling
221         CHAINLEN = 1 + (params->nsamples - 1) * GAP;
222         for (int ROUND = 1; ROUND <= BURNIN + CHAINLEN; ROUND++) {
223
224                 for (int i = 0; i < N1; i++) {
225                         --counts[z[i]];
226                         fr = s[i]; to = s[i + 1]; len = to - fr;
227                         arr.assign(len, 0);
228                         for (int j = fr; j < to; j++) {
229                                 arr[j - fr] = counts[hits[j].sid] * hits[j].conprb;
230                                 if (j > fr) arr[j - fr] += arr[j - fr - 1]; //cumulative
231                         }
232                         z[i] = hits[fr + sample(rg, arr, len)].sid;
233                         ++counts[z[i]];
234                 }
235
236                 if (ROUND > BURNIN) {
237                         if ((ROUND - BURNIN - 1) % GAP == 0) {
238                                 writeCountVector(params->fo, counts);
239                                 for (int i = 0; i <= M; i++) {
240                                         params->pme_c[i] += counts[i] - 1;
241                                         params->pve_c[i] += (counts[i] - 1) * (counts[i] - 1);
242                                         params->pme_theta[i] += counts[i] / totc;
243                                 }
244                         }
245                 }
246
247                 if (verbose && ROUND % 100 == 0) { printf("Thread %d, ROUND %d is finished!\n", params->no, ROUND); }
248         }
249
250         return NULL;
251 }
252
253 void release() {
254 //      char inpF[STRLEN], command[STRLEN];
255         string line;
256
257         /* destroy attribute */
258         pthread_attr_destroy(&attr);
259         delete[] threads;
260
261         pme_c.assign(M + 1, 0);
262         pve_c.assign(M + 1, 0);
263         pme_theta.assign(M + 1, 0);
264         for (int i = 0; i < nThreads; i++) {
265                 fclose(paramsArray[i].fo);
266                 delete paramsArray[i].engine;
267                 for (int j = 0; j <= M; j++) {
268                         pme_c[j] += paramsArray[i].pme_c[j];
269                         pve_c[j] += paramsArray[i].pve_c[j];
270                         pme_theta[j] += paramsArray[i].pme_theta[j];
271                 }
272                 delete[] paramsArray[i].pme_c;
273                 delete[] paramsArray[i].pve_c;
274                 delete[] paramsArray[i].pme_theta;
275         }
276         delete[] paramsArray;
277
278
279         for (int i = 0; i <= M; i++) {
280                 pme_c[i] /= NSAMPLES;
281                 pve_c[i] = (pve_c[i] - NSAMPLES * pme_c[i] * pme_c[i]) / (NSAMPLES - 1);
282                 pme_theta[i] /= NSAMPLES;
283         }
284
285         /*
286         // combine files
287         FILE *fo = fopen(cvsF, "a");
288         for (int i = 1; i < nThreads; i++) {
289                 sprintf(inpF, "%s%d", cvsF, i);
290                 ifstream fin(inpF);
291                 while (getline(fin, line)) {
292                         fprintf(fo, "%s\n", line.c_str());
293                 }
294                 fin.close();
295                 sprintf(command, "rm -f %s", inpF);
296                 int status = system(command);
297                 general_assert(status == 0, "Fail to delete file " + cstrtos(inpF) + "!");
298         }
299         fclose(fo);
300         */
301 }
302
303 template<class ModelType>
304 void calcExpectedEffectiveLengths(ModelType& model) {
305         int lb, ub, span;
306         double *pdf = NULL, *cdf = NULL, *clen = NULL; // clen[i] = \sigma_{j=1}^{i}pdf[i]*(lb+i)
307   
308         model.getGLD().copyTo(pdf, cdf, lb, ub, span);
309         clen = new double[span + 1];
310         clen[0] = 0.0;
311         for (int i = 1; i <= span; i++) {
312                 clen[i] = clen[i - 1] + pdf[i] * (lb + i);
313         }
314
315         eel.assign(M + 1, 0.0);
316         for (int i = 1; i <= M; i++) {
317                 int totLen = refs.getRef(i).getTotLen();
318                 int fullLen = refs.getRef(i).getFullLen();
319                 int pos1 = max(min(totLen - fullLen + 1, ub) - lb, 0);
320                 int pos2 = max(min(totLen, ub) - lb, 0);
321
322                 if (pos2 == 0) { eel[i] = 0.0; continue; }
323     
324                 eel[i] = fullLen * cdf[pos1] + ((cdf[pos2] - cdf[pos1]) * (totLen + 1) - (clen[pos2] - clen[pos1]));
325                 assert(eel[i] >= 0);
326                 if (eel[i] < MINEEL) { eel[i] = 0.0; }
327         }
328   
329         delete[] pdf;
330         delete[] cdf;
331         delete[] clen;
332 }
333
334 template<class ModelType>
335 void writeEstimatedParameters(char* modelF, char* imdName) {
336         ModelType model;
337         double denom;
338         char outF[STRLEN];
339         FILE *fo;
340
341         model.read(modelF);
342
343         calcExpectedEffectiveLengths<ModelType>(model);
344
345         denom = pme_theta[0];
346         for (int i = 1; i <= M; i++)
347           if (eel[i] < EPSILON) pme_theta[i] = 0.0;
348           else denom += pme_theta[i];
349
350         general_assert(denom >= EPSILON, "No Expected Effective Length is no less than " + ftos(MINEEL, 6) + "?!");
351
352         for (int i = 0; i <= M; i++) pme_theta[i] /= denom;
353
354         denom = 0.0;
355         double *mw = model.getMW();
356         for (int i = 0; i <= M; i++) {
357           pme_theta[i] = (mw[i] < EPSILON ? 0.0 : pme_theta[i] / mw[i]);
358           denom += pme_theta[i];
359         }
360         assert(denom >= EPSILON);
361         for (int i = 0; i <= M; i++) pme_theta[i] /= denom;
362
363         //calculate tau values
364         double *tau = new double[M + 1];
365         memset(tau, 0, sizeof(double) * (M + 1));
366
367         denom = 0.0;
368         for (int i = 1; i <= M; i++) 
369           if (eel[i] > EPSILON) {
370             tau[i] = pme_theta[i] / eel[i];
371             denom += tau[i];
372           }
373
374         general_assert(denom >= EPSILON, "No alignable reads?!");
375
376         for (int i = 1; i <= M; i++) {
377                 tau[i] /= denom;
378         }
379
380         //isoform level results
381         sprintf(outF, "%s.iso_res", imdName);
382         fo = fopen(outF, "a");
383         general_assert(fo != NULL, "Cannot open " + cstrtos(outF) + "!");
384
385         for (int i = 1; i <= M; i++)
386                 fprintf(fo, "%.2f%c", pme_c[i], (i < M ? '\t' : '\n'));
387         for (int i = 1; i <= M; i++)
388                 fprintf(fo, "%.15g%c", tau[i], (i < M ? '\t' : '\n'));
389
390         fclose(fo);
391
392         //gene level results
393         sprintf(outF, "%s.gene_res", imdName);
394         fo = fopen(outF, "a");
395         general_assert(fo != NULL, "Cannot open " + cstrtos(outF) + "!");
396
397         for (int i = 0; i < m; i++) {
398                 double sumC = 0.0; //  sum of pme counts
399                 int b = gi.spAt(i), e = gi.spAt(i + 1);
400                 for (int j = b; j < e; j++) {
401                         sumC += pme_c[j];
402                 }
403                 fprintf(fo, "%.15g%c", sumC, (i < m - 1 ? '\t' : '\n'));
404         }
405         for (int i = 0; i < m; i++) {
406                 double sumT = 0.0; //  sum of tau values
407                 int b = gi.spAt(i), e = gi.spAt(i + 1);
408                 for (int j = b; j < e; j++) {
409                         sumT += tau[j];
410                 }
411                 fprintf(fo, "%.15g%c", sumT, (i < m - 1 ? '\t' : '\n'));
412         }
413         fclose(fo);
414
415         delete[] tau;
416
417         if (verbose) { printf("Gibbs based expression values are written!\n"); }
418 }
419
420 int main(int argc, char* argv[]) {
421         if (argc < 7) {
422                 printf("Usage: rsem-run-gibbs-multi reference_name sample_name sampleToken BURNIN NSAMPLES GAP [-p #Threads] [--var] [-q]\n");
423                 exit(-1);
424         }
425
426         BURNIN = atoi(argv[4]);
427         NSAMPLES = atoi(argv[5]);
428         GAP = atoi(argv[6]);
429         sprintf(imdName, "%s.temp/%s", argv[2], argv[3]);
430         sprintf(statName, "%s.stat/%s", argv[2], argv[3]);
431         load_data(argv[1], statName, imdName);
432
433         nThreads = 1;
434         var_opt = false;
435         quiet = false;
436
437         for (int i = 7; i < argc; i++) {
438                 if (!strcmp(argv[i], "-p")) nThreads = atoi(argv[i + 1]);
439                 if (!strcmp(argv[i], "--var")) var_opt = true;
440                 if (!strcmp(argv[i], "-q")) quiet = true;
441         }
442         verbose = !quiet;
443
444         assert(NSAMPLES > 1); // Otherwise, we cannot calculate posterior variance
445
446         if (nThreads > NSAMPLES) {
447                 nThreads = NSAMPLES;
448                 printf("Warning: Number of samples is less than number of threads! Change the number of threads to %d!\n", nThreads);
449         }
450
451         if (verbose) printf("Gibbs started!\n");
452
453         init();
454         for (int i = 0; i < nThreads; i++) {
455                 rc = pthread_create(&threads[i], &attr, Gibbs, (void*)(&paramsArray[i]));
456                 pthread_assert(rc, "pthread_create", "Cannot create thread " + itos(i) + " (numbered from 0)!");
457         }
458         for (int i = 0; i < nThreads; i++) {
459                 rc = pthread_join(threads[i], &status);
460                 pthread_assert(rc, "pthread_join", "Cannot join thread " + itos(i) + " (numbered from 0)!");
461         }
462         release();
463
464         if (verbose) printf("Gibbs finished!\n");
465
466         sprintf(modelF, "%s.model", statName);
467         FILE *fi = fopen(modelF, "r");
468         general_assert(fi != NULL, "Cannot open " + cstrtos(modelF) + "!");
469         assert(fscanf(fi, "%d", &model_type) == 1);
470         fclose(fi);
471
472         switch(model_type) {
473         case 0 : writeEstimatedParameters<SingleModel>(modelF, imdName); break;
474         case 1 : writeEstimatedParameters<SingleQModel>(modelF, imdName); break;
475         case 2 : writeEstimatedParameters<PairedEndModel>(modelF, imdName); break;
476         case 3 : writeEstimatedParameters<PairedEndQModel>(modelF, imdName); break;
477         }
478
479         if (var_opt) {
480                 char varF[STRLEN];
481
482                 sprintf(varF, "%s.var", statName);
483                 FILE *fo = fopen(varF, "w");
484                 general_assert(fo != NULL, "Cannot open " + cstrtos(varF) + "!");
485                 for (int i = 0; i < m; i++) {
486                         int b = gi.spAt(i), e = gi.spAt(i + 1), number_of_isoforms = e - b;
487                         for (int j = b; j < e; j++) {
488                                 fprintf(fo, "%s\t%d\t%.15g\t%.15g\n", refs.getRef(j).getName().c_str(), number_of_isoforms, pme_c[j], pve_c[j]);
489                         }
490                 }
491                 fclose(fo);
492         }
493
494         return 0;
495 }