]> git.donarmstrong.com Git - rsem.git/blob - Gibbs.cpp
Install the latest version of EBSeq from Bioconductor and if fails, try to install...
[rsem.git] / Gibbs.cpp
1 #include<cstdio>
2 #include<cstring>
3 #include<cstdlib>
4 #include<cassert>
5 #include<fstream>
6 #include<sstream>
7 #include<vector>
8 #include<pthread.h>
9
10 #include "utils.h"
11 #include "my_assert.h"
12 #include "sampling.h"
13
14 #include "Model.h"
15 #include "SingleModel.h"
16 #include "SingleQModel.h"
17 #include "PairedEndModel.h"
18 #include "PairedEndQModel.h"
19
20 #include "Refs.h"
21 #include "GroupInfo.h"
22
23 using namespace std;
24
25 struct Params {
26         int no, nsamples;
27         FILE *fo;
28         engine_type *engine;
29         double *pme_c, *pve_c; //posterior mean and variance vectors on counts
30   double *pme_tpm, *pme_fpkm;
31 };
32
33
34 struct Item {
35         int sid;
36         double conprb;
37
38         Item(int sid, double conprb) {
39                 this->sid = sid;
40                 this->conprb = conprb;
41         }
42 };
43
44 int nThreads;
45
46 int model_type;
47 int m, M;
48 READ_INT_TYPE N0, N1;
49 HIT_INT_TYPE nHits;
50 double totc;
51 int BURNIN, NSAMPLES, GAP;
52 char imdName[STRLEN], statName[STRLEN];
53 char thetaF[STRLEN], ofgF[STRLEN], groupF[STRLEN], refF[STRLEN], modelF[STRLEN];
54 char cvsF[STRLEN];
55
56 Refs refs;
57 GroupInfo gi;
58
59 vector<HIT_INT_TYPE> s;
60 vector<Item> hits;
61
62 vector<double> eel;
63 double *mw;
64
65 vector<double> pme_c, pve_c; //global posterior mean and variance vectors on counts
66 vector<double> pme_tpm, pme_fpkm;
67
68 bool var_opt;
69 bool quiet;
70
71 Params *paramsArray;
72 pthread_t *threads;
73 pthread_attr_t attr;
74 int rc;
75
76 void load_data(char* reference_name, char* statName, char* imdName) {
77         ifstream fin;
78         string line;
79         int tmpVal;
80
81         //load reference file
82         sprintf(refF, "%s.seq", reference_name);
83         refs.loadRefs(refF, 1);
84         M = refs.getM();
85
86         //load groupF
87         sprintf(groupF, "%s.grp", reference_name);
88         gi.load(groupF);
89         m = gi.getm();
90
91         //load ofgF;
92         sprintf(ofgF, "%s.ofg", imdName);
93         fin.open(ofgF);
94         general_assert(fin.is_open(), "Cannot open " + cstrtos(ofgF) + "!");
95         fin>>tmpVal>>N0;
96         general_assert(tmpVal == M, "M in " + cstrtos(ofgF) + " is not consistent with " + cstrtos(refF) + "!");
97         getline(fin, line);
98
99         s.clear(); hits.clear();
100         s.push_back(0);
101         while (getline(fin, line)) {
102                 istringstream strin(line);
103                 int sid;
104                 double conprb;
105
106                 while (strin>>sid>>conprb) {
107                         hits.push_back(Item(sid, conprb));
108                 }
109                 s.push_back(hits.size());
110         }
111         fin.close();
112
113         N1 = s.size() - 1;
114         nHits = hits.size();
115
116         totc = N0 + N1 + (M + 1);
117
118         if (verbose) { printf("Loading Data is finished!\n"); }
119 }
120
121 template<class ModelType>
122 void calcExpectedEffectiveLengths(ModelType& model) {
123         int lb, ub, span;
124         double *pdf = NULL, *cdf = NULL, *clen = NULL; // clen[i] = sigma_{j=1}^{i}pdf[i]*(lb+i)
125   
126         model.getGLD().copyTo(pdf, cdf, lb, ub, span);
127         clen = new double[span + 1];
128         clen[0] = 0.0;
129         for (int i = 1; i <= span; i++) {
130                 clen[i] = clen[i - 1] + pdf[i] * (lb + i);
131         }
132
133         eel.assign(M + 1, 0.0);
134         for (int i = 1; i <= M; i++) {
135                 int totLen = refs.getRef(i).getTotLen();
136                 int fullLen = refs.getRef(i).getFullLen();
137                 int pos1 = max(min(totLen - fullLen + 1, ub) - lb, 0);
138                 int pos2 = max(min(totLen, ub) - lb, 0);
139
140                 if (pos2 == 0) { eel[i] = 0.0; continue; }
141     
142                 eel[i] = fullLen * cdf[pos1] + ((cdf[pos2] - cdf[pos1]) * (totLen + 1) - (clen[pos2] - clen[pos1]));
143                 assert(eel[i] >= 0);
144                 if (eel[i] < MINEEL) { eel[i] = 0.0; }
145         }
146   
147         delete[] pdf;
148         delete[] cdf;
149         delete[] clen;
150 }
151
152 template<class ModelType>
153 void init_model_related(char* modelF) {
154         ModelType model;
155         model.read(modelF);
156
157         calcExpectedEffectiveLengths<ModelType>(model);
158         memcpy(mw, model.getMW(), sizeof(double) * (M + 1)); // otherwise, after exiting this procedure, mw becomes undefined
159 }
160
161 // assign threads
162 void init() {
163         int quotient, left;
164         char outF[STRLEN];
165
166         quotient = NSAMPLES / nThreads;
167         left = NSAMPLES % nThreads;
168
169         sprintf(cvsF, "%s.countvectors", imdName);
170         paramsArray = new Params[nThreads];
171         threads = new pthread_t[nThreads];
172
173         for (int i = 0; i < nThreads; i++) {
174                 paramsArray[i].no = i;
175
176                 paramsArray[i].nsamples = quotient;
177                 if (i < left) paramsArray[i].nsamples++;
178
179                 sprintf(outF, "%s%d", cvsF, i);
180                 paramsArray[i].fo = fopen(outF, "w");
181
182                 paramsArray[i].engine = engineFactory::new_engine();
183                 paramsArray[i].pme_c = new double[M + 1];
184                 memset(paramsArray[i].pme_c, 0, sizeof(double) * (M + 1));
185                 paramsArray[i].pve_c = new double[M + 1];
186                 memset(paramsArray[i].pve_c, 0, sizeof(double) * (M + 1));
187                 paramsArray[i].pme_tpm = new double[M + 1];
188                 memset(paramsArray[i].pme_tpm, 0, sizeof(double) * (M + 1));
189                 paramsArray[i].pme_fpkm = new double[M + 1];
190                 memset(paramsArray[i].pme_fpkm, 0, sizeof(double) * (M + 1));
191         }
192
193         /* set thread attribute to be joinable */
194         pthread_attr_init(&attr);
195         pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
196
197         if (verbose) { printf("Initialization finished!\n"); }
198 }
199
200 //sample theta from Dir(1)
201 void sampleTheta(engine_type& engine, vector<double>& theta) {
202         gamma_dist gm(1);
203         gamma_generator gmg(engine, gm);
204         double denom;
205
206         theta.assign(M + 1, 0);
207         denom = 0.0;
208         for (int i = 0; i <= M; i++) {
209                 theta[i] = gmg();
210                 denom += theta[i];
211         }
212         assert(denom > EPSILON);
213         for (int i = 0; i <= M; i++) theta[i] /= denom;
214 }
215
216 void writeCountVector(FILE* fo, vector<int>& counts) {
217         for (int i = 0; i < M; i++) {
218                 fprintf(fo, "%d ", counts[i]);
219         }
220         fprintf(fo, "%d\n", counts[M]);
221 }
222
223 void polishTheta(vector<double>& theta, const vector<double>& eel, const double* mw) {
224         double sum = 0.0;
225
226         /* The reason that for noise gene, mw value is 1 is :
227          * currently, all masked positions are for poly(A) sites, which in theory should be filtered out.
228          * So the theta0 does not containing reads from any masked position
229          */
230
231         for (int i = 0; i <= M; i++) {
232                 // i == 0, mw[i] == 1
233                 if (i > 0 && (mw[i] < EPSILON || eel[i] < EPSILON)) {
234                         theta[i] = 0.0;
235                         continue;
236                 }
237                 theta[i] = theta[i] / mw[i];
238                 sum += theta[i];
239         }
240         // currently is OK, since no transcript should be masked totally, only the poly(A) tail related part will be masked
241         general_assert(sum >= EPSILON, "No effective length is no less than" + ftos(MINEEL, 6) + " !");
242         for (int i = 0; i <= M; i++) theta[i] /= sum;
243 }
244
245 void calcExpressionValues(const vector<double>& theta, const vector<double>& eel, vector<double>& tpm, vector<double>& fpkm) {
246         double denom;
247         vector<double> frac;
248
249         //calculate fraction of count over all mappabile reads
250         denom = 0.0;
251         frac.assign(M + 1, 0.0);
252         for (int i = 1; i <= M; i++) 
253           if (eel[i] >= EPSILON) {
254             frac[i] = theta[i];
255             denom += frac[i];
256           }
257         general_assert(denom >= EPSILON, "No alignable reads?!");
258         for (int i = 1; i <= M; i++) frac[i] /= denom;
259   
260         //calculate FPKM
261         fpkm.assign(M + 1, 0.0);
262         for (int i = 1; i <= M; i++)
263                 if (eel[i] >= EPSILON) fpkm[i] = frac[i] * 1e9 / eel[i];
264
265         //calculate TPM
266         tpm.assign(M + 1, 0.0);
267         denom = 0.0;
268         for (int i = 1; i <= M; i++) denom += fpkm[i];
269         for (int i = 1; i <= M; i++) tpm[i] = fpkm[i] / denom * 1e6;  
270 }
271
272 void* Gibbs(void* arg) {
273         int CHAINLEN;
274         HIT_INT_TYPE len, fr, to;
275         Params *params = (Params*)arg;
276
277         vector<double> theta, tpm, fpkm;
278         vector<int> z, counts;
279         vector<double> arr;
280
281         uniform01 rg(*params->engine);
282
283         // generate initial state
284         sampleTheta(*params->engine, theta);
285
286         z.assign(N1, 0);
287
288         counts.assign(M + 1, 1); // 1 pseudo count
289         counts[0] += N0;
290
291         for (READ_INT_TYPE i = 0; i < N1; i++) {
292                 fr = s[i]; to = s[i + 1];
293                 len = to - fr;
294                 arr.assign(len, 0);
295                 for (HIT_INT_TYPE j = fr; j < to; j++) {
296                         arr[j - fr] = theta[hits[j].sid] * hits[j].conprb;
297                         if (j > fr) arr[j - fr] += arr[j - fr - 1];  // cumulative
298                 }
299                 z[i] = hits[fr + sample(rg, arr, len)].sid;
300                 ++counts[z[i]];
301         }
302
303         // Gibbs sampling
304         CHAINLEN = 1 + (params->nsamples - 1) * GAP;
305         for (int ROUND = 1; ROUND <= BURNIN + CHAINLEN; ROUND++) {
306
307                 for (READ_INT_TYPE i = 0; i < N1; i++) {
308                         --counts[z[i]];
309                         fr = s[i]; to = s[i + 1]; len = to - fr;
310                         arr.assign(len, 0);
311                         for (HIT_INT_TYPE j = fr; j < to; j++) {
312                                 arr[j - fr] = counts[hits[j].sid] * hits[j].conprb;
313                                 if (j > fr) arr[j - fr] += arr[j - fr - 1]; //cumulative
314                         }
315                         z[i] = hits[fr + sample(rg, arr, len)].sid;
316                         ++counts[z[i]];
317                 }
318
319                 if (ROUND > BURNIN) {
320                         if ((ROUND - BURNIN - 1) % GAP == 0) {
321                                 writeCountVector(params->fo, counts);
322                                 for (int i = 0; i <= M; i++) theta[i] = counts[i] / totc;
323                                 polishTheta(theta, eel, mw);
324                                 calcExpressionValues(theta, eel, tpm, fpkm);
325                                 for (int i = 0; i <= M; i++) {
326                                         params->pme_c[i] += counts[i] - 1;
327                                         params->pve_c[i] += (counts[i] - 1) * (counts[i] - 1);
328                                         params->pme_tpm[i] += tpm[i];
329                                         params->pme_fpkm[i] += fpkm[i];
330                                 }
331                         }
332                 }
333
334                 if (verbose && ROUND % 100 == 0) { printf("Thread %d, ROUND %d is finished!\n", params->no, ROUND); }
335         }
336
337         return NULL;
338 }
339
340 void release() {
341 //      char inpF[STRLEN], command[STRLEN];
342         string line;
343
344         /* destroy attribute */
345         pthread_attr_destroy(&attr);
346         delete[] threads;
347
348         pme_c.assign(M + 1, 0);
349         pve_c.assign(M + 1, 0);
350         pme_tpm.assign(M + 1, 0);
351         pme_fpkm.assign(M + 1, 0);
352         for (int i = 0; i < nThreads; i++) {
353                 fclose(paramsArray[i].fo);
354                 delete paramsArray[i].engine;
355                 for (int j = 0; j <= M; j++) {
356                         pme_c[j] += paramsArray[i].pme_c[j];
357                         pve_c[j] += paramsArray[i].pve_c[j];
358                         pme_tpm[j] += paramsArray[i].pme_tpm[j];
359                         pme_fpkm[j] += paramsArray[i].pme_fpkm[j];
360                 }
361                 delete[] paramsArray[i].pme_c;
362                 delete[] paramsArray[i].pve_c;
363                 delete[] paramsArray[i].pme_tpm;
364                 delete[] paramsArray[i].pme_fpkm;
365         }
366         delete[] paramsArray;
367
368
369         for (int i = 0; i <= M; i++) {
370                 pme_c[i] /= NSAMPLES;
371                 pve_c[i] = (pve_c[i] - NSAMPLES * pme_c[i] * pme_c[i]) / (NSAMPLES - 1);
372                 pme_tpm[i] /= NSAMPLES;
373                 pme_fpkm[i] /= NSAMPLES;
374         }
375 }
376
377 void writeResults(char* imdName) {
378         char outF[STRLEN];
379         FILE *fo;
380
381         vector<double> isopct;
382         vector<double> gene_counts, gene_tpm, gene_fpkm;
383
384         //calculate IsoPct, etc.
385         isopct.assign(M + 1, 0.0);
386         gene_counts.assign(m, 0.0); gene_tpm.assign(m, 0.0); gene_fpkm.assign(m, 0.0);
387
388         for (int i = 0; i < m; i++) {
389                 int b = gi.spAt(i), e = gi.spAt(i + 1);
390                 for (int j = b; j < e; j++) {
391                         gene_counts[i] += pme_c[j];
392                         gene_tpm[i] += pme_tpm[j];
393                         gene_fpkm[i] += pme_fpkm[j];
394                 }
395                 if (gene_tpm[i] < EPSILON) continue;
396                 for (int j = b; j < e; j++)
397                         isopct[j] = pme_tpm[j] / gene_tpm[i];
398         }
399
400         //isoform level results
401         sprintf(outF, "%s.iso_res", imdName);
402         fo = fopen(outF, "a");
403         general_assert(fo != NULL, "Cannot open " + cstrtos(outF) + "!");
404
405         for (int i = 1; i <= M; i++)
406                 fprintf(fo, "%.2f%c", pme_c[i], (i < M ? '\t' : '\n'));
407         for (int i = 1; i <= M; i++)
408                 fprintf(fo, "%.2f%c", pme_tpm[i], (i < M ? '\t' : '\n'));
409         for (int i = 1; i <= M; i++)
410                 fprintf(fo, "%.2f%c", pme_fpkm[i], (i < M ? '\t' : '\n'));
411         for (int i = 1; i <= M; i++)
412                 fprintf(fo, "%.2f%c", isopct[i] * 1e2, (i < M ? '\t' : '\n'));
413         fclose(fo);
414
415         //gene level results
416         sprintf(outF, "%s.gene_res", imdName);
417         fo = fopen(outF, "a");
418         general_assert(fo != NULL, "Cannot open " + cstrtos(outF) + "!");
419
420         for (int i = 0; i < m; i++)
421                 fprintf(fo, "%.2f%c", gene_counts[i], (i < m - 1 ? '\t' : '\n'));
422         for (int i = 0; i < m; i++)
423                 fprintf(fo, "%.2f%c", gene_tpm[i], (i < m - 1 ? '\t' : '\n'));
424         for (int i = 0; i < m; i++)
425                 fprintf(fo, "%.2f%c", gene_fpkm[i], (i < m - 1 ? '\t' : '\n'));
426         fclose(fo);
427
428         if (verbose) { printf("Gibbs based expression values are written!\n"); }
429 }
430
431 int main(int argc, char* argv[]) {
432         if (argc < 7) {
433                 printf("Usage: rsem-run-gibbs reference_name imdName statName BURNIN NSAMPLES GAP [-p #Threads] [--var] [-q]\n");
434                 exit(-1);
435         }
436
437         strcpy(imdName, argv[2]);
438         strcpy(statName, argv[3]);
439
440         BURNIN = atoi(argv[4]);
441         NSAMPLES = atoi(argv[5]);
442         GAP = atoi(argv[6]);
443
444         nThreads = 1;
445         var_opt = false;
446         quiet = false;
447
448         for (int i = 7; i < argc; i++) {
449                 if (!strcmp(argv[i], "-p")) nThreads = atoi(argv[i + 1]);
450                 if (!strcmp(argv[i], "--var")) var_opt = true;
451                 if (!strcmp(argv[i], "-q")) quiet = true;
452         }
453         verbose = !quiet;
454
455         assert(NSAMPLES > 1); // Otherwise, we cannot calculate posterior variance
456
457         if (nThreads > NSAMPLES) {
458                 nThreads = NSAMPLES;
459                 printf("Warning: Number of samples is less than number of threads! Change the number of threads to %d!\n", nThreads);
460         }
461
462         load_data(argv[1], statName, imdName);
463
464         sprintf(modelF, "%s.model", statName);
465         FILE *fi = fopen(modelF, "r");
466         general_assert(fi != NULL, "Cannot open " + cstrtos(modelF) + "!");
467         assert(fscanf(fi, "%d", &model_type) == 1);
468         fclose(fi);
469
470         mw = new double[M + 1]; // make an extra copy
471
472         switch(model_type) {
473         case 0 : init_model_related<SingleModel>(modelF); break;
474         case 1 : init_model_related<SingleQModel>(modelF); break;
475         case 2 : init_model_related<PairedEndModel>(modelF); break;
476         case 3 : init_model_related<PairedEndQModel>(modelF); break;
477         }
478
479         if (verbose) printf("Gibbs started!\n");
480
481         init();
482         for (int i = 0; i < nThreads; i++) {
483                 rc = pthread_create(&threads[i], &attr, Gibbs, (void*)(&paramsArray[i]));
484                 pthread_assert(rc, "pthread_create", "Cannot create thread " + itos(i) + " (numbered from 0)!");
485         }
486         for (int i = 0; i < nThreads; i++) {
487                 rc = pthread_join(threads[i], NULL);
488                 pthread_assert(rc, "pthread_join", "Cannot join thread " + itos(i) + " (numbered from 0)!");
489         }
490         release();
491
492         if (verbose) printf("Gibbs finished!\n");
493         
494         writeResults(imdName);
495
496         if (var_opt) {
497                 char varF[STRLEN];
498
499                 sprintf(varF, "%s.var", statName);
500                 FILE *fo = fopen(varF, "w");
501                 general_assert(fo != NULL, "Cannot open " + cstrtos(varF) + "!");
502                 for (int i = 0; i < m; i++) {
503                         int b = gi.spAt(i), e = gi.spAt(i + 1), number_of_isoforms = e - b;
504                         for (int j = b; j < e; j++) {
505                                 fprintf(fo, "%s\t%d\t%.15g\t%.15g\n", refs.getRef(j).getName().c_str(), number_of_isoforms, pme_c[j], pve_c[j]);
506                         }
507                 }
508                 fclose(fo);
509         }
510
511         delete mw; // delete the copy
512
513         return 0;
514 }