]> git.donarmstrong.com Git - rsem.git/blob - Gibbs.cpp
Modified the use of uniform_01
[rsem.git] / Gibbs.cpp
1 #include<cstdio>
2 #include<cstring>
3 #include<cstdlib>
4 #include<cassert>
5 #include<fstream>
6 #include<sstream>
7 #include<vector>
8 #include<pthread.h>
9
10 #include "utils.h"
11 #include "my_assert.h"
12 #include "sampling.h"
13
14 #include "Model.h"
15 #include "SingleModel.h"
16 #include "SingleQModel.h"
17 #include "PairedEndModel.h"
18 #include "PairedEndQModel.h"
19
20 #include "Refs.h"
21
22 #include "GroupInfo.h"
23 #include "WriteResults.h"
24
25 using namespace std;
26
27 struct Params {
28   int no, nsamples;
29   FILE *fo;
30   engine_type *engine;
31   double *pme_c, *pve_c; //posterior mean and variance vectors on counts
32   double *pme_tpm, *pme_fpkm;
33   
34   double *pve_c_genes, *pve_c_trans;
35 };
36
37 struct Item {
38         int sid;
39         double conprb;
40
41         Item(int sid, double conprb) {
42                 this->sid = sid;
43                 this->conprb = conprb;
44         }
45 };
46
47 int nThreads;
48
49 int model_type;
50 int M;
51 READ_INT_TYPE N0, N1;
52 HIT_INT_TYPE nHits;
53 double totc;
54 int BURNIN, NSAMPLES, GAP;
55 char refName[STRLEN], imdName[STRLEN], statName[STRLEN];
56 char thetaF[STRLEN], ofgF[STRLEN], refF[STRLEN], modelF[STRLEN];
57 char cvsF[STRLEN];
58
59 Refs refs;
60
61 vector<HIT_INT_TYPE> s;
62 vector<Item> hits;
63
64 vector<double> eel;
65 double *mw;
66
67 vector<double> pme_c, pve_c; //global posterior mean and variance vectors on counts
68 vector<double> pme_tpm, pme_fpkm;
69
70 bool quiet;
71
72 Params *paramsArray;
73 pthread_t *threads;
74 pthread_attr_t attr;
75 int rc;
76
77 bool hasSeed;
78 seedType seed;
79
80 int m;
81 char groupF[STRLEN];
82 GroupInfo gi;
83
84 bool alleleS;
85 int m_trans;
86 GroupInfo gt, ta;
87 vector<double> pve_c_genes, pve_c_trans;
88
89 void load_group_info(char* refName) {
90   // Load group info
91   sprintf(groupF, "%s.grp", refName);
92   gi.load(groupF);
93   m = gi.getm();
94   
95   alleleS = isAlleleSpecific(refName, &gt, &ta); // if allele-specific 
96   m_trans = (alleleS ? ta.getm() : 0);
97
98   if (verbose) { printf("Loading group information is finished!\n"); }
99 }
100
101 void load_data(char* refName, char* statName, char* imdName) {
102         ifstream fin;
103         string line;
104         int tmpVal;
105
106         //load reference file
107         sprintf(refF, "%s.seq", refName);
108         refs.loadRefs(refF, 1);
109         M = refs.getM();
110
111         //load ofgF;
112         sprintf(ofgF, "%s.ofg", imdName);
113         fin.open(ofgF);
114         general_assert(fin.is_open(), "Cannot open " + cstrtos(ofgF) + "!");
115         fin>>tmpVal>>N0;
116         general_assert(tmpVal == M, "M in " + cstrtos(ofgF) + " is not consistent with " + cstrtos(refF) + "!");
117         getline(fin, line);
118
119         s.clear(); hits.clear();
120         s.push_back(0);
121         while (getline(fin, line)) {
122                 istringstream strin(line);
123                 int sid;
124                 double conprb;
125
126                 while (strin>>sid>>conprb) {
127                         hits.push_back(Item(sid, conprb));
128                 }
129                 s.push_back(hits.size());
130         }
131         fin.close();
132
133         N1 = s.size() - 1;
134         nHits = hits.size();
135
136         totc = N0 + N1 + (M + 1);
137
138         if (verbose) { printf("Loading data is finished!\n"); }
139 }
140
141 template<class ModelType>
142 void init_model_related(char* modelF) {
143         ModelType model;
144         model.read(modelF);
145
146         calcExpectedEffectiveLengths<ModelType>(M, refs, model, eel);
147         memcpy(mw, model.getMW(), sizeof(double) * (M + 1)); // otherwise, after exiting this procedure, mw becomes undefined
148 }
149
150 // assign threads
151 void init() {
152         int quotient, left;
153         char outF[STRLEN];
154
155         quotient = NSAMPLES / nThreads;
156         left = NSAMPLES % nThreads;
157
158         sprintf(cvsF, "%s.countvectors", imdName);
159         paramsArray = new Params[nThreads];
160         threads = new pthread_t[nThreads];
161
162         hasSeed ? engineFactory::init(seed) : engineFactory::init();
163         for (int i = 0; i < nThreads; i++) {
164                 paramsArray[i].no = i;
165
166                 paramsArray[i].nsamples = quotient;
167                 if (i < left) paramsArray[i].nsamples++;
168
169                 sprintf(outF, "%s%d", cvsF, i);
170                 paramsArray[i].fo = fopen(outF, "w");
171
172                 paramsArray[i].engine = engineFactory::new_engine();
173                 paramsArray[i].pme_c = new double[M + 1];
174                 memset(paramsArray[i].pme_c, 0, sizeof(double) * (M + 1));
175                 paramsArray[i].pve_c = new double[M + 1];
176                 memset(paramsArray[i].pve_c, 0, sizeof(double) * (M + 1));
177                 paramsArray[i].pme_tpm = new double[M + 1];
178                 memset(paramsArray[i].pme_tpm, 0, sizeof(double) * (M + 1));
179                 paramsArray[i].pme_fpkm = new double[M + 1];
180                 memset(paramsArray[i].pme_fpkm, 0, sizeof(double) * (M + 1));
181
182                 paramsArray[i].pve_c_genes = new double[m];
183                 memset(paramsArray[i].pve_c_genes, 0, sizeof(double) * m);
184                 
185                 paramsArray[i].pve_c_trans = NULL;
186                 if (alleleS) {
187                   paramsArray[i].pve_c_trans = new double[m_trans];
188                   memset(paramsArray[i].pve_c_trans, 0, sizeof(double) * m_trans);
189                 }
190         }
191         engineFactory::finish();
192
193         /* set thread attribute to be joinable */
194         pthread_attr_init(&attr);
195         pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
196
197         if (verbose) { printf("Initialization finished!\n"); }
198 }
199
200 //sample theta from Dir(1)
201 void sampleTheta(engine_type& engine, vector<double>& theta) {
202         gamma_dist gm(1);
203         gamma_generator gmg(engine, gm);
204         double denom;
205
206         theta.assign(M + 1, 0);
207         denom = 0.0;
208         for (int i = 0; i <= M; i++) {
209                 theta[i] = gmg();
210                 denom += theta[i];
211         }
212         assert(denom > EPSILON);
213         for (int i = 0; i <= M; i++) theta[i] /= denom;
214 }
215
216 void writeCountVector(FILE* fo, vector<int>& counts) {
217         for (int i = 0; i < M; i++) {
218                 fprintf(fo, "%d ", counts[i]);
219         }
220         fprintf(fo, "%d\n", counts[M]);
221 }
222
223 void* Gibbs(void* arg) {
224         int CHAINLEN;
225         HIT_INT_TYPE len, fr, to;
226         Params *params = (Params*)arg;
227
228         vector<double> theta, tpm, fpkm;
229         vector<int> z, counts;
230         vector<double> arr;
231
232         uniform_01_generator rg(*params->engine, uniform_01_dist());
233
234         // generate initial state
235         sampleTheta(*params->engine, theta);
236
237         z.assign(N1, 0);
238
239         counts.assign(M + 1, 1); // 1 pseudo count
240         counts[0] += N0;
241
242         for (READ_INT_TYPE i = 0; i < N1; i++) {
243                 fr = s[i]; to = s[i + 1];
244                 len = to - fr;
245                 arr.assign(len, 0);
246                 for (HIT_INT_TYPE j = fr; j < to; j++) {
247                         arr[j - fr] = theta[hits[j].sid] * hits[j].conprb;
248                         if (j > fr) arr[j - fr] += arr[j - fr - 1];  // cumulative
249                 }
250                 z[i] = hits[fr + sample(rg, arr, len)].sid;
251                 ++counts[z[i]];
252         }
253
254         // Gibbs sampling
255         CHAINLEN = 1 + (params->nsamples - 1) * GAP;
256         for (int ROUND = 1; ROUND <= BURNIN + CHAINLEN; ROUND++) {
257
258                 for (READ_INT_TYPE i = 0; i < N1; i++) {
259                         --counts[z[i]];
260                         fr = s[i]; to = s[i + 1]; len = to - fr;
261                         arr.assign(len, 0);
262                         for (HIT_INT_TYPE j = fr; j < to; j++) {
263                                 arr[j - fr] = counts[hits[j].sid] * hits[j].conprb;
264                                 if (j > fr) arr[j - fr] += arr[j - fr - 1]; //cumulative
265                         }
266                         z[i] = hits[fr + sample(rg, arr, len)].sid;
267                         ++counts[z[i]];
268                 }
269
270                 if (ROUND > BURNIN) {
271                         if ((ROUND - BURNIN - 1) % GAP == 0) {
272                                 writeCountVector(params->fo, counts);
273                                 for (int i = 0; i <= M; i++) theta[i] = counts[i] / totc;
274                                 polishTheta(M, theta, eel, mw);
275                                 calcExpressionValues(M, theta, eel, tpm, fpkm);
276                                 for (int i = 0; i <= M; i++) {
277                                         params->pme_c[i] += counts[i] - 1;
278                                         params->pve_c[i] += double(counts[i] - 1) * (counts[i] - 1);
279                                         params->pme_tpm[i] += tpm[i];
280                                         params->pme_fpkm[i] += fpkm[i];
281                                 }
282
283                                 for (int i = 0; i < m; i++) {
284                                   int b = gi.spAt(i), e = gi.spAt(i + 1);
285                                   double count = 0.0;
286                                   for (int j = b; j < e; j++) count += counts[j] - 1;
287                                   params->pve_c_genes[i] += count * count;
288                                 }
289
290                                 if (alleleS)
291                                   for (int i = 0; i < m_trans; i++) {
292                                     int b = ta.spAt(i), e = ta.spAt(i + 1);
293                                     double count = 0.0;
294                                     for (int j = b; j < e; j++) count += counts[j] - 1;
295                                     params->pve_c_trans[i] += count * count;
296                                   }
297                         }
298                 }
299
300                 if (verbose && ROUND % 100 == 0) { printf("Thread %d, ROUND %d is finished!\n", params->no, ROUND); }
301         }
302
303         return NULL;
304 }
305
306 void release() {
307 //      char inpF[STRLEN], command[STRLEN];
308         string line;
309
310         /* destroy attribute */
311         pthread_attr_destroy(&attr);
312         delete[] threads;
313
314         pme_c.assign(M + 1, 0);
315         pve_c.assign(M + 1, 0);
316         pme_tpm.assign(M + 1, 0);
317         pme_fpkm.assign(M + 1, 0);
318
319         pve_c_genes.assign(m, 0);
320         pve_c_trans.clear();
321         if (alleleS) pve_c_trans.assign(m_trans, 0);
322
323         for (int i = 0; i < nThreads; i++) {
324                 fclose(paramsArray[i].fo);
325                 delete paramsArray[i].engine;
326                 for (int j = 0; j <= M; j++) {
327                         pme_c[j] += paramsArray[i].pme_c[j];
328                         pve_c[j] += paramsArray[i].pve_c[j];
329                         pme_tpm[j] += paramsArray[i].pme_tpm[j];
330                         pme_fpkm[j] += paramsArray[i].pme_fpkm[j];
331                 }
332
333                 for (int j = 0; j < m; j++) 
334                   pve_c_genes[j] += paramsArray[i].pve_c_genes[j];
335                 
336                 if (alleleS) 
337                   for (int j = 0; j < m_trans; j++) 
338                     pve_c_trans[j] += paramsArray[i].pve_c_trans[j];
339
340                 delete[] paramsArray[i].pme_c;
341                 delete[] paramsArray[i].pve_c;
342                 delete[] paramsArray[i].pme_tpm;
343                 delete[] paramsArray[i].pme_fpkm;
344
345                 delete[] paramsArray[i].pve_c_genes;
346                 if (alleleS) delete[] paramsArray[i].pve_c_trans;
347         }
348         delete[] paramsArray;
349
350         for (int i = 0; i <= M; i++) {
351                 pme_c[i] /= NSAMPLES;
352                 pve_c[i] = (pve_c[i] - double(NSAMPLES) * pme_c[i] * pme_c[i]) / double(NSAMPLES - 1);
353                 if (pve_c[i] < 0.0) pve_c[i] = 0.0;
354                 pme_tpm[i] /= NSAMPLES;
355                 pme_fpkm[i] /= NSAMPLES;
356         }
357
358         for (int i = 0; i < m; i++) {
359           int b = gi.spAt(i), e = gi.spAt(i + 1);
360           double pme_c_gene = 0.0;
361           for (int j = b; j < e; j++) pme_c_gene += pme_c[j];
362           pve_c_genes[i] = (pve_c_genes[i] - double(NSAMPLES) * pme_c_gene * pme_c_gene) / double(NSAMPLES - 1);
363           if (pve_c_genes[i] < 0.0) pve_c_genes[i] = 0.0;
364         }
365
366         if (alleleS) 
367           for (int i = 0; i < m_trans; i++) {
368             int b = ta.spAt(i), e = ta.spAt(i + 1);
369             double pme_c_tran = 0.0;
370             for (int j = b; j < e; j++) pme_c_tran += pme_c[j];
371             pve_c_trans[i] = (pve_c_trans[i] - double(NSAMPLES) * pme_c_tran * pme_c_tran) / double(NSAMPLES - 1);
372             if (pve_c_trans[i] < 0.0) pve_c_trans[i] = 0.0;
373           }
374 }
375
376 int main(int argc, char* argv[]) {
377         if (argc < 7) {
378                 printf("Usage: rsem-run-gibbs reference_name imdName statName BURNIN NSAMPLES GAP [-p #Threads] [--seed seed] [-q]\n");
379                 exit(-1);
380         }
381
382         strcpy(refName, argv[1]);
383         strcpy(imdName, argv[2]);
384         strcpy(statName, argv[3]);
385
386         BURNIN = atoi(argv[4]);
387         NSAMPLES = atoi(argv[5]);
388         GAP = atoi(argv[6]);
389
390         nThreads = 1;
391         hasSeed = false;
392         quiet = false;
393
394         for (int i = 7; i < argc; i++) {
395                 if (!strcmp(argv[i], "-p")) nThreads = atoi(argv[i + 1]);
396                 if (!strcmp(argv[i], "--seed")) {
397                   hasSeed = true;
398                   int len = strlen(argv[i + 1]);
399                   seed = 0;
400                   for (int k = 0; k < len; k++) seed = seed * 10 + (argv[i + 1][k] - '0');
401                 }
402                 if (!strcmp(argv[i], "-q")) quiet = true;
403         }
404         verbose = !quiet;
405
406         assert(NSAMPLES > 1); // Otherwise, we cannot calculate posterior variance
407
408         if (nThreads > NSAMPLES) {
409                 nThreads = NSAMPLES;
410                 printf("Warning: Number of samples is less than number of threads! Change the number of threads to %d!\n", nThreads);
411         }
412
413         load_group_info(refName);
414         load_data(refName, statName, imdName);
415
416         sprintf(modelF, "%s.model", statName);
417         FILE *fi = fopen(modelF, "r");
418         general_assert(fi != NULL, "Cannot open " + cstrtos(modelF) + "!");
419         assert(fscanf(fi, "%d", &model_type) == 1);
420         fclose(fi);
421
422         mw = new double[M + 1]; // make an extra copy
423
424         switch(model_type) {
425         case 0 : init_model_related<SingleModel>(modelF); break;
426         case 1 : init_model_related<SingleQModel>(modelF); break;
427         case 2 : init_model_related<PairedEndModel>(modelF); break;
428         case 3 : init_model_related<PairedEndQModel>(modelF); break;
429         }
430
431         if (verbose) printf("Gibbs started!\n");
432
433         init();
434         for (int i = 0; i < nThreads; i++) {
435                 rc = pthread_create(&threads[i], &attr, Gibbs, (void*)(&paramsArray[i]));
436                 pthread_assert(rc, "pthread_create", "Cannot create thread " + itos(i) + " (numbered from 0)!");
437         }
438         for (int i = 0; i < nThreads; i++) {
439                 rc = pthread_join(threads[i], NULL);
440                 pthread_assert(rc, "pthread_join", "Cannot join thread " + itos(i) + " (numbered from 0)!");
441         }
442         release();
443
444         if (verbose) printf("Gibbs finished!\n");
445         
446         writeResultsGibbs(M, m, m_trans, gi, gt, ta, alleleS, imdName, pme_c, pme_fpkm, pme_tpm, pve_c, pve_c_genes, pve_c_trans);
447
448         delete mw; // delete the copy
449
450         return 0;
451 }