]> git.donarmstrong.com Git - rsem.git/blob - Gibbs.cpp
Fixed a bug for calling pthread_join
[rsem.git] / Gibbs.cpp
1 #include<cstdio>
2 #include<cstring>
3 #include<cstdlib>
4 #include<cassert>
5 #include<fstream>
6 #include<sstream>
7 #include<vector>
8 #include<pthread.h>
9
10 #include "utils.h"
11 #include "my_assert.h"
12 #include "sampling.h"
13
14 #include "Model.h"
15 #include "SingleModel.h"
16 #include "SingleQModel.h"
17 #include "PairedEndModel.h"
18 #include "PairedEndQModel.h"
19
20 #include "Refs.h"
21 #include "GroupInfo.h"
22
23 using namespace std;
24
25 struct Params {
26         int no, nsamples;
27         FILE *fo;
28         engine_type *engine;
29         double *pme_c, *pve_c; //posterior mean and variance vectors on counts
30         double *pme_theta;
31 };
32
33
34 struct Item {
35         int sid;
36         double conprb;
37
38         Item(int sid, double conprb) {
39                 this->sid = sid;
40                 this->conprb = conprb;
41         }
42 };
43
44 int nThreads;
45
46 int model_type;
47 int m, M;
48 READ_INT_TYPE N0, N1;
49 HIT_INT_TYPE nHits;
50 double totc;
51 int BURNIN, NSAMPLES, GAP;
52 char imdName[STRLEN], statName[STRLEN];
53 char thetaF[STRLEN], ofgF[STRLEN], groupF[STRLEN], refF[STRLEN], modelF[STRLEN];
54 char cvsF[STRLEN];
55
56 Refs refs;
57 GroupInfo gi;
58
59 vector<HIT_INT_TYPE> s;
60 vector<Item> hits;
61
62 vector<double> theta;
63
64 vector<double> pme_c, pve_c; //global posterior mean and variance vectors on counts
65 vector<double> pme_theta, eel;
66
67 bool var_opt;
68 bool quiet;
69
70 Params *paramsArray;
71 pthread_t *threads;
72 pthread_attr_t attr;
73 int rc;
74
75 void load_data(char* reference_name, char* statName, char* imdName) {
76         ifstream fin;
77         string line;
78         int tmpVal;
79
80         //load reference file
81         sprintf(refF, "%s.seq", reference_name);
82         refs.loadRefs(refF, 1);
83         M = refs.getM();
84
85         //load groupF
86         sprintf(groupF, "%s.grp", reference_name);
87         gi.load(groupF);
88         m = gi.getm();
89
90         //load thetaF
91         sprintf(thetaF, "%s.theta",statName);
92         fin.open(thetaF);
93         general_assert(fin.is_open(), "Cannot open " + cstrtos(thetaF) + "!");
94         fin>>tmpVal;
95         general_assert(tmpVal == M + 1, "Number of transcripts is not consistent in " + cstrtos(refF) + " and " + cstrtos(thetaF) + "!");
96         theta.assign(M + 1, 0);
97         for (int i = 0; i <= M; i++) fin>>theta[i];
98         fin.close();
99
100         //load ofgF;
101         sprintf(ofgF, "%s.ofg", imdName);
102         fin.open(ofgF);
103         general_assert(fin.is_open(), "Cannot open " + cstrtos(ofgF) + "!");
104         fin>>tmpVal>>N0;
105         general_assert(tmpVal == M, "M in " + cstrtos(ofgF) + " is not consistent with " + cstrtos(refF) + "!");
106         getline(fin, line);
107
108         s.clear(); hits.clear();
109         s.push_back(0);
110         while (getline(fin, line)) {
111                 istringstream strin(line);
112                 int sid;
113                 double conprb;
114
115                 while (strin>>sid>>conprb) {
116                         hits.push_back(Item(sid, conprb));
117                 }
118                 s.push_back(hits.size());
119         }
120         fin.close();
121
122         N1 = s.size() - 1;
123         nHits = hits.size();
124
125         totc = N0 + N1 + (M + 1);
126
127         if (verbose) { printf("Loading Data is finished!\n"); }
128 }
129
130 // assign threads
131 void init() {
132         int quotient, left;
133         char outF[STRLEN];
134
135         quotient = NSAMPLES / nThreads;
136         left = NSAMPLES % nThreads;
137
138         sprintf(cvsF, "%s.countvectors", imdName);
139         paramsArray = new Params[nThreads];
140         threads = new pthread_t[nThreads];
141
142         for (int i = 0; i < nThreads; i++) {
143                 paramsArray[i].no = i;
144
145                 paramsArray[i].nsamples = quotient;
146                 if (i < left) paramsArray[i].nsamples++;
147
148                 sprintf(outF, "%s%d", cvsF, i);
149                 paramsArray[i].fo = fopen(outF, "w");
150
151                 paramsArray[i].engine = engineFactory::new_engine();
152                 paramsArray[i].pme_c = new double[M + 1];
153                 memset(paramsArray[i].pme_c, 0, sizeof(double) * (M + 1));
154                 paramsArray[i].pve_c = new double[M + 1];
155                 memset(paramsArray[i].pve_c, 0, sizeof(double) * (M + 1));
156                 paramsArray[i].pme_theta = new double[M + 1];
157                 memset(paramsArray[i].pme_theta, 0, sizeof(double) * (M + 1));
158         }
159
160         /* set thread attribute to be joinable */
161         pthread_attr_init(&attr);
162         pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
163
164         if (verbose) { printf("Initialization finished!\n"); }
165 }
166
167 //sample theta from Dir(1)
168 void sampleTheta(engine_type& engine, vector<double>& theta) {
169         gamma_dist gm(1);
170         gamma_generator gmg(engine, gm);
171         double denom;
172
173         theta.assign(M + 1, 0);
174         denom = 0.0;
175         for (int i = 0; i <= M; i++) {
176                 theta[i] = gmg();
177                 denom += theta[i];
178         }
179         assert(denom > EPSILON);
180         for (int i = 0; i <= M; i++) theta[i] /= denom;
181 }
182
183 void writeCountVector(FILE* fo, vector<int>& counts) {
184         for (int i = 0; i < M; i++) {
185                 fprintf(fo, "%d ", counts[i]);
186         }
187         fprintf(fo, "%d\n", counts[M]);
188 }
189
190 void* Gibbs(void* arg) {
191         int CHAINLEN;
192         HIT_INT_TYPE len, fr, to;
193         Params *params = (Params*)arg;
194
195         vector<double> theta;
196         vector<int> z, counts;
197         vector<double> arr;
198
199         uniform01 rg(*params->engine);
200
201         // generate initial state
202         sampleTheta(*params->engine, theta);
203
204         z.assign(N1, 0);
205
206         counts.assign(M + 1, 1); // 1 pseudo count
207         counts[0] += N0;
208
209         for (READ_INT_TYPE i = 0; i < N1; i++) {
210                 fr = s[i]; to = s[i + 1];
211                 len = to - fr;
212                 arr.assign(len, 0);
213                 for (HIT_INT_TYPE j = fr; j < to; j++) {
214                         arr[j - fr] = theta[hits[j].sid] * hits[j].conprb;
215                         if (j > fr) arr[j - fr] += arr[j - fr - 1];  // cumulative
216                 }
217                 z[i] = hits[fr + sample(rg, arr, len)].sid;
218                 ++counts[z[i]];
219         }
220
221         // Gibbs sampling
222         CHAINLEN = 1 + (params->nsamples - 1) * GAP;
223         for (int ROUND = 1; ROUND <= BURNIN + CHAINLEN; ROUND++) {
224
225                 for (READ_INT_TYPE i = 0; i < N1; i++) {
226                         --counts[z[i]];
227                         fr = s[i]; to = s[i + 1]; len = to - fr;
228                         arr.assign(len, 0);
229                         for (HIT_INT_TYPE j = fr; j < to; j++) {
230                                 arr[j - fr] = counts[hits[j].sid] * hits[j].conprb;
231                                 if (j > fr) arr[j - fr] += arr[j - fr - 1]; //cumulative
232                         }
233                         z[i] = hits[fr + sample(rg, arr, len)].sid;
234                         ++counts[z[i]];
235                 }
236
237                 if (ROUND > BURNIN) {
238                         if ((ROUND - BURNIN - 1) % GAP == 0) {
239                                 writeCountVector(params->fo, counts);
240                                 for (int i = 0; i <= M; i++) {
241                                         params->pme_c[i] += counts[i] - 1;
242                                         params->pve_c[i] += (counts[i] - 1) * (counts[i] - 1);
243                                         params->pme_theta[i] += counts[i] / totc;
244                                 }
245                         }
246                 }
247
248                 if (verbose && ROUND % 100 == 0) { printf("Thread %d, ROUND %d is finished!\n", params->no, ROUND); }
249         }
250
251         return NULL;
252 }
253
254 void release() {
255 //      char inpF[STRLEN], command[STRLEN];
256         string line;
257
258         /* destroy attribute */
259         pthread_attr_destroy(&attr);
260         delete[] threads;
261
262         pme_c.assign(M + 1, 0);
263         pve_c.assign(M + 1, 0);
264         pme_theta.assign(M + 1, 0);
265         for (int i = 0; i < nThreads; i++) {
266                 fclose(paramsArray[i].fo);
267                 delete paramsArray[i].engine;
268                 for (int j = 0; j <= M; j++) {
269                         pme_c[j] += paramsArray[i].pme_c[j];
270                         pve_c[j] += paramsArray[i].pve_c[j];
271                         pme_theta[j] += paramsArray[i].pme_theta[j];
272                 }
273                 delete[] paramsArray[i].pme_c;
274                 delete[] paramsArray[i].pve_c;
275                 delete[] paramsArray[i].pme_theta;
276         }
277         delete[] paramsArray;
278
279
280         for (int i = 0; i <= M; i++) {
281                 pme_c[i] /= NSAMPLES;
282                 pve_c[i] = (pve_c[i] - NSAMPLES * pme_c[i] * pme_c[i]) / (NSAMPLES - 1);
283                 pme_theta[i] /= NSAMPLES;
284         }
285 }
286
287 template<class ModelType>
288 void calcExpectedEffectiveLengths(ModelType& model) {
289         int lb, ub, span;
290         double *pdf = NULL, *cdf = NULL, *clen = NULL; // clen[i] = \sigma_{j=1}^{i}pdf[i]*(lb+i)
291   
292         model.getGLD().copyTo(pdf, cdf, lb, ub, span);
293         clen = new double[span + 1];
294         clen[0] = 0.0;
295         for (int i = 1; i <= span; i++) {
296                 clen[i] = clen[i - 1] + pdf[i] * (lb + i);
297         }
298
299         eel.assign(M + 1, 0.0);
300         for (int i = 1; i <= M; i++) {
301                 int totLen = refs.getRef(i).getTotLen();
302                 int fullLen = refs.getRef(i).getFullLen();
303                 int pos1 = max(min(totLen - fullLen + 1, ub) - lb, 0);
304                 int pos2 = max(min(totLen, ub) - lb, 0);
305
306                 if (pos2 == 0) { eel[i] = 0.0; continue; }
307     
308                 eel[i] = fullLen * cdf[pos1] + ((cdf[pos2] - cdf[pos1]) * (totLen + 1) - (clen[pos2] - clen[pos1]));
309                 assert(eel[i] >= 0);
310                 if (eel[i] < MINEEL) { eel[i] = 0.0; }
311         }
312   
313         delete[] pdf;
314         delete[] cdf;
315         delete[] clen;
316 }
317
318 template<class ModelType>
319 void writeEstimatedParameters(char* modelF, char* imdName) {
320         ModelType model;
321         double denom;
322         char outF[STRLEN];
323         FILE *fo;
324
325         model.read(modelF);
326
327         calcExpectedEffectiveLengths<ModelType>(model);
328
329         denom = pme_theta[0];
330         for (int i = 1; i <= M; i++)
331           if (eel[i] < EPSILON) pme_theta[i] = 0.0;
332           else denom += pme_theta[i];
333
334         general_assert(denom >= EPSILON, "No Expected Effective Length is no less than " + ftos(MINEEL, 6) + "?!");
335
336         for (int i = 0; i <= M; i++) pme_theta[i] /= denom;
337
338         denom = 0.0;
339         double *mw = model.getMW();
340         for (int i = 0; i <= M; i++) {
341           pme_theta[i] = (mw[i] < EPSILON ? 0.0 : pme_theta[i] / mw[i]);
342           denom += pme_theta[i];
343         }
344         assert(denom >= EPSILON);
345         for (int i = 0; i <= M; i++) pme_theta[i] /= denom;
346
347         //calculate tau values
348         double *tau = new double[M + 1];
349         memset(tau, 0, sizeof(double) * (M + 1));
350
351         denom = 0.0;
352         for (int i = 1; i <= M; i++) 
353           if (eel[i] > EPSILON) {
354             tau[i] = pme_theta[i] / eel[i];
355             denom += tau[i];
356           }
357
358         general_assert(denom >= EPSILON, "No alignable reads?!");
359
360         for (int i = 1; i <= M; i++) {
361                 tau[i] /= denom;
362         }
363
364         //isoform level results
365         sprintf(outF, "%s.iso_res", imdName);
366         fo = fopen(outF, "a");
367         general_assert(fo != NULL, "Cannot open " + cstrtos(outF) + "!");
368
369         for (int i = 1; i <= M; i++)
370                 fprintf(fo, "%.2f%c", pme_c[i], (i < M ? '\t' : '\n'));
371         for (int i = 1; i <= M; i++)
372                 fprintf(fo, "%.15g%c", tau[i], (i < M ? '\t' : '\n'));
373
374         fclose(fo);
375
376         //gene level results
377         sprintf(outF, "%s.gene_res", imdName);
378         fo = fopen(outF, "a");
379         general_assert(fo != NULL, "Cannot open " + cstrtos(outF) + "!");
380
381         for (int i = 0; i < m; i++) {
382                 double sumC = 0.0; //  sum of pme counts
383                 int b = gi.spAt(i), e = gi.spAt(i + 1);
384                 for (int j = b; j < e; j++) {
385                         sumC += pme_c[j];
386                 }
387                 fprintf(fo, "%.15g%c", sumC, (i < m - 1 ? '\t' : '\n'));
388         }
389         for (int i = 0; i < m; i++) {
390                 double sumT = 0.0; //  sum of tau values
391                 int b = gi.spAt(i), e = gi.spAt(i + 1);
392                 for (int j = b; j < e; j++) {
393                         sumT += tau[j];
394                 }
395                 fprintf(fo, "%.15g%c", sumT, (i < m - 1 ? '\t' : '\n'));
396         }
397         fclose(fo);
398
399         delete[] tau;
400
401         if (verbose) { printf("Gibbs based expression values are written!\n"); }
402 }
403
404 int main(int argc, char* argv[]) {
405         if (argc < 7) {
406                 printf("Usage: rsem-run-gibbs reference_name imdName statName BURNIN NSAMPLES GAP [-p #Threads] [--var] [-q]\n");
407                 exit(-1);
408         }
409
410         strcpy(imdName, argv[2]);
411         strcpy(statName, argv[3]);
412
413         BURNIN = atoi(argv[4]);
414         NSAMPLES = atoi(argv[5]);
415         GAP = atoi(argv[6]);
416
417         load_data(argv[1], statName, imdName);
418
419         nThreads = 1;
420         var_opt = false;
421         quiet = false;
422
423         for (int i = 7; i < argc; i++) {
424                 if (!strcmp(argv[i], "-p")) nThreads = atoi(argv[i + 1]);
425                 if (!strcmp(argv[i], "--var")) var_opt = true;
426                 if (!strcmp(argv[i], "-q")) quiet = true;
427         }
428         verbose = !quiet;
429
430         assert(NSAMPLES > 1); // Otherwise, we cannot calculate posterior variance
431
432         if (nThreads > NSAMPLES) {
433                 nThreads = NSAMPLES;
434                 printf("Warning: Number of samples is less than number of threads! Change the number of threads to %d!\n", nThreads);
435         }
436
437         if (verbose) printf("Gibbs started!\n");
438
439         init();
440         for (int i = 0; i < nThreads; i++) {
441                 rc = pthread_create(&threads[i], &attr, Gibbs, (void*)(&paramsArray[i]));
442                 pthread_assert(rc, "pthread_create", "Cannot create thread " + itos(i) + " (numbered from 0)!");
443         }
444         for (int i = 0; i < nThreads; i++) {
445                 rc = pthread_join(threads[i], NULL);
446                 pthread_assert(rc, "pthread_join", "Cannot join thread " + itos(i) + " (numbered from 0)!");
447         }
448         release();
449
450         if (verbose) printf("Gibbs finished!\n");
451
452         sprintf(modelF, "%s.model", statName);
453         FILE *fi = fopen(modelF, "r");
454         general_assert(fi != NULL, "Cannot open " + cstrtos(modelF) + "!");
455         assert(fscanf(fi, "%d", &model_type) == 1);
456         fclose(fi);
457
458         switch(model_type) {
459         case 0 : writeEstimatedParameters<SingleModel>(modelF, imdName); break;
460         case 1 : writeEstimatedParameters<SingleQModel>(modelF, imdName); break;
461         case 2 : writeEstimatedParameters<PairedEndModel>(modelF, imdName); break;
462         case 3 : writeEstimatedParameters<PairedEndQModel>(modelF, imdName); break;
463         }
464
465         if (var_opt) {
466                 char varF[STRLEN];
467
468                 sprintf(varF, "%s.var", statName);
469                 FILE *fo = fopen(varF, "w");
470                 general_assert(fo != NULL, "Cannot open " + cstrtos(varF) + "!");
471                 for (int i = 0; i < m; i++) {
472                         int b = gi.spAt(i), e = gi.spAt(i + 1), number_of_isoforms = e - b;
473                         for (int j = b; j < e; j++) {
474                                 fprintf(fo, "%s\t%d\t%.15g\t%.15g\n", refs.getRef(j).getName().c_str(), number_of_isoforms, pme_c[j], pve_c[j]);
475                         }
476                 }
477                 fclose(fo);
478         }
479
480         return 0;
481 }