]> git.donarmstrong.com Git - rsem.git/blob - Gibbs.cpp
Modified WHAT_IS_NEW
[rsem.git] / Gibbs.cpp
1 #include<cstdio>
2 #include<cstring>
3 #include<cstdlib>
4 #include<cassert>
5 #include<fstream>
6 #include<sstream>
7 #include<vector>
8 #include<pthread.h>
9
10 #include "utils.h"
11 #include "my_assert.h"
12 #include "sampling.h"
13
14 #include "Model.h"
15 #include "SingleModel.h"
16 #include "SingleQModel.h"
17 #include "PairedEndModel.h"
18 #include "PairedEndQModel.h"
19
20 #include "Refs.h"
21
22 #include "GroupInfo.h"
23 #include "WriteResults.h"
24
25 using namespace std;
26
27 struct Params {
28   int no, nsamples;
29   FILE *fo;
30   engine_type *engine;
31   double *pme_c, *pve_c; //posterior mean and variance vectors on counts
32   double *pme_tpm, *pme_fpkm;
33   
34   double *pve_c_genes, *pve_c_trans;
35 };
36
37 struct Item {
38         int sid;
39         double conprb;
40
41         Item(int sid, double conprb) {
42                 this->sid = sid;
43                 this->conprb = conprb;
44         }
45 };
46
47 int nThreads;
48
49 int model_type;
50 int M;
51 READ_INT_TYPE N0, N1;
52 HIT_INT_TYPE nHits;
53 double totc;
54 int BURNIN, NSAMPLES, GAP;
55 char refName[STRLEN], imdName[STRLEN], statName[STRLEN];
56 char thetaF[STRLEN], ofgF[STRLEN], refF[STRLEN], modelF[STRLEN];
57 char cvsF[STRLEN];
58
59 Refs refs;
60
61 vector<HIT_INT_TYPE> s;
62 vector<Item> hits;
63
64 vector<double> eel;
65 double *mw;
66
67 vector<int> pseudo_counts;
68 vector<double> pme_c, pve_c; //global posterior mean and variance vectors on counts
69 vector<double> pme_tpm, pme_fpkm;
70
71 bool quiet;
72
73 Params *paramsArray;
74 pthread_t *threads;
75 pthread_attr_t attr;
76 int rc;
77
78 bool hasSeed;
79 seedType seed;
80
81 int m;
82 char groupF[STRLEN];
83 GroupInfo gi;
84
85 bool alleleS;
86 int m_trans;
87 GroupInfo gt, ta;
88 vector<double> pve_c_genes, pve_c_trans;
89
90 void load_data(char* refName, char* statName, char* imdName) {
91         ifstream fin;
92         string line;
93         int tmpVal;
94
95         //load reference file
96         sprintf(refF, "%s.seq", refName);
97         refs.loadRefs(refF, 1);
98         M = refs.getM();
99
100         //load ofgF;
101         sprintf(ofgF, "%s.ofg", imdName);
102         fin.open(ofgF);
103         general_assert(fin.is_open(), "Cannot open " + cstrtos(ofgF) + "!");
104         fin>>tmpVal>>N0;
105         general_assert(tmpVal == M, "M in " + cstrtos(ofgF) + " is not consistent with " + cstrtos(refF) + "!");
106         getline(fin, line);
107
108         s.clear(); hits.clear();
109         s.push_back(0);
110         while (getline(fin, line)) {
111                 istringstream strin(line);
112                 int sid;
113                 double conprb;
114
115                 while (strin>>sid>>conprb) {
116                         hits.push_back(Item(sid, conprb));
117                 }
118                 s.push_back(hits.size());
119         }
120         fin.close();
121
122         N1 = s.size() - 1;
123         nHits = hits.size();
124
125         totc = N0 + N1 + (M + 1);
126
127         if (verbose) { printf("Loading data is finished!\n"); }
128 }
129
130 void load_group_info(char* refName) {
131   // Load group info
132   sprintf(groupF, "%s.grp", refName);
133   gi.load(groupF);
134   m = gi.getm();
135   
136   alleleS = isAlleleSpecific(refName, &gt, &ta); // if allele-specific 
137   m_trans = (alleleS ? ta.getm() : 0);
138
139   if (verbose) { printf("Loading group information is finished!\n"); }
140 }
141
142 // Load imdName.omit and initialize the pseudo count vector.
143 void load_omit_info(const char* imdName) {
144   char omitF[STRLEN];
145   sprintf(omitF, "%s.omit", imdName);
146   FILE *fi = fopen(omitF, "r");
147   pseudo_counts.assign(M + 1, 1);
148   int tid;
149   while (fscanf(fi, "%d", &tid) == 1) pseudo_counts[tid] = 0;
150   fclose(fi);
151 }
152
153 template<class ModelType>
154 void init_model_related(char* modelF) {
155         ModelType model;
156         model.read(modelF);
157
158         calcExpectedEffectiveLengths<ModelType>(M, refs, model, eel);
159         memcpy(mw, model.getMW(), sizeof(double) * (M + 1)); // otherwise, after exiting this procedure, mw becomes undefined
160 }
161
162 // assign threads
163 void init() {
164         int quotient, left;
165         char outF[STRLEN];
166
167         quotient = NSAMPLES / nThreads;
168         left = NSAMPLES % nThreads;
169
170         sprintf(cvsF, "%s.countvectors", imdName);
171         paramsArray = new Params[nThreads];
172         threads = new pthread_t[nThreads];
173
174         hasSeed ? engineFactory::init(seed) : engineFactory::init();
175         for (int i = 0; i < nThreads; i++) {
176                 paramsArray[i].no = i;
177
178                 paramsArray[i].nsamples = quotient;
179                 if (i < left) paramsArray[i].nsamples++;
180
181                 sprintf(outF, "%s%d", cvsF, i);
182                 paramsArray[i].fo = fopen(outF, "w");
183
184                 paramsArray[i].engine = engineFactory::new_engine();
185                 paramsArray[i].pme_c = new double[M + 1];
186                 memset(paramsArray[i].pme_c, 0, sizeof(double) * (M + 1));
187                 paramsArray[i].pve_c = new double[M + 1];
188                 memset(paramsArray[i].pve_c, 0, sizeof(double) * (M + 1));
189                 paramsArray[i].pme_tpm = new double[M + 1];
190                 memset(paramsArray[i].pme_tpm, 0, sizeof(double) * (M + 1));
191                 paramsArray[i].pme_fpkm = new double[M + 1];
192                 memset(paramsArray[i].pme_fpkm, 0, sizeof(double) * (M + 1));
193
194                 paramsArray[i].pve_c_genes = new double[m];
195                 memset(paramsArray[i].pve_c_genes, 0, sizeof(double) * m);
196                 
197                 paramsArray[i].pve_c_trans = NULL;
198                 if (alleleS) {
199                   paramsArray[i].pve_c_trans = new double[m_trans];
200                   memset(paramsArray[i].pve_c_trans, 0, sizeof(double) * m_trans);
201                 }
202         }
203         engineFactory::finish();
204
205         /* set thread attribute to be joinable */
206         pthread_attr_init(&attr);
207         pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
208
209         if (verbose) { printf("Initialization finished!\n"); }
210 }
211
212 //sample theta from Dir(1)
213 void sampleTheta(engine_type& engine, vector<double>& theta) {
214         gamma_dist gm(1);
215         gamma_generator gmg(engine, gm);
216         double denom;
217
218         theta.assign(M + 1, 0);
219         denom = 0.0;
220         for (int i = 0; i <= M; i++) {
221                 theta[i] = (pseudo_counts[i] > 0 ? gmg() : 0.0);
222                 denom += theta[i];
223         }
224         assert(denom > EPSILON);
225         for (int i = 0; i <= M; i++) theta[i] /= denom;
226 }
227
228 void writeCountVector(FILE* fo, vector<int>& counts) {
229         for (int i = 0; i < M; i++) {
230                 fprintf(fo, "%d ", counts[i]);
231         }
232         fprintf(fo, "%d\n", counts[M]);
233 }
234
235 void* Gibbs(void* arg) {
236         int CHAINLEN;
237         HIT_INT_TYPE len, fr, to;
238         Params *params = (Params*)arg;
239
240         vector<double> theta, tpm, fpkm;
241         vector<int> z, counts(pseudo_counts);
242         vector<double> arr;
243
244         uniform_01_generator rg(*params->engine, uniform_01_dist());
245
246         // generate initial state
247         sampleTheta(*params->engine, theta);
248
249         z.assign(N1, 0);
250         counts[0] += N0;
251
252         for (READ_INT_TYPE i = 0; i < N1; i++) {
253                 fr = s[i]; to = s[i + 1];
254                 len = to - fr;
255                 arr.assign(len, 0);
256                 for (HIT_INT_TYPE j = fr; j < to; j++) {
257                         arr[j - fr] = theta[hits[j].sid] * hits[j].conprb;
258                         if (j > fr) arr[j - fr] += arr[j - fr - 1];  // cumulative
259                 }
260                 z[i] = hits[fr + sample(rg, arr, len)].sid;
261                 ++counts[z[i]];
262         }
263
264         // Gibbs sampling
265         CHAINLEN = 1 + (params->nsamples - 1) * GAP;
266         for (int ROUND = 1; ROUND <= BURNIN + CHAINLEN; ROUND++) {
267
268                 for (READ_INT_TYPE i = 0; i < N1; i++) {
269                         --counts[z[i]];
270                         fr = s[i]; to = s[i + 1]; len = to - fr;
271                         arr.assign(len, 0);
272                         for (HIT_INT_TYPE j = fr; j < to; j++) {
273                                 arr[j - fr] = counts[hits[j].sid] * hits[j].conprb;
274                                 if (j > fr) arr[j - fr] += arr[j - fr - 1]; //cumulative
275                         }
276                         z[i] = hits[fr + sample(rg, arr, len)].sid;
277                         ++counts[z[i]];
278                 }
279
280                 if (ROUND > BURNIN) {
281                         if ((ROUND - BURNIN - 1) % GAP == 0) {
282                                 writeCountVector(params->fo, counts);
283                                 for (int i = 0; i <= M; i++) theta[i] = counts[i] / totc;
284                                 polishTheta(M, theta, eel, mw);
285                                 calcExpressionValues(M, theta, eel, tpm, fpkm);
286                                 for (int i = 0; i <= M; i++) {
287                                         params->pme_c[i] += counts[i] - pseudo_counts[i];
288                                         params->pve_c[i] += double(counts[i] - pseudo_counts[i]) * (counts[i] - pseudo_counts[i]);
289                                         params->pme_tpm[i] += tpm[i];
290                                         params->pme_fpkm[i] += fpkm[i];
291                                 }
292
293                                 for (int i = 0; i < m; i++) {
294                                   int b = gi.spAt(i), e = gi.spAt(i + 1);
295                                   double count = 0.0;
296                                   for (int j = b; j < e; j++) count += counts[j] - pseudo_counts[j];
297                                   params->pve_c_genes[i] += count * count;
298                                 }
299
300                                 if (alleleS)
301                                   for (int i = 0; i < m_trans; i++) {
302                                     int b = ta.spAt(i), e = ta.spAt(i + 1);
303                                     double count = 0.0;
304                                     for (int j = b; j < e; j++) count += counts[j] - pseudo_counts[j];
305                                     params->pve_c_trans[i] += count * count;
306                                   }
307                         }
308                 }
309
310                 if (verbose && ROUND % 100 == 0) { printf("Thread %d, ROUND %d is finished!\n", params->no, ROUND); }
311         }
312
313         return NULL;
314 }
315
316 void release() {
317 //      char inpF[STRLEN], command[STRLEN];
318         string line;
319
320         /* destroy attribute */
321         pthread_attr_destroy(&attr);
322         delete[] threads;
323
324         pme_c.assign(M + 1, 0);
325         pve_c.assign(M + 1, 0);
326         pme_tpm.assign(M + 1, 0);
327         pme_fpkm.assign(M + 1, 0);
328
329         pve_c_genes.assign(m, 0);
330         pve_c_trans.clear();
331         if (alleleS) pve_c_trans.assign(m_trans, 0);
332
333         for (int i = 0; i < nThreads; i++) {
334                 fclose(paramsArray[i].fo);
335                 delete paramsArray[i].engine;
336                 for (int j = 0; j <= M; j++) {
337                         pme_c[j] += paramsArray[i].pme_c[j];
338                         pve_c[j] += paramsArray[i].pve_c[j];
339                         pme_tpm[j] += paramsArray[i].pme_tpm[j];
340                         pme_fpkm[j] += paramsArray[i].pme_fpkm[j];
341                 }
342
343                 for (int j = 0; j < m; j++) 
344                   pve_c_genes[j] += paramsArray[i].pve_c_genes[j];
345                 
346                 if (alleleS) 
347                   for (int j = 0; j < m_trans; j++) 
348                     pve_c_trans[j] += paramsArray[i].pve_c_trans[j];
349
350                 delete[] paramsArray[i].pme_c;
351                 delete[] paramsArray[i].pve_c;
352                 delete[] paramsArray[i].pme_tpm;
353                 delete[] paramsArray[i].pme_fpkm;
354
355                 delete[] paramsArray[i].pve_c_genes;
356                 if (alleleS) delete[] paramsArray[i].pve_c_trans;
357         }
358         delete[] paramsArray;
359
360         for (int i = 0; i <= M; i++) {
361                 pme_c[i] /= NSAMPLES;
362                 pve_c[i] = (pve_c[i] - double(NSAMPLES) * pme_c[i] * pme_c[i]) / double(NSAMPLES - 1);
363                 if (pve_c[i] < 0.0) pve_c[i] = 0.0;
364                 pme_tpm[i] /= NSAMPLES;
365                 pme_fpkm[i] /= NSAMPLES;
366         }
367
368         for (int i = 0; i < m; i++) {
369           int b = gi.spAt(i), e = gi.spAt(i + 1);
370           double pme_c_gene = 0.0;
371           for (int j = b; j < e; j++) pme_c_gene += pme_c[j];
372           pve_c_genes[i] = (pve_c_genes[i] - double(NSAMPLES) * pme_c_gene * pme_c_gene) / double(NSAMPLES - 1);
373           if (pve_c_genes[i] < 0.0) pve_c_genes[i] = 0.0;
374         }
375
376         if (alleleS) 
377           for (int i = 0; i < m_trans; i++) {
378             int b = ta.spAt(i), e = ta.spAt(i + 1);
379             double pme_c_tran = 0.0;
380             for (int j = b; j < e; j++) pme_c_tran += pme_c[j];
381             pve_c_trans[i] = (pve_c_trans[i] - double(NSAMPLES) * pme_c_tran * pme_c_tran) / double(NSAMPLES - 1);
382             if (pve_c_trans[i] < 0.0) pve_c_trans[i] = 0.0;
383           }
384 }
385
386 int main(int argc, char* argv[]) {
387         if (argc < 7) {
388                 printf("Usage: rsem-run-gibbs reference_name imdName statName BURNIN NSAMPLES GAP [-p #Threads] [--seed seed] [-q]\n");
389                 exit(-1);
390         }
391
392         strcpy(refName, argv[1]);
393         strcpy(imdName, argv[2]);
394         strcpy(statName, argv[3]);
395
396         BURNIN = atoi(argv[4]);
397         NSAMPLES = atoi(argv[5]);
398         GAP = atoi(argv[6]);
399
400         nThreads = 1;
401         hasSeed = false;
402         quiet = false;
403
404         for (int i = 7; i < argc; i++) {
405                 if (!strcmp(argv[i], "-p")) nThreads = atoi(argv[i + 1]);
406                 if (!strcmp(argv[i], "--seed")) {
407                   hasSeed = true;
408                   int len = strlen(argv[i + 1]);
409                   seed = 0;
410                   for (int k = 0; k < len; k++) seed = seed * 10 + (argv[i + 1][k] - '0');
411                 }
412                 if (!strcmp(argv[i], "-q")) quiet = true;
413         }
414         verbose = !quiet;
415
416         assert(NSAMPLES > 1); // Otherwise, we cannot calculate posterior variance
417
418         if (nThreads > NSAMPLES) {
419                 nThreads = NSAMPLES;
420                 printf("Warning: Number of samples is less than number of threads! Change the number of threads to %d!\n", nThreads);
421         }
422
423         load_data(refName, statName, imdName);
424         load_group_info(refName);
425         load_omit_info(imdName);
426
427         sprintf(modelF, "%s.model", statName);
428         FILE *fi = fopen(modelF, "r");
429         general_assert(fi != NULL, "Cannot open " + cstrtos(modelF) + "!");
430         assert(fscanf(fi, "%d", &model_type) == 1);
431         fclose(fi);
432
433         mw = new double[M + 1]; // make an extra copy
434
435         switch(model_type) {
436         case 0 : init_model_related<SingleModel>(modelF); break;
437         case 1 : init_model_related<SingleQModel>(modelF); break;
438         case 2 : init_model_related<PairedEndModel>(modelF); break;
439         case 3 : init_model_related<PairedEndQModel>(modelF); break;
440         }
441
442         if (verbose) printf("Gibbs started!\n");
443
444         init();
445         for (int i = 0; i < nThreads; i++) {
446                 rc = pthread_create(&threads[i], &attr, Gibbs, (void*)(&paramsArray[i]));
447                 pthread_assert(rc, "pthread_create", "Cannot create thread " + itos(i) + " (numbered from 0)!");
448         }
449         for (int i = 0; i < nThreads; i++) {
450                 rc = pthread_join(threads[i], NULL);
451                 pthread_assert(rc, "pthread_join", "Cannot join thread " + itos(i) + " (numbered from 0)!");
452         }
453         release();
454
455         if (verbose) printf("Gibbs finished!\n");
456         
457         writeResultsGibbs(M, m, m_trans, gi, gt, ta, alleleS, imdName, pme_c, pme_fpkm, pme_tpm, pve_c, pve_c_genes, pve_c_trans);
458
459         delete mw; // delete the copy
460
461         return 0;
462 }