]> git.donarmstrong.com Git - rsem.git/blob - Gibbs.cpp
Version with Gibbs parallelized
[rsem.git] / Gibbs.cpp
1 #include<ctime>
2 #include<cstdio>
3 #include<cstring>
4 #include<cstdlib>
5 #include<cassert>
6 #include<fstream>
7 #include<sstream>
8 #include<vector>
9 #include<set>
10 #include<pthread.h>
11
12 #include "boost/random.hpp"
13
14 #include "utils.h"
15
16 #include "Model.h"
17 #include "SingleModel.h"
18 #include "SingleQModel.h"
19 #include "PairedEndModel.h"
20 #include "PairedEndQModel.h"
21
22 #include "Refs.h"
23 #include "GroupInfo.h"
24
25 using namespace std;
26
27 typedef unsigned int seedType;
28 typedef boost::mt19937 engine_type;
29 typedef boost::gamma_distribution<> distribution_type;
30 typedef boost::variate_generator<engine_type&, distribution_type> generator_type;
31 typedef boost::uniform_01<engine_type> uniform_generator;
32
33 struct Params {
34         int no, nsamples;
35         FILE *fo;
36         boost::mt19937 *rng;
37         double *pme_c, *pme_theta;
38 };
39
40 struct Item {
41         int sid;
42         double conprb;
43
44         Item(int sid, double conprb) {
45                 this->sid = sid;
46                 this->conprb = conprb;
47         }
48 };
49
50 int nThreads;
51
52 int model_type;
53 int m, M, N0, N1, nHits;
54 double totc;
55 int BURNIN, NSAMPLES, GAP;
56 char imdName[STRLEN], statName[STRLEN];
57 char thetaF[STRLEN], ofgF[STRLEN], groupF[STRLEN], refF[STRLEN], modelF[STRLEN];
58 char cvsF[STRLEN];
59
60 Refs refs;
61 GroupInfo gi;
62
63 vector<double> theta, pme_theta, pme_c, eel;
64
65 vector<int> s;
66 vector<Item> hits;
67 vector<int> counts;
68
69 bool quiet;
70
71 engine_type engine(time(NULL));
72 Params *params;
73 pthread_t *threads;
74 pthread_attr_t attr;
75 void *status;
76 int rc;
77
78 void load_data(char* reference_name, char* statName, char* imdName) {
79         ifstream fin;
80         string line;
81         int tmpVal;
82
83         //load reference file
84         sprintf(refF, "%s.seq", reference_name);
85         refs.loadRefs(refF, 1);
86         M = refs.getM();
87
88         //load groupF
89         sprintf(groupF, "%s.grp", reference_name);
90         gi.load(groupF);
91         m = gi.getm();
92
93         //load thetaF
94         sprintf(thetaF, "%s.theta",statName);
95         fin.open(thetaF);
96         if (!fin.is_open()) {
97                 fprintf(stderr, "Cannot open %s!\n", thetaF);
98                 exit(-1);
99         }
100         fin>>tmpVal;
101         if (tmpVal != M + 1) {
102                 fprintf(stderr, "Number of transcripts is not consistent in %s and %s!\n", refF, thetaF);
103                 exit(-1);
104         }
105         theta.clear(); theta.resize(M + 1);
106         for (int i = 0; i <= M; i++) fin>>theta[i];
107         fin.close();
108
109         //load ofgF;
110         sprintf(ofgF, "%s.ofg", imdName);
111         fin.open(ofgF);
112         if (!fin.is_open()) {
113                 fprintf(stderr, "Cannot open %s!\n", ofgF);
114                 exit(-1);
115         }
116         fin>>tmpVal>>N0;
117         if (tmpVal != M) {
118                 fprintf(stderr, "M in %s is not consistent with %s!\n", ofgF, refF);
119                 exit(-1);
120         }
121         getline(fin, line);
122
123         s.clear(); hits.clear();
124         s.push_back(0);
125         while (getline(fin, line)) {
126                 istringstream strin(line);
127                 int sid;
128                 double conprb;
129
130                 while (strin>>sid>>conprb) {
131                         hits.push_back(Item(sid, conprb));
132                 }
133                 s.push_back(hits.size());
134         }
135         fin.close();
136
137         N1 = s.size() - 1;
138         nHits = hits.size();
139
140         totc = N0 + N1 + (M + 1);
141
142         if (verbose) { printf("Loading Data is finished!\n"); }
143 }
144
145 // assign threads
146 void init() {
147         int quotient, left;
148         seedType seed;
149         set<seedType> seedSet;
150         set<seedType>::iterator iter;
151         char outF[STRLEN];
152
153         quotient = NSAMPLES / nThreads;
154         left = NSAMPLES % nThreads;
155
156         sprintf(cvsF, "%s.countvectors", imdName);
157         seedSet.clear();
158         params = new Params[nThreads];
159         threads = new pthread_t[nThreads];
160         for (int i = 0; i < nThreads; i++) {
161                 params[i].no = i;
162
163                 params[i].nsamples = quotient;
164                 if (i < left) params[i].nsamples++;
165
166                 if (i == 0) {
167                         params[i].fo = fopen(cvsF, "w");
168                         fprintf(params[i].fo, "%d %d\n", NSAMPLES, M + 1);
169                 }
170                 else {
171                         sprintf(outF, "%s%d", cvsF, i);
172                         params[i].fo = fopen(outF, "w");
173                 }
174
175                 do {
176                         seed = engine();
177                         iter = seedSet.find(seed);
178                 } while (iter != seedSet.end());
179                 params[i].rng = new engine_type(seed);
180                 seedSet.insert(seed);
181
182                 params[i].pme_c = new double[M + 1];
183                 memset(params[i].pme_c, 0, sizeof(double) * (M + 1));
184                 params[i].pme_theta = new double[M + 1];
185                 memset(params[i].pme_theta, 0, sizeof(double) * (M + 1));
186         }
187
188         /* set thread attribute to be joinable */
189         pthread_attr_init(&attr);
190         pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
191
192 }
193
194 void sampleTheta(engine_type& engine, vector<double>& theta) {
195         distribution_type gm(1);
196         generator_type gmg(engine, gm);
197         double denom;
198
199         theta.clear();
200         theta.resize(M + 1);
201         denom = 0.0;
202         for (int i = 0; i <= M; i++) {
203                 theta[i] = gmg();
204                 denom += theta[i];
205         }
206         assert(denom > EPSILON);
207         for (int i = 0; i <= M; i++) theta[i] /= denom;
208 }
209
210 // arr should be cumulative!
211 // interval : [,)
212 // random number should be in [0, arr[len - 1])
213 // If by chance arr[len - 1] == 0.0, one possibility is to sample uniformly from 0...len-1
214 int sample(uniform_generator& rg, vector<double>& arr, int len) {
215   int l, r, mid;
216   double prb = rg() * arr[len - 1];
217
218   l = 0; r = len - 1;
219   while (l <= r) {
220     mid = (l + r) / 2;
221     if (arr[mid] <= prb) l = mid + 1;
222     else r = mid - 1;
223   }
224
225   if (l >= len) { printf("%d %lf %lf\n", len, arr[len - 1], prb); }
226   assert(l < len);
227
228   return l;
229 }
230
231 void writeForSee(FILE* fo, vector<int>& counts) {
232   for (int i = 1; i < M; i++) fprintf(fo, "%d ", counts[i]);
233   fprintf(fo, "%d\n", counts[M]);
234 }
235
236 void writeCountVector(FILE* fo, vector<int>& counts) {
237         for (int i = 0; i < M; i++) {
238                 fprintf(fo, "%d ", counts[i]);
239         }
240         fprintf(fo, "%d\n", counts[M]);
241 }
242
243 void* Gibbs(void* arg) {
244         int len, fr, to;
245         int CHAINLEN;
246         Params *params = (Params*)arg;
247
248         vector<double> theta;
249         vector<int> z, counts;
250         vector<double> arr;
251
252         uniform_generator rg(*params->rng);
253
254         sampleTheta(*params->rng, theta);
255
256
257         // generate initial state
258         arr.clear();
259
260         z.clear();
261         z.resize(N1);
262
263         counts.clear();
264         counts.resize(M + 1, 1); // 1 pseudo count
265         counts[0] += N0;
266
267         for (int i = 0; i < N1; i++) {
268                 fr = s[i]; to = s[i + 1];
269                 len = to - fr;
270                 arr.resize(len);
271                 for (int j = fr; j < to; j++) {
272                         arr[j - fr] = theta[hits[j].sid] * hits[j].conprb;
273                         if (j > fr) arr[j - fr] += arr[j - fr - 1];  // cumulative
274                 }
275                 z[i] = hits[fr + sample(rg, arr, len)].sid;
276                 ++counts[z[i]];
277         }
278
279         // Gibbs sampling
280     char seeF[STRLEN];
281         sprintf(seeF, "%s.see", statName);
282         FILE *fsee = fopen(seeF, "w");
283
284         CHAINLEN = 1 + (params->nsamples - 1) * GAP;
285         for (int ROUND = 1; ROUND <= BURNIN + CHAINLEN  ; ROUND++) {
286
287                 for (int i = 0; i < N1; i++) {
288                         --counts[z[i]];
289                         fr = s[i]; to = s[i + 1]; len = to - fr;
290                         arr.resize(len);
291                         for (int j = fr; j < to; j++) {
292                                 arr[j - fr] = counts[hits[j].sid] * hits[j].conprb;
293                                 if (j > fr) arr[j - fr] += arr[j - fr - 1]; //cumulative
294                         }
295                         z[i] = hits[fr + sample(rg, arr, len)].sid;
296                         ++counts[z[i]];
297                 }
298
299                 writeForSee(fsee, counts);
300
301                 if (ROUND > BURNIN) {
302                         if ((ROUND - BURNIN - 1) % GAP == 0) writeCountVector(params->fo, counts);
303                         for (int i = 0; i <= M; i++) {
304                                 params->pme_c[i] += counts[i] - 1;
305                                 params->pme_theta[i] += counts[i] / totc;
306                         }
307                 }
308
309                 if (verbose & ROUND % 10 == 0) { printf("Thread %d, ROUND %d is finished!\n", params->no, ROUND); }
310         }
311         fclose(fsee);
312
313         return NULL;
314 }
315
316 void release() {
317         char inpF[STRLEN], command[STRLEN];
318         string line;
319
320         /* destroy attribute */
321         pthread_attr_destroy(&attr);
322         delete[] threads;
323
324         pme_c.clear(); pme_c.resize(M + 1, 0.0);
325         pme_theta.clear(); pme_theta.resize(M + 1, 0.0);
326         for (int i = 0; i < nThreads; i++) {
327                 fclose(params[i].fo);
328                 delete params[i].rng;
329                 for (int j = 0; j <= M; j++) {
330                         pme_c[j] += params[i].pme_c[j];
331                         pme_theta[j] += params[i].pme_theta[j];
332                 }
333                 delete[] params[i].pme_c;
334                 delete[] params[i].pme_theta;
335         }
336         delete[] params;
337
338         for (int i = 0; i <= M; i++) {
339                 pme_c[i] /= NSAMPLES;
340                 pme_theta[i] /= NSAMPLES;
341         }
342
343         // combine files
344         FILE *fo = fopen(cvsF, "a");
345         for (int i = 1; i < nThreads; i++) {
346                 sprintf(inpF, "%s%d", cvsF, i);
347                 ifstream fin(inpF);
348                 while (getline(fin, line)) {
349                         fprintf(fo, "%s\n", line.c_str());
350                 }
351                 fin.close();
352                 sprintf(command, "rm -f %s", inpF);
353                 int status = system(command);
354                 if (status != 0) { fprintf(stderr, "Fail to delete file %s!\n", inpF); exit(-1); }
355         }
356         fclose(fo);
357 }
358
359 template<class ModelType>
360 void calcExpectedEffectiveLengths(ModelType& model) {
361   int lb, ub, span;
362   double *pdf = NULL, *cdf = NULL, *clen = NULL; // clen[i] = sigma_{j=1}^{i}pdf[i]*(lb+i)
363   
364   model.getGLD().copyTo(pdf, cdf, lb, ub, span);
365   clen = new double[span + 1];
366   clen[0] = 0.0;
367   for (int i = 1; i <= span; i++) {
368     clen[i] = clen[i - 1] + pdf[i] * (lb + i);
369   }
370
371   eel.clear();
372   eel.resize(M + 1, 0.0);
373   for (int i = 1; i <= M; i++) {
374     int totLen = refs.getRef(i).getTotLen();
375     int fullLen = refs.getRef(i).getFullLen();
376     int pos1 = max(min(totLen - fullLen + 1, ub) - lb, 0);
377     int pos2 = max(min(totLen, ub) - lb, 0);
378
379     if (pos2 == 0) { eel[i] = 0.0; continue; }
380     
381     eel[i] = fullLen * cdf[pos1] + ((cdf[pos2] - cdf[pos1]) * (totLen + 1) - (clen[pos2] - clen[pos1]));
382     assert(eel[i] >= 0);
383     if (eel[i] < MINEEL) { eel[i] = 0.0; }
384   }
385   
386   delete[] pdf;
387   delete[] cdf;
388   delete[] clen;
389 }
390
391 template<class ModelType>
392 void writeEstimatedParameters(char* modelF, char* imdName) {
393         ModelType model;
394         double denom;
395         char outF[STRLEN];
396         FILE *fo;
397
398         model.read(modelF);
399
400         calcExpectedEffectiveLengths<ModelType>(model);
401
402         denom = pme_theta[0];
403         for (int i = 1; i <= M; i++)
404           if (eel[i] < EPSILON) pme_theta[i] = 0.0;
405           else denom += pme_theta[i];
406         if (denom <= 0) { fprintf(stderr, "No Expected Effective Length is no less than %.6g?!\n", MINEEL); exit(-1); }
407         for (int i = 0; i <= M; i++) pme_theta[i] /= denom;
408
409         denom = 0.0;
410         double *mw = model.getMW();
411         for (int i = 0; i <= M; i++) {
412           pme_theta[i] = (mw[i] < EPSILON ? 0.0 : pme_theta[i] / mw[i]);
413           denom += pme_theta[i];
414         }
415         assert(denom >= EPSILON);
416         for (int i = 0; i <= M; i++) pme_theta[i] /= denom;
417
418         //calculate tau values
419         double *tau = new double[M + 1];
420         memset(tau, 0, sizeof(double) * (M + 1));
421
422         denom = 0.0;
423         for (int i = 1; i <= M; i++) 
424           if (eel[i] > EPSILON) {
425             tau[i] = pme_theta[i] / eel[i];
426             denom += tau[i];
427           }
428         if (denom <= 0) { fprintf(stderr, "No alignable reads?!\n"); exit(-1); }
429         //assert(denom > 0);
430         for (int i = 1; i <= M; i++) {
431                 tau[i] /= denom;
432         }
433
434         //isoform level results
435         sprintf(outF, "%s.iso_res", imdName);
436         fo = fopen(outF, "a");
437         if (fo == NULL) { fprintf(stderr, "Cannot open %s!\n", outF); exit(-1); }
438         for (int i = 1; i <= M; i++)
439                 fprintf(fo, "%.2f%c", pme_c[i], (i < M ? '\t' : '\n'));
440         for (int i = 1; i <= M; i++)
441                 fprintf(fo, "%.15g%c", tau[i], (i < M ? '\t' : '\n'));
442         fclose(fo);
443
444         //gene level results
445         sprintf(outF, "%s.gene_res", imdName);
446         fo = fopen(outF, "a");
447         if (fo == NULL) { fprintf(stderr, "Cannot open %s!\n", outF); exit(-1); }
448         for (int i = 0; i < m; i++) {
449                 double sumC = 0.0; //  sum of pme counts
450                 int b = gi.spAt(i), e = gi.spAt(i + 1);
451                 for (int j = b; j < e; j++) {
452                         sumC += pme_c[j];
453                 }
454                 fprintf(fo, "%.15g%c", sumC, (i < m - 1 ? '\t' : '\n'));
455         }
456         for (int i = 0; i < m; i++) {
457                 double sumT = 0.0; //  sum of tau values
458                 int b = gi.spAt(i), e = gi.spAt(i + 1);
459                 for (int j = b; j < e; j++) {
460                         sumT += tau[j];
461                 }
462                 fprintf(fo, "%.15g%c", sumT, (i < m - 1 ? '\t' : '\n'));
463         }
464         fclose(fo);
465
466         delete[] tau;
467
468         if (verbose) { printf("Gibbs based expression values are written!\n"); }
469 }
470
471
472 int main(int argc, char* argv[]) {
473         if (argc < 7) {
474                 printf("Usage: rsem-run-gibbs reference_name sample_name sampleToken BURNIN NSAMPLES GAP [-p #Threads] [-q]\n");
475                 exit(-1);
476         }
477
478         BURNIN = atoi(argv[4]);
479         NSAMPLES = atoi(argv[5]);
480         GAP = atoi(argv[6]);
481         sprintf(imdName, "%s.temp/%s", argv[2], argv[3]);
482         sprintf(statName, "%s.stat/%s", argv[2], argv[3]);
483         load_data(argv[1], statName, imdName);
484
485         nThreads = 1;
486         quiet = false;
487         for (int i = 7; i < argc; i++) {
488                 if (!strcmp(argv[i], "-p")) nThreads = atoi(argv[i + 1]);
489                 if (!strcmp(argv[i], "-q")) quiet = true;
490         }
491         verbose = !quiet;
492
493         if (nThreads > NSAMPLES) {
494                 nThreads = NSAMPLES;
495                 printf("Warning: Number of samples is less than number of threads! Change the number of threads to %d!\n", nThreads);
496         }
497
498         if (verbose) printf("Gibbs started!\n");
499         init();
500         for (int i = 0; i < nThreads; i++) {
501                 rc = pthread_create(&threads[i], &attr, Gibbs, (void*)(&params[i]));
502                 if (rc != 0) { fprintf(stderr, "Cannot create thread %d! (numbered from 0)\n", i); }
503         }
504         for (int i = 0; i < nThreads; i++) {
505                 rc = pthread_join(threads[i], &status);
506                 if (rc != 0) { fprintf(stderr, "Cannot join thread %d! (numbered from 0)\n", i); }
507         }
508         release();
509         if (verbose) printf("Gibbs finished!\n");
510
511         sprintf(modelF, "%s.model", statName);
512         FILE *fi = fopen(modelF, "r");
513         if (fi == NULL) { fprintf(stderr, "Cannot open %s!\n", modelF); exit(-1); }
514         fscanf(fi, "%d", &model_type);
515         fclose(fi);
516
517         switch(model_type) {
518         case 0 : writeEstimatedParameters<SingleModel>(modelF, imdName); break;
519         case 1 : writeEstimatedParameters<SingleQModel>(modelF, imdName); break;
520         case 2 : writeEstimatedParameters<PairedEndModel>(modelF, imdName); break;
521         case 3 : writeEstimatedParameters<PairedEndQModel>(modelF, imdName); break;
522         }
523
524         return 0;
525 }