]> git.donarmstrong.com Git - rsem.git/blob - Gibbs.cpp
Modified the acknowledgement section of README.md
[rsem.git] / Gibbs.cpp
1 #include<cstdio>
2 #include<cstring>
3 #include<cstdlib>
4 #include<cassert>
5 #include<fstream>
6 #include<sstream>
7 #include<vector>
8 #include<pthread.h>
9
10 #include "utils.h"
11 #include "my_assert.h"
12 #include "sampling.h"
13
14 #include "Model.h"
15 #include "SingleModel.h"
16 #include "SingleQModel.h"
17 #include "PairedEndModel.h"
18 #include "PairedEndQModel.h"
19
20 #include "Refs.h"
21
22 #include "GroupInfo.h"
23 #include "WriteResults.h"
24
25 using namespace std;
26
27 struct Params {
28         int no, nsamples;
29         FILE *fo;
30         engine_type *engine;
31         double *pme_c, *pve_c; //posterior mean and variance vectors on counts
32   double *pme_tpm, *pme_fpkm;
33 };
34
35
36 struct Item {
37         int sid;
38         double conprb;
39
40         Item(int sid, double conprb) {
41                 this->sid = sid;
42                 this->conprb = conprb;
43         }
44 };
45
46 int nThreads;
47
48 int model_type;
49 int M;
50 READ_INT_TYPE N0, N1;
51 HIT_INT_TYPE nHits;
52 double totc;
53 int BURNIN, NSAMPLES, GAP;
54 char refName[STRLEN], imdName[STRLEN], statName[STRLEN];
55 char thetaF[STRLEN], ofgF[STRLEN], refF[STRLEN], modelF[STRLEN];
56 char cvsF[STRLEN];
57
58 Refs refs;
59
60 vector<HIT_INT_TYPE> s;
61 vector<Item> hits;
62
63 vector<double> eel;
64 double *mw;
65
66 vector<double> pme_c, pve_c; //global posterior mean and variance vectors on counts
67 vector<double> pme_tpm, pme_fpkm;
68
69 bool var_opt;
70 bool quiet;
71
72 Params *paramsArray;
73 pthread_t *threads;
74 pthread_attr_t attr;
75 int rc;
76
77 void load_data(char* refName, char* statName, char* imdName) {
78         ifstream fin;
79         string line;
80         int tmpVal;
81
82         //load reference file
83         sprintf(refF, "%s.seq", refName);
84         refs.loadRefs(refF, 1);
85         M = refs.getM();
86
87         //load ofgF;
88         sprintf(ofgF, "%s.ofg", imdName);
89         fin.open(ofgF);
90         general_assert(fin.is_open(), "Cannot open " + cstrtos(ofgF) + "!");
91         fin>>tmpVal>>N0;
92         general_assert(tmpVal == M, "M in " + cstrtos(ofgF) + " is not consistent with " + cstrtos(refF) + "!");
93         getline(fin, line);
94
95         s.clear(); hits.clear();
96         s.push_back(0);
97         while (getline(fin, line)) {
98                 istringstream strin(line);
99                 int sid;
100                 double conprb;
101
102                 while (strin>>sid>>conprb) {
103                         hits.push_back(Item(sid, conprb));
104                 }
105                 s.push_back(hits.size());
106         }
107         fin.close();
108
109         N1 = s.size() - 1;
110         nHits = hits.size();
111
112         totc = N0 + N1 + (M + 1);
113
114         if (verbose) { printf("Loading Data is finished!\n"); }
115 }
116
117 template<class ModelType>
118 void init_model_related(char* modelF) {
119         ModelType model;
120         model.read(modelF);
121
122         calcExpectedEffectiveLengths<ModelType>(M, refs, model, eel);
123         memcpy(mw, model.getMW(), sizeof(double) * (M + 1)); // otherwise, after exiting this procedure, mw becomes undefined
124 }
125
126 // assign threads
127 void init() {
128         int quotient, left;
129         char outF[STRLEN];
130
131         quotient = NSAMPLES / nThreads;
132         left = NSAMPLES % nThreads;
133
134         sprintf(cvsF, "%s.countvectors", imdName);
135         paramsArray = new Params[nThreads];
136         threads = new pthread_t[nThreads];
137
138         for (int i = 0; i < nThreads; i++) {
139                 paramsArray[i].no = i;
140
141                 paramsArray[i].nsamples = quotient;
142                 if (i < left) paramsArray[i].nsamples++;
143
144                 sprintf(outF, "%s%d", cvsF, i);
145                 paramsArray[i].fo = fopen(outF, "w");
146
147                 paramsArray[i].engine = engineFactory::new_engine();
148                 paramsArray[i].pme_c = new double[M + 1];
149                 memset(paramsArray[i].pme_c, 0, sizeof(double) * (M + 1));
150                 paramsArray[i].pve_c = new double[M + 1];
151                 memset(paramsArray[i].pve_c, 0, sizeof(double) * (M + 1));
152                 paramsArray[i].pme_tpm = new double[M + 1];
153                 memset(paramsArray[i].pme_tpm, 0, sizeof(double) * (M + 1));
154                 paramsArray[i].pme_fpkm = new double[M + 1];
155                 memset(paramsArray[i].pme_fpkm, 0, sizeof(double) * (M + 1));
156         }
157
158         /* set thread attribute to be joinable */
159         pthread_attr_init(&attr);
160         pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
161
162         if (verbose) { printf("Initialization finished!\n"); }
163 }
164
165 //sample theta from Dir(1)
166 void sampleTheta(engine_type& engine, vector<double>& theta) {
167         gamma_dist gm(1);
168         gamma_generator gmg(engine, gm);
169         double denom;
170
171         theta.assign(M + 1, 0);
172         denom = 0.0;
173         for (int i = 0; i <= M; i++) {
174                 theta[i] = gmg();
175                 denom += theta[i];
176         }
177         assert(denom > EPSILON);
178         for (int i = 0; i <= M; i++) theta[i] /= denom;
179 }
180
181 void writeCountVector(FILE* fo, vector<int>& counts) {
182         for (int i = 0; i < M; i++) {
183                 fprintf(fo, "%d ", counts[i]);
184         }
185         fprintf(fo, "%d\n", counts[M]);
186 }
187
188 void* Gibbs(void* arg) {
189         int CHAINLEN;
190         HIT_INT_TYPE len, fr, to;
191         Params *params = (Params*)arg;
192
193         vector<double> theta, tpm, fpkm;
194         vector<int> z, counts;
195         vector<double> arr;
196
197         uniform01 rg(*params->engine);
198
199         // generate initial state
200         sampleTheta(*params->engine, theta);
201
202         z.assign(N1, 0);
203
204         counts.assign(M + 1, 1); // 1 pseudo count
205         counts[0] += N0;
206
207         for (READ_INT_TYPE i = 0; i < N1; i++) {
208                 fr = s[i]; to = s[i + 1];
209                 len = to - fr;
210                 arr.assign(len, 0);
211                 for (HIT_INT_TYPE j = fr; j < to; j++) {
212                         arr[j - fr] = theta[hits[j].sid] * hits[j].conprb;
213                         if (j > fr) arr[j - fr] += arr[j - fr - 1];  // cumulative
214                 }
215                 z[i] = hits[fr + sample(rg, arr, len)].sid;
216                 ++counts[z[i]];
217         }
218
219         // Gibbs sampling
220         CHAINLEN = 1 + (params->nsamples - 1) * GAP;
221         for (int ROUND = 1; ROUND <= BURNIN + CHAINLEN; ROUND++) {
222
223                 for (READ_INT_TYPE i = 0; i < N1; i++) {
224                         --counts[z[i]];
225                         fr = s[i]; to = s[i + 1]; len = to - fr;
226                         arr.assign(len, 0);
227                         for (HIT_INT_TYPE j = fr; j < to; j++) {
228                                 arr[j - fr] = counts[hits[j].sid] * hits[j].conprb;
229                                 if (j > fr) arr[j - fr] += arr[j - fr - 1]; //cumulative
230                         }
231                         z[i] = hits[fr + sample(rg, arr, len)].sid;
232                         ++counts[z[i]];
233                 }
234
235                 if (ROUND > BURNIN) {
236                         if ((ROUND - BURNIN - 1) % GAP == 0) {
237                                 writeCountVector(params->fo, counts);
238                                 for (int i = 0; i <= M; i++) theta[i] = counts[i] / totc;
239                                 polishTheta(M, theta, eel, mw);
240                                 calcExpressionValues(M, theta, eel, tpm, fpkm);
241                                 for (int i = 0; i <= M; i++) {
242                                         params->pme_c[i] += counts[i] - 1;
243                                         params->pve_c[i] += (counts[i] - 1) * (counts[i] - 1);
244                                         params->pme_tpm[i] += tpm[i];
245                                         params->pme_fpkm[i] += fpkm[i];
246                                 }
247                         }
248                 }
249
250                 if (verbose && ROUND % 100 == 0) { printf("Thread %d, ROUND %d is finished!\n", params->no, ROUND); }
251         }
252
253         return NULL;
254 }
255
256 void release() {
257 //      char inpF[STRLEN], command[STRLEN];
258         string line;
259
260         /* destroy attribute */
261         pthread_attr_destroy(&attr);
262         delete[] threads;
263
264         pme_c.assign(M + 1, 0);
265         pve_c.assign(M + 1, 0);
266         pme_tpm.assign(M + 1, 0);
267         pme_fpkm.assign(M + 1, 0);
268         for (int i = 0; i < nThreads; i++) {
269                 fclose(paramsArray[i].fo);
270                 delete paramsArray[i].engine;
271                 for (int j = 0; j <= M; j++) {
272                         pme_c[j] += paramsArray[i].pme_c[j];
273                         pve_c[j] += paramsArray[i].pve_c[j];
274                         pme_tpm[j] += paramsArray[i].pme_tpm[j];
275                         pme_fpkm[j] += paramsArray[i].pme_fpkm[j];
276                 }
277                 delete[] paramsArray[i].pme_c;
278                 delete[] paramsArray[i].pve_c;
279                 delete[] paramsArray[i].pme_tpm;
280                 delete[] paramsArray[i].pme_fpkm;
281         }
282         delete[] paramsArray;
283
284
285         for (int i = 0; i <= M; i++) {
286                 pme_c[i] /= NSAMPLES;
287                 pve_c[i] = (pve_c[i] - NSAMPLES * pme_c[i] * pme_c[i]) / (NSAMPLES - 1);
288                 pme_tpm[i] /= NSAMPLES;
289                 pme_fpkm[i] /= NSAMPLES;
290         }
291 }
292
293 int main(int argc, char* argv[]) {
294         if (argc < 7) {
295                 printf("Usage: rsem-run-gibbs reference_name imdName statName BURNIN NSAMPLES GAP [-p #Threads] [--var] [-q]\n");
296                 exit(-1);
297         }
298
299         strcpy(refName, argv[1]);
300         strcpy(imdName, argv[2]);
301         strcpy(statName, argv[3]);
302
303         BURNIN = atoi(argv[4]);
304         NSAMPLES = atoi(argv[5]);
305         GAP = atoi(argv[6]);
306
307         nThreads = 1;
308         var_opt = false;
309         quiet = false;
310
311         for (int i = 7; i < argc; i++) {
312                 if (!strcmp(argv[i], "-p")) nThreads = atoi(argv[i + 1]);
313                 if (!strcmp(argv[i], "--var")) var_opt = true;
314                 if (!strcmp(argv[i], "-q")) quiet = true;
315         }
316         verbose = !quiet;
317
318         assert(NSAMPLES > 1); // Otherwise, we cannot calculate posterior variance
319
320         if (nThreads > NSAMPLES) {
321                 nThreads = NSAMPLES;
322                 printf("Warning: Number of samples is less than number of threads! Change the number of threads to %d!\n", nThreads);
323         }
324
325         load_data(refName, statName, imdName);
326
327         sprintf(modelF, "%s.model", statName);
328         FILE *fi = fopen(modelF, "r");
329         general_assert(fi != NULL, "Cannot open " + cstrtos(modelF) + "!");
330         assert(fscanf(fi, "%d", &model_type) == 1);
331         fclose(fi);
332
333         mw = new double[M + 1]; // make an extra copy
334
335         switch(model_type) {
336         case 0 : init_model_related<SingleModel>(modelF); break;
337         case 1 : init_model_related<SingleQModel>(modelF); break;
338         case 2 : init_model_related<PairedEndModel>(modelF); break;
339         case 3 : init_model_related<PairedEndQModel>(modelF); break;
340         }
341
342         if (verbose) printf("Gibbs started!\n");
343
344         init();
345         for (int i = 0; i < nThreads; i++) {
346                 rc = pthread_create(&threads[i], &attr, Gibbs, (void*)(&paramsArray[i]));
347                 pthread_assert(rc, "pthread_create", "Cannot create thread " + itos(i) + " (numbered from 0)!");
348         }
349         for (int i = 0; i < nThreads; i++) {
350                 rc = pthread_join(threads[i], NULL);
351                 pthread_assert(rc, "pthread_join", "Cannot join thread " + itos(i) + " (numbered from 0)!");
352         }
353         release();
354
355         if (verbose) printf("Gibbs finished!\n");
356         
357         writeResultsGibbs(M, refName, imdName, pme_c, pme_fpkm, pme_tpm);
358
359         if (var_opt) {
360                 char varF[STRLEN];
361
362                 // Load group info
363                 int m;
364                 GroupInfo gi;
365                 char groupF[STRLEN];
366                 sprintf(groupF, "%s.grp", refName);
367                 gi.load(groupF);
368                 m = gi.getm();
369                 
370                 sprintf(varF, "%s.var", statName);
371                 FILE *fo = fopen(varF, "w");
372                 general_assert(fo != NULL, "Cannot open " + cstrtos(varF) + "!");
373                 for (int i = 0; i < m; i++) {
374                         int b = gi.spAt(i), e = gi.spAt(i + 1), number_of_isoforms = e - b;
375                         for (int j = b; j < e; j++) {
376                                 fprintf(fo, "%s\t%d\t%.15g\t%.15g\n", refs.getRef(j).getName().c_str(), number_of_isoforms, pme_c[j], pve_c[j]);
377                         }
378                 }
379                 fclose(fo);
380         }
381         
382         delete mw; // delete the copy
383
384         return 0;
385 }