]> git.donarmstrong.com Git - rsem.git/blob - Gibbs.cpp
Allowed > 2^31 hits
[rsem.git] / Gibbs.cpp
1 #include<cstdio>
2 #include<cstring>
3 #include<cstdlib>
4 #include<cassert>
5 #include<fstream>
6 #include<sstream>
7 #include<vector>
8 #include<pthread.h>
9
10 #include "utils.h"
11 #include "my_assert.h"
12 #include "sampling.h"
13
14 #include "Model.h"
15 #include "SingleModel.h"
16 #include "SingleQModel.h"
17 #include "PairedEndModel.h"
18 #include "PairedEndQModel.h"
19
20 #include "Refs.h"
21 #include "GroupInfo.h"
22
23 using namespace std;
24
25 struct Params {
26         int no, nsamples;
27         FILE *fo;
28         engine_type *engine;
29         double *pme_c, *pve_c; //posterior mean and variance vectors on counts
30         double *pme_theta;
31 };
32
33
34 struct Item {
35         int sid;
36         double conprb;
37
38         Item(int sid, double conprb) {
39                 this->sid = sid;
40                 this->conprb = conprb;
41         }
42 };
43
44 int nThreads;
45
46 int model_type;
47 int m, M;
48 READ_INT_TYPE N0, N1;
49 HIT_INT_TYPE nHits;
50 double totc;
51 int BURNIN, NSAMPLES, GAP;
52 char imdName[STRLEN], statName[STRLEN];
53 char thetaF[STRLEN], ofgF[STRLEN], groupF[STRLEN], refF[STRLEN], modelF[STRLEN];
54 char cvsF[STRLEN];
55
56 Refs refs;
57 GroupInfo gi;
58
59 vector<HIT_INT_TYPE> s;
60 vector<Item> hits;
61
62 vector<double> theta;
63
64 vector<double> pme_c, pve_c; //global posterior mean and variance vectors on counts
65 vector<double> pme_theta, eel;
66
67 bool var_opt;
68 bool quiet;
69
70 Params *paramsArray;
71 pthread_t *threads;
72 pthread_attr_t attr;
73 void *status;
74 int rc;
75
76 void load_data(char* reference_name, char* statName, char* imdName) {
77         ifstream fin;
78         string line;
79         int tmpVal;
80
81         //load reference file
82         sprintf(refF, "%s.seq", reference_name);
83         refs.loadRefs(refF, 1);
84         M = refs.getM();
85
86         //load groupF
87         sprintf(groupF, "%s.grp", reference_name);
88         gi.load(groupF);
89         m = gi.getm();
90
91         //load thetaF
92         sprintf(thetaF, "%s.theta",statName);
93         fin.open(thetaF);
94         general_assert(fin.is_open(), "Cannot open " + cstrtos(thetaF) + "!");
95         fin>>tmpVal;
96         general_assert(tmpVal == M + 1, "Number of transcripts is not consistent in " + cstrtos(refF) + " and " + cstrtos(thetaF) + "!");
97         theta.assign(M + 1, 0);
98         for (int i = 0; i <= M; i++) fin>>theta[i];
99         fin.close();
100
101         //load ofgF;
102         sprintf(ofgF, "%s.ofg", imdName);
103         fin.open(ofgF);
104         general_assert(fin.is_open(), "Cannot open " + cstrtos(ofgF) + "!");
105         fin>>tmpVal>>N0;
106         general_assert(tmpVal == M, "M in " + cstrtos(ofgF) + " is not consistent with " + cstrtos(refF) + "!");
107         getline(fin, line);
108
109         s.clear(); hits.clear();
110         s.push_back(0);
111         while (getline(fin, line)) {
112                 istringstream strin(line);
113                 int sid;
114                 double conprb;
115
116                 while (strin>>sid>>conprb) {
117                         hits.push_back(Item(sid, conprb));
118                 }
119                 s.push_back(hits.size());
120         }
121         fin.close();
122
123         N1 = s.size() - 1;
124         nHits = hits.size();
125
126         totc = N0 + N1 + (M + 1);
127
128         if (verbose) { printf("Loading Data is finished!\n"); }
129 }
130
131 // assign threads
132 void init() {
133         int quotient, left;
134         char outF[STRLEN];
135
136         quotient = NSAMPLES / nThreads;
137         left = NSAMPLES % nThreads;
138
139         sprintf(cvsF, "%s.countvectors", imdName);
140         paramsArray = new Params[nThreads];
141         threads = new pthread_t[nThreads];
142
143         for (int i = 0; i < nThreads; i++) {
144                 paramsArray[i].no = i;
145
146                 paramsArray[i].nsamples = quotient;
147                 if (i < left) paramsArray[i].nsamples++;
148
149                 sprintf(outF, "%s%d", cvsF, i);
150                 paramsArray[i].fo = fopen(outF, "w");
151
152                 paramsArray[i].engine = engineFactory::new_engine();
153                 paramsArray[i].pme_c = new double[M + 1];
154                 memset(paramsArray[i].pme_c, 0, sizeof(double) * (M + 1));
155                 paramsArray[i].pve_c = new double[M + 1];
156                 memset(paramsArray[i].pve_c, 0, sizeof(double) * (M + 1));
157                 paramsArray[i].pme_theta = new double[M + 1];
158                 memset(paramsArray[i].pme_theta, 0, sizeof(double) * (M + 1));
159         }
160
161         /* set thread attribute to be joinable */
162         pthread_attr_init(&attr);
163         pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
164
165         if (verbose) { printf("Initialization finished!\n"); }
166 }
167
168 //sample theta from Dir(1)
169 void sampleTheta(engine_type& engine, vector<double>& theta) {
170         gamma_dist gm(1);
171         gamma_generator gmg(engine, gm);
172         double denom;
173
174         theta.assign(M + 1, 0);
175         denom = 0.0;
176         for (int i = 0; i <= M; i++) {
177                 theta[i] = gmg();
178                 denom += theta[i];
179         }
180         assert(denom > EPSILON);
181         for (int i = 0; i <= M; i++) theta[i] /= denom;
182 }
183
184 void writeCountVector(FILE* fo, vector<int>& counts) {
185         for (int i = 0; i < M; i++) {
186                 fprintf(fo, "%d ", counts[i]);
187         }
188         fprintf(fo, "%d\n", counts[M]);
189 }
190
191 void* Gibbs(void* arg) {
192         int CHAINLEN;
193         HIT_INT_TYPE len, fr, to;
194         Params *params = (Params*)arg;
195
196         vector<double> theta;
197         vector<int> z, counts;
198         vector<double> arr;
199
200         uniform01 rg(*params->engine);
201
202         // generate initial state
203         sampleTheta(*params->engine, theta);
204
205         z.assign(N1, 0);
206
207         counts.assign(M + 1, 1); // 1 pseudo count
208         counts[0] += N0;
209
210         for (READ_INT_TYPE i = 0; i < N1; i++) {
211                 fr = s[i]; to = s[i + 1];
212                 len = to - fr;
213                 arr.assign(len, 0);
214                 for (HIT_INT_TYPE j = fr; j < to; j++) {
215                         arr[j - fr] = theta[hits[j].sid] * hits[j].conprb;
216                         if (j > fr) arr[j - fr] += arr[j - fr - 1];  // cumulative
217                 }
218                 z[i] = hits[fr + sample(rg, arr, len)].sid;
219                 ++counts[z[i]];
220         }
221
222         // Gibbs sampling
223         CHAINLEN = 1 + (params->nsamples - 1) * GAP;
224         for (int ROUND = 1; ROUND <= BURNIN + CHAINLEN; ROUND++) {
225
226                 for (READ_INT_TYPE i = 0; i < N1; i++) {
227                         --counts[z[i]];
228                         fr = s[i]; to = s[i + 1]; len = to - fr;
229                         arr.assign(len, 0);
230                         for (HIT_INT_TYPE j = fr; j < to; j++) {
231                                 arr[j - fr] = counts[hits[j].sid] * hits[j].conprb;
232                                 if (j > fr) arr[j - fr] += arr[j - fr - 1]; //cumulative
233                         }
234                         z[i] = hits[fr + sample(rg, arr, len)].sid;
235                         ++counts[z[i]];
236                 }
237
238                 if (ROUND > BURNIN) {
239                         if ((ROUND - BURNIN - 1) % GAP == 0) {
240                                 writeCountVector(params->fo, counts);
241                                 for (int i = 0; i <= M; i++) {
242                                         params->pme_c[i] += counts[i] - 1;
243                                         params->pve_c[i] += (counts[i] - 1) * (counts[i] - 1);
244                                         params->pme_theta[i] += counts[i] / totc;
245                                 }
246                         }
247                 }
248
249                 if (verbose && ROUND % 100 == 0) { printf("Thread %d, ROUND %d is finished!\n", params->no, ROUND); }
250         }
251
252         return NULL;
253 }
254
255 void release() {
256 //      char inpF[STRLEN], command[STRLEN];
257         string line;
258
259         /* destroy attribute */
260         pthread_attr_destroy(&attr);
261         delete[] threads;
262
263         pme_c.assign(M + 1, 0);
264         pve_c.assign(M + 1, 0);
265         pme_theta.assign(M + 1, 0);
266         for (int i = 0; i < nThreads; i++) {
267                 fclose(paramsArray[i].fo);
268                 delete paramsArray[i].engine;
269                 for (int j = 0; j <= M; j++) {
270                         pme_c[j] += paramsArray[i].pme_c[j];
271                         pve_c[j] += paramsArray[i].pve_c[j];
272                         pme_theta[j] += paramsArray[i].pme_theta[j];
273                 }
274                 delete[] paramsArray[i].pme_c;
275                 delete[] paramsArray[i].pve_c;
276                 delete[] paramsArray[i].pme_theta;
277         }
278         delete[] paramsArray;
279
280
281         for (int i = 0; i <= M; i++) {
282                 pme_c[i] /= NSAMPLES;
283                 pve_c[i] = (pve_c[i] - NSAMPLES * pme_c[i] * pme_c[i]) / (NSAMPLES - 1);
284                 pme_theta[i] /= NSAMPLES;
285         }
286 }
287
288 template<class ModelType>
289 void calcExpectedEffectiveLengths(ModelType& model) {
290         int lb, ub, span;
291         double *pdf = NULL, *cdf = NULL, *clen = NULL; // clen[i] = \sigma_{j=1}^{i}pdf[i]*(lb+i)
292   
293         model.getGLD().copyTo(pdf, cdf, lb, ub, span);
294         clen = new double[span + 1];
295         clen[0] = 0.0;
296         for (int i = 1; i <= span; i++) {
297                 clen[i] = clen[i - 1] + pdf[i] * (lb + i);
298         }
299
300         eel.assign(M + 1, 0.0);
301         for (int i = 1; i <= M; i++) {
302                 int totLen = refs.getRef(i).getTotLen();
303                 int fullLen = refs.getRef(i).getFullLen();
304                 int pos1 = max(min(totLen - fullLen + 1, ub) - lb, 0);
305                 int pos2 = max(min(totLen, ub) - lb, 0);
306
307                 if (pos2 == 0) { eel[i] = 0.0; continue; }
308     
309                 eel[i] = fullLen * cdf[pos1] + ((cdf[pos2] - cdf[pos1]) * (totLen + 1) - (clen[pos2] - clen[pos1]));
310                 assert(eel[i] >= 0);
311                 if (eel[i] < MINEEL) { eel[i] = 0.0; }
312         }
313   
314         delete[] pdf;
315         delete[] cdf;
316         delete[] clen;
317 }
318
319 template<class ModelType>
320 void writeEstimatedParameters(char* modelF, char* imdName) {
321         ModelType model;
322         double denom;
323         char outF[STRLEN];
324         FILE *fo;
325
326         model.read(modelF);
327
328         calcExpectedEffectiveLengths<ModelType>(model);
329
330         denom = pme_theta[0];
331         for (int i = 1; i <= M; i++)
332           if (eel[i] < EPSILON) pme_theta[i] = 0.0;
333           else denom += pme_theta[i];
334
335         general_assert(denom >= EPSILON, "No Expected Effective Length is no less than " + ftos(MINEEL, 6) + "?!");
336
337         for (int i = 0; i <= M; i++) pme_theta[i] /= denom;
338
339         denom = 0.0;
340         double *mw = model.getMW();
341         for (int i = 0; i <= M; i++) {
342           pme_theta[i] = (mw[i] < EPSILON ? 0.0 : pme_theta[i] / mw[i]);
343           denom += pme_theta[i];
344         }
345         assert(denom >= EPSILON);
346         for (int i = 0; i <= M; i++) pme_theta[i] /= denom;
347
348         //calculate tau values
349         double *tau = new double[M + 1];
350         memset(tau, 0, sizeof(double) * (M + 1));
351
352         denom = 0.0;
353         for (int i = 1; i <= M; i++) 
354           if (eel[i] > EPSILON) {
355             tau[i] = pme_theta[i] / eel[i];
356             denom += tau[i];
357           }
358
359         general_assert(denom >= EPSILON, "No alignable reads?!");
360
361         for (int i = 1; i <= M; i++) {
362                 tau[i] /= denom;
363         }
364
365         //isoform level results
366         sprintf(outF, "%s.iso_res", imdName);
367         fo = fopen(outF, "a");
368         general_assert(fo != NULL, "Cannot open " + cstrtos(outF) + "!");
369
370         for (int i = 1; i <= M; i++)
371                 fprintf(fo, "%.2f%c", pme_c[i], (i < M ? '\t' : '\n'));
372         for (int i = 1; i <= M; i++)
373                 fprintf(fo, "%.15g%c", tau[i], (i < M ? '\t' : '\n'));
374
375         fclose(fo);
376
377         //gene level results
378         sprintf(outF, "%s.gene_res", imdName);
379         fo = fopen(outF, "a");
380         general_assert(fo != NULL, "Cannot open " + cstrtos(outF) + "!");
381
382         for (int i = 0; i < m; i++) {
383                 double sumC = 0.0; //  sum of pme counts
384                 int b = gi.spAt(i), e = gi.spAt(i + 1);
385                 for (int j = b; j < e; j++) {
386                         sumC += pme_c[j];
387                 }
388                 fprintf(fo, "%.15g%c", sumC, (i < m - 1 ? '\t' : '\n'));
389         }
390         for (int i = 0; i < m; i++) {
391                 double sumT = 0.0; //  sum of tau values
392                 int b = gi.spAt(i), e = gi.spAt(i + 1);
393                 for (int j = b; j < e; j++) {
394                         sumT += tau[j];
395                 }
396                 fprintf(fo, "%.15g%c", sumT, (i < m - 1 ? '\t' : '\n'));
397         }
398         fclose(fo);
399
400         delete[] tau;
401
402         if (verbose) { printf("Gibbs based expression values are written!\n"); }
403 }
404
405 int main(int argc, char* argv[]) {
406         if (argc < 7) {
407                 printf("Usage: rsem-run-gibbs-multi reference_name sample_name sampleToken BURNIN NSAMPLES GAP [-p #Threads] [--var] [-q]\n");
408                 exit(-1);
409         }
410
411         BURNIN = atoi(argv[4]);
412         NSAMPLES = atoi(argv[5]);
413         GAP = atoi(argv[6]);
414         sprintf(imdName, "%s.temp/%s", argv[2], argv[3]);
415         sprintf(statName, "%s.stat/%s", argv[2], argv[3]);
416         load_data(argv[1], statName, imdName);
417
418         nThreads = 1;
419         var_opt = false;
420         quiet = false;
421
422         for (int i = 7; i < argc; i++) {
423                 if (!strcmp(argv[i], "-p")) nThreads = atoi(argv[i + 1]);
424                 if (!strcmp(argv[i], "--var")) var_opt = true;
425                 if (!strcmp(argv[i], "-q")) quiet = true;
426         }
427         verbose = !quiet;
428
429         assert(NSAMPLES > 1); // Otherwise, we cannot calculate posterior variance
430
431         if (nThreads > NSAMPLES) {
432                 nThreads = NSAMPLES;
433                 printf("Warning: Number of samples is less than number of threads! Change the number of threads to %d!\n", nThreads);
434         }
435
436         if (verbose) printf("Gibbs started!\n");
437
438         init();
439         for (int i = 0; i < nThreads; i++) {
440                 rc = pthread_create(&threads[i], &attr, Gibbs, (void*)(&paramsArray[i]));
441                 pthread_assert(rc, "pthread_create", "Cannot create thread " + itos(i) + " (numbered from 0)!");
442         }
443         for (int i = 0; i < nThreads; i++) {
444                 rc = pthread_join(threads[i], &status);
445                 pthread_assert(rc, "pthread_join", "Cannot join thread " + itos(i) + " (numbered from 0)!");
446         }
447         release();
448
449         if (verbose) printf("Gibbs finished!\n");
450
451         sprintf(modelF, "%s.model", statName);
452         FILE *fi = fopen(modelF, "r");
453         general_assert(fi != NULL, "Cannot open " + cstrtos(modelF) + "!");
454         assert(fscanf(fi, "%d", &model_type) == 1);
455         fclose(fi);
456
457         switch(model_type) {
458         case 0 : writeEstimatedParameters<SingleModel>(modelF, imdName); break;
459         case 1 : writeEstimatedParameters<SingleQModel>(modelF, imdName); break;
460         case 2 : writeEstimatedParameters<PairedEndModel>(modelF, imdName); break;
461         case 3 : writeEstimatedParameters<PairedEndQModel>(modelF, imdName); break;
462         }
463
464         if (var_opt) {
465                 char varF[STRLEN];
466
467                 sprintf(varF, "%s.var", statName);
468                 FILE *fo = fopen(varF, "w");
469                 general_assert(fo != NULL, "Cannot open " + cstrtos(varF) + "!");
470                 for (int i = 0; i < m; i++) {
471                         int b = gi.spAt(i), e = gi.spAt(i + 1), number_of_isoforms = e - b;
472                         for (int j = b; j < e; j++) {
473                                 fprintf(fo, "%s\t%d\t%.15g\t%.15g\n", refs.getRef(j).getName().c_str(), number_of_isoforms, pme_c[j], pve_c[j]);
474                         }
475                 }
476                 fclose(fo);
477         }
478
479         return 0;
480 }