]> git.donarmstrong.com Git - rsem.git/blob - Gibbs.cpp
7e26d864d9258c97ee693379c88330636ef67c5c
[rsem.git] / Gibbs.cpp
1 #include<cstdio>
2 #include<cstring>
3 #include<cstdlib>
4 #include<cassert>
5 #include<fstream>
6 #include<sstream>
7 #include<vector>
8 #include<pthread.h>
9
10 #include "utils.h"
11 #include "my_assert.h"
12 #include "sampling.h"
13
14 #include "Model.h"
15 #include "SingleModel.h"
16 #include "SingleQModel.h"
17 #include "PairedEndModel.h"
18 #include "PairedEndQModel.h"
19
20 #include "Refs.h"
21
22 #include "GroupInfo.h"
23 #include "WriteResults.h"
24
25 using namespace std;
26
27 struct Params {
28         int no, nsamples;
29         FILE *fo;
30         engine_type *engine;
31         double *pme_c, *pve_c; //posterior mean and variance vectors on counts
32   double *pme_tpm, *pme_fpkm;
33 };
34
35
36 struct Item {
37         int sid;
38         double conprb;
39
40         Item(int sid, double conprb) {
41                 this->sid = sid;
42                 this->conprb = conprb;
43         }
44 };
45
46 int nThreads;
47
48 int model_type;
49 int M;
50 READ_INT_TYPE N0, N1;
51 HIT_INT_TYPE nHits;
52 double totc;
53 int BURNIN, NSAMPLES, GAP;
54 char refName[STRLEN], imdName[STRLEN], statName[STRLEN];
55 char thetaF[STRLEN], ofgF[STRLEN], refF[STRLEN], modelF[STRLEN];
56 char cvsF[STRLEN];
57
58 Refs refs;
59
60 vector<HIT_INT_TYPE> s;
61 vector<Item> hits;
62
63 vector<double> eel;
64 double *mw;
65
66 vector<double> pme_c, pve_c; //global posterior mean and variance vectors on counts
67 vector<double> pme_tpm, pme_fpkm;
68
69 bool var_opt;
70 bool quiet;
71
72 Params *paramsArray;
73 pthread_t *threads;
74 pthread_attr_t attr;
75 int rc;
76
77 bool hasSeed;
78 seedType seed;
79
80 void load_data(char* refName, char* statName, char* imdName) {
81         ifstream fin;
82         string line;
83         int tmpVal;
84
85         //load reference file
86         sprintf(refF, "%s.seq", refName);
87         refs.loadRefs(refF, 1);
88         M = refs.getM();
89
90         //load ofgF;
91         sprintf(ofgF, "%s.ofg", imdName);
92         fin.open(ofgF);
93         general_assert(fin.is_open(), "Cannot open " + cstrtos(ofgF) + "!");
94         fin>>tmpVal>>N0;
95         general_assert(tmpVal == M, "M in " + cstrtos(ofgF) + " is not consistent with " + cstrtos(refF) + "!");
96         getline(fin, line);
97
98         s.clear(); hits.clear();
99         s.push_back(0);
100         while (getline(fin, line)) {
101                 istringstream strin(line);
102                 int sid;
103                 double conprb;
104
105                 while (strin>>sid>>conprb) {
106                         hits.push_back(Item(sid, conprb));
107                 }
108                 s.push_back(hits.size());
109         }
110         fin.close();
111
112         N1 = s.size() - 1;
113         nHits = hits.size();
114
115         totc = N0 + N1 + (M + 1);
116
117         if (verbose) { printf("Loading Data is finished!\n"); }
118 }
119
120 template<class ModelType>
121 void init_model_related(char* modelF) {
122         ModelType model;
123         model.read(modelF);
124
125         calcExpectedEffectiveLengths<ModelType>(M, refs, model, eel);
126         memcpy(mw, model.getMW(), sizeof(double) * (M + 1)); // otherwise, after exiting this procedure, mw becomes undefined
127 }
128
129 // assign threads
130 void init() {
131         int quotient, left;
132         char outF[STRLEN];
133
134         quotient = NSAMPLES / nThreads;
135         left = NSAMPLES % nThreads;
136
137         sprintf(cvsF, "%s.countvectors", imdName);
138         paramsArray = new Params[nThreads];
139         threads = new pthread_t[nThreads];
140
141         hasSeed ? engineFactory::init(seed) : engineFactory::init();
142         for (int i = 0; i < nThreads; i++) {
143                 paramsArray[i].no = i;
144
145                 paramsArray[i].nsamples = quotient;
146                 if (i < left) paramsArray[i].nsamples++;
147
148                 sprintf(outF, "%s%d", cvsF, i);
149                 paramsArray[i].fo = fopen(outF, "w");
150
151                 paramsArray[i].engine = engineFactory::new_engine();
152                 paramsArray[i].pme_c = new double[M + 1];
153                 memset(paramsArray[i].pme_c, 0, sizeof(double) * (M + 1));
154                 paramsArray[i].pve_c = new double[M + 1];
155                 memset(paramsArray[i].pve_c, 0, sizeof(double) * (M + 1));
156                 paramsArray[i].pme_tpm = new double[M + 1];
157                 memset(paramsArray[i].pme_tpm, 0, sizeof(double) * (M + 1));
158                 paramsArray[i].pme_fpkm = new double[M + 1];
159                 memset(paramsArray[i].pme_fpkm, 0, sizeof(double) * (M + 1));
160         }
161         engineFactory::finish();
162
163         /* set thread attribute to be joinable */
164         pthread_attr_init(&attr);
165         pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
166
167         if (verbose) { printf("Initialization finished!\n"); }
168 }
169
170 //sample theta from Dir(1)
171 void sampleTheta(engine_type& engine, vector<double>& theta) {
172         gamma_dist gm(1);
173         gamma_generator gmg(engine, gm);
174         double denom;
175
176         theta.assign(M + 1, 0);
177         denom = 0.0;
178         for (int i = 0; i <= M; i++) {
179                 theta[i] = gmg();
180                 denom += theta[i];
181         }
182         assert(denom > EPSILON);
183         for (int i = 0; i <= M; i++) theta[i] /= denom;
184 }
185
186 void writeCountVector(FILE* fo, vector<int>& counts) {
187         for (int i = 0; i < M; i++) {
188                 fprintf(fo, "%d ", counts[i]);
189         }
190         fprintf(fo, "%d\n", counts[M]);
191 }
192
193 void* Gibbs(void* arg) {
194         int CHAINLEN;
195         HIT_INT_TYPE len, fr, to;
196         Params *params = (Params*)arg;
197
198         vector<double> theta, tpm, fpkm;
199         vector<int> z, counts;
200         vector<double> arr;
201
202         uniform01 rg(*params->engine);
203
204         // generate initial state
205         sampleTheta(*params->engine, theta);
206
207         z.assign(N1, 0);
208
209         counts.assign(M + 1, 1); // 1 pseudo count
210         counts[0] += N0;
211
212         for (READ_INT_TYPE i = 0; i < N1; i++) {
213                 fr = s[i]; to = s[i + 1];
214                 len = to - fr;
215                 arr.assign(len, 0);
216                 for (HIT_INT_TYPE j = fr; j < to; j++) {
217                         arr[j - fr] = theta[hits[j].sid] * hits[j].conprb;
218                         if (j > fr) arr[j - fr] += arr[j - fr - 1];  // cumulative
219                 }
220                 z[i] = hits[fr + sample(rg, arr, len)].sid;
221                 ++counts[z[i]];
222         }
223
224         // Gibbs sampling
225         CHAINLEN = 1 + (params->nsamples - 1) * GAP;
226         for (int ROUND = 1; ROUND <= BURNIN + CHAINLEN; ROUND++) {
227
228                 for (READ_INT_TYPE i = 0; i < N1; i++) {
229                         --counts[z[i]];
230                         fr = s[i]; to = s[i + 1]; len = to - fr;
231                         arr.assign(len, 0);
232                         for (HIT_INT_TYPE j = fr; j < to; j++) {
233                                 arr[j - fr] = counts[hits[j].sid] * hits[j].conprb;
234                                 if (j > fr) arr[j - fr] += arr[j - fr - 1]; //cumulative
235                         }
236                         z[i] = hits[fr + sample(rg, arr, len)].sid;
237                         ++counts[z[i]];
238                 }
239
240                 if (ROUND > BURNIN) {
241                         if ((ROUND - BURNIN - 1) % GAP == 0) {
242                                 writeCountVector(params->fo, counts);
243                                 for (int i = 0; i <= M; i++) theta[i] = counts[i] / totc;
244                                 polishTheta(M, theta, eel, mw);
245                                 calcExpressionValues(M, theta, eel, tpm, fpkm);
246                                 for (int i = 0; i <= M; i++) {
247                                         params->pme_c[i] += counts[i] - 1;
248                                         params->pve_c[i] += (counts[i] - 1) * (counts[i] - 1);
249                                         params->pme_tpm[i] += tpm[i];
250                                         params->pme_fpkm[i] += fpkm[i];
251                                 }
252                         }
253                 }
254
255                 if (verbose && ROUND % 100 == 0) { printf("Thread %d, ROUND %d is finished!\n", params->no, ROUND); }
256         }
257
258         return NULL;
259 }
260
261 void release() {
262 //      char inpF[STRLEN], command[STRLEN];
263         string line;
264
265         /* destroy attribute */
266         pthread_attr_destroy(&attr);
267         delete[] threads;
268
269         pme_c.assign(M + 1, 0);
270         pve_c.assign(M + 1, 0);
271         pme_tpm.assign(M + 1, 0);
272         pme_fpkm.assign(M + 1, 0);
273         for (int i = 0; i < nThreads; i++) {
274                 fclose(paramsArray[i].fo);
275                 delete paramsArray[i].engine;
276                 for (int j = 0; j <= M; j++) {
277                         pme_c[j] += paramsArray[i].pme_c[j];
278                         pve_c[j] += paramsArray[i].pve_c[j];
279                         pme_tpm[j] += paramsArray[i].pme_tpm[j];
280                         pme_fpkm[j] += paramsArray[i].pme_fpkm[j];
281                 }
282                 delete[] paramsArray[i].pme_c;
283                 delete[] paramsArray[i].pve_c;
284                 delete[] paramsArray[i].pme_tpm;
285                 delete[] paramsArray[i].pme_fpkm;
286         }
287         delete[] paramsArray;
288
289
290         for (int i = 0; i <= M; i++) {
291                 pme_c[i] /= NSAMPLES;
292                 pve_c[i] = (pve_c[i] - NSAMPLES * pme_c[i] * pme_c[i]) / (NSAMPLES - 1);
293                 pme_tpm[i] /= NSAMPLES;
294                 pme_fpkm[i] /= NSAMPLES;
295         }
296 }
297
298 int main(int argc, char* argv[]) {
299         if (argc < 7) {
300                 printf("Usage: rsem-run-gibbs reference_name imdName statName BURNIN NSAMPLES GAP [-p #Threads] [--var] [--seed seed] [-q]\n");
301                 exit(-1);
302         }
303
304         strcpy(refName, argv[1]);
305         strcpy(imdName, argv[2]);
306         strcpy(statName, argv[3]);
307
308         BURNIN = atoi(argv[4]);
309         NSAMPLES = atoi(argv[5]);
310         GAP = atoi(argv[6]);
311
312         nThreads = 1;
313         var_opt = false;
314         hasSeed = false;
315         quiet = false;
316
317         for (int i = 7; i < argc; i++) {
318                 if (!strcmp(argv[i], "-p")) nThreads = atoi(argv[i + 1]);
319                 if (!strcmp(argv[i], "--var")) var_opt = true;
320                 if (!strcmp(argv[i], "--seed")) {
321                   hasSeed = true;
322                   int len = strlen(argv[i + 1]);
323                   seed = 0;
324                   for (int k = 0; k < len; k++) seed = seed * 10 + (argv[i + 1][k] - '0');
325                 }
326                 if (!strcmp(argv[i], "-q")) quiet = true;
327         }
328         verbose = !quiet;
329
330         assert(NSAMPLES > 1); // Otherwise, we cannot calculate posterior variance
331
332         if (nThreads > NSAMPLES) {
333                 nThreads = NSAMPLES;
334                 printf("Warning: Number of samples is less than number of threads! Change the number of threads to %d!\n", nThreads);
335         }
336
337         load_data(refName, statName, imdName);
338
339         sprintf(modelF, "%s.model", statName);
340         FILE *fi = fopen(modelF, "r");
341         general_assert(fi != NULL, "Cannot open " + cstrtos(modelF) + "!");
342         assert(fscanf(fi, "%d", &model_type) == 1);
343         fclose(fi);
344
345         mw = new double[M + 1]; // make an extra copy
346
347         switch(model_type) {
348         case 0 : init_model_related<SingleModel>(modelF); break;
349         case 1 : init_model_related<SingleQModel>(modelF); break;
350         case 2 : init_model_related<PairedEndModel>(modelF); break;
351         case 3 : init_model_related<PairedEndQModel>(modelF); break;
352         }
353
354         if (verbose) printf("Gibbs started!\n");
355
356         init();
357         for (int i = 0; i < nThreads; i++) {
358                 rc = pthread_create(&threads[i], &attr, Gibbs, (void*)(&paramsArray[i]));
359                 pthread_assert(rc, "pthread_create", "Cannot create thread " + itos(i) + " (numbered from 0)!");
360         }
361         for (int i = 0; i < nThreads; i++) {
362                 rc = pthread_join(threads[i], NULL);
363                 pthread_assert(rc, "pthread_join", "Cannot join thread " + itos(i) + " (numbered from 0)!");
364         }
365         release();
366
367         if (verbose) printf("Gibbs finished!\n");
368         
369         writeResultsGibbs(M, refName, imdName, pme_c, pme_fpkm, pme_tpm);
370
371         if (var_opt) {
372                 char varF[STRLEN];
373
374                 // Load group info
375                 int m;
376                 GroupInfo gi;
377                 char groupF[STRLEN];
378                 sprintf(groupF, "%s.grp", refName);
379                 gi.load(groupF);
380                 m = gi.getm();
381                 
382                 sprintf(varF, "%s.var", statName);
383                 FILE *fo = fopen(varF, "w");
384                 general_assert(fo != NULL, "Cannot open " + cstrtos(varF) + "!");
385                 for (int i = 0; i < m; i++) {
386                         int b = gi.spAt(i), e = gi.spAt(i + 1), number_of_isoforms = e - b;
387                         for (int j = b; j < e; j++) {
388                                 fprintf(fo, "%s\t%d\t%.15g\t%.15g\n", refs.getRef(j).getName().c_str(), number_of_isoforms, pme_c[j], pve_c[j]);
389                         }
390                 }
391                 fclose(fo);
392         }
393         
394         delete mw; // delete the copy
395
396         return 0;
397 }