]> git.donarmstrong.com Git - rsem.git/blob - EM.cpp
The order of @SQ tags in SAM/BAM files can be arbitrary now
[rsem.git] / EM.cpp
1 #include<ctime>
2 #include<cmath>
3 #include<cstdio>
4 #include<cstdlib>
5 #include<cstring>
6 #include<cassert>
7 #include<string>
8 #include<vector>
9 #include<algorithm>
10 #include<pthread.h>
11
12 #include "utils.h"
13 #include "my_assert.h"
14 #include "sampling.h"
15
16 #include "Read.h"
17 #include "SingleRead.h"
18 #include "SingleReadQ.h"
19 #include "PairedEndRead.h"
20 #include "PairedEndReadQ.h"
21
22 #include "SingleHit.h"
23 #include "PairedEndHit.h"
24
25 #include "Model.h"
26 #include "SingleModel.h"
27 #include "SingleQModel.h"
28 #include "PairedEndModel.h"
29 #include "PairedEndQModel.h"
30
31 #include "Transcript.h"
32 #include "Transcripts.h"
33
34 #include "Refs.h"
35 #include "GroupInfo.h"
36 #include "HitContainer.h"
37 #include "ReadIndex.h"
38 #include "ReadReader.h"
39
40 #include "ModelParams.h"
41
42 #include "HitWrapper.h"
43 #include "BamWriter.h"
44
45 using namespace std;
46
47 const double STOP_CRITERIA = 0.001;
48 const int MAX_ROUND = 10000;
49 const int MIN_ROUND = 20;
50
51 struct Params {
52         void *model;
53         void *reader, *hitv, *ncpv, *mhp, *countv;
54 };
55
56 int read_type;
57 int m, M; // m genes, M isoforms
58 int N0, N1, N2, N_tot;
59 int nThreads;
60
61
62 bool genBamF; // If user wants to generate bam file, true; otherwise, false.
63 bool bamSampling; // true if sampling from read posterior distribution when bam file is generated
64 bool updateModel, calcExpectedWeights;
65 bool genGibbsOut; // generate file for Gibbs sampler
66
67 char refName[STRLEN], outName[STRLEN];
68 char imdName[STRLEN], statName[STRLEN];
69 char refF[STRLEN], groupF[STRLEN], cntF[STRLEN], tiF[STRLEN];
70 char mparamsF[STRLEN], bmparamsF[STRLEN];
71 char modelF[STRLEN], thetaF[STRLEN];
72
73 char inpSamType;
74 char *pt_fn_list, *pt_chr_list;
75 char inpSamF[STRLEN], outBamF[STRLEN], fn_list[STRLEN], chr_list[STRLEN];
76
77 char out_for_gibbs_F[STRLEN];
78
79 vector<double> theta, eel; // eel : expected effective length
80
81 double *probv, **countvs;
82
83 Refs refs;
84 GroupInfo gi;
85 Transcripts transcripts;
86
87 ModelParams mparams;
88
89 template<class ReadType, class HitType, class ModelType>
90 void init(ReadReader<ReadType> **&readers, HitContainer<HitType> **&hitvs, double **&ncpvs, ModelType **&mhps) {
91         int nReads, nHits, rt;
92         int nrLeft, nhT, curnr; // nrLeft : number of reads left, nhT : hit threshold per thread, curnr: current number of reads
93         char datF[STRLEN];
94
95         int s;
96         char readFs[2][STRLEN];
97         ReadIndex *indices[2];
98         ifstream fin;
99
100         readers = new ReadReader<ReadType>*[nThreads];
101         genReadFileNames(imdName, 1, read_type, s, readFs);
102         for (int i = 0; i < s; i++) {
103                 indices[i] = new ReadIndex(readFs[i]);
104         }
105         for (int i = 0; i < nThreads; i++) {
106                 readers[i] = new ReadReader<ReadType>(s, readFs, refs.hasPolyA(), mparams.seedLen); // allow calculation of calc_lq() function
107                 readers[i]->setIndices(indices);
108         }
109
110         hitvs = new HitContainer<HitType>*[nThreads];
111         for (int i = 0; i < nThreads; i++) {
112                 hitvs[i] = new HitContainer<HitType>();
113         }
114
115         sprintf(datF, "%s.dat", imdName);
116         fin.open(datF);
117         general_assert(fin.is_open(), "Cannot open " + cstrtos(datF) + "! It may not exist.");
118         fin>>nReads>>nHits>>rt;
119         general_assert(nReads == N1, "Number of alignable reads does not match!");
120         general_assert(rt == read_type, "Data file (.dat) does not have the right read type!");
121
122
123         //A just so so strategy for paralleling
124         nhT = nHits / nThreads;
125         nrLeft = N1;
126         curnr = 0;
127
128         ncpvs = new double*[nThreads];
129         for (int i = 0; i < nThreads; i++) {
130                 int ntLeft = nThreads - i - 1; // # of threads left
131
132                 general_assert(readers[i]->locate(curnr), "Read indices files do not match!");
133
134                 while (nrLeft > ntLeft && (i == nThreads - 1 || hitvs[i]->getNHits() < nhT)) {
135                         general_assert(hitvs[i]->read(fin), "Cannot read alignments from .dat file!");
136
137                         --nrLeft;
138                         if (verbose && nrLeft % 1000000 == 0) { printf("DAT %d reads left!\n", nrLeft); }
139                 }
140                 ncpvs[i] = new double[hitvs[i]->getN()];
141                 memset(ncpvs[i], 0, sizeof(double) * hitvs[i]->getN());
142                 curnr += hitvs[i]->getN();
143
144                 if (verbose) { printf("Thread %d : N = %d, NHit = %d\n", i, hitvs[i]->getN(), hitvs[i]->getNHits()); }
145         }
146
147         fin.close();
148
149         mhps = new ModelType*[nThreads];
150         for (int i = 0; i < nThreads; i++) {
151                 mhps[i] = new ModelType(mparams, false); // just model helper
152         }
153
154         probv = new double[M + 1];
155         countvs = new double*[nThreads];
156         for (int i = 0; i < nThreads; i++) {
157                 countvs[i] = new double[M + 1];
158         }
159
160
161         if (verbose) { printf("EM_init finished!\n"); }
162 }
163
164 template<class ReadType, class HitType, class ModelType>
165 void* E_STEP(void* arg) {
166         Params *params = (Params*)arg;
167         ModelType *model = (ModelType*)(params->model);
168         ReadReader<ReadType> *reader = (ReadReader<ReadType>*)(params->reader);
169         HitContainer<HitType> *hitv = (HitContainer<HitType>*)(params->hitv);
170         double *ncpv = (double*)(params->ncpv);
171         ModelType *mhp = (ModelType*)(params->mhp);
172         double *countv = (double*)(params->countv);
173
174         bool needCalcConPrb = model->getNeedCalcConPrb();
175
176         ReadType read;
177
178         int N = hitv->getN();
179         double sum;
180         vector<double> fracs; //to remove this, do calculation twice
181         int fr, to, id;
182
183         if (needCalcConPrb || updateModel) { reader->reset(); }
184         if (updateModel) { mhp->init(); }
185
186         memset(countv, 0, sizeof(double) * (M + 1));
187         for (int i = 0; i < N; i++) {
188                 if (needCalcConPrb || updateModel) {
189                         general_assert(reader->next(read), "Can not load a read!");
190                 }
191
192                 fr = hitv->getSAt(i);
193                 to = hitv->getSAt(i + 1);
194                 fracs.resize(to - fr + 1);
195
196                 sum = 0.0;
197
198                 if (needCalcConPrb) { ncpv[i] = model->getNoiseConPrb(read); }
199                 fracs[0] = probv[0] * ncpv[i];
200                 if (fracs[0] < EPSILON) fracs[0] = 0.0;
201                 sum += fracs[0];
202                 for (int j = fr; j < to; j++) {
203                         HitType &hit = hitv->getHitAt(j);
204                         if (needCalcConPrb) { hit.setConPrb(model->getConPrb(read, hit)); }
205                         id = j - fr + 1;
206                         fracs[id] = probv[hit.getSid()] * hit.getConPrb();
207                         if (fracs[id] < EPSILON) fracs[id] = 0.0;
208                         sum += fracs[id];
209                 }
210
211                 if (sum >= EPSILON) {
212                         fracs[0] /= sum;
213                         countv[0] += fracs[0];
214                         if (updateModel) { mhp->updateNoise(read, fracs[0]); }
215                         if (calcExpectedWeights) { ncpv[i] = fracs[0]; }
216                         for (int j = fr; j < to; j++) {
217                                 HitType &hit = hitv->getHitAt(j);
218                                 id = j - fr + 1;
219                                 fracs[id] /= sum;
220                                 countv[hit.getSid()] += fracs[id];
221                                 if (updateModel) { mhp->update(read, hit, fracs[id]); }
222                                 if (calcExpectedWeights) { hit.setConPrb(fracs[id]); }
223                         }                       
224                 }
225                 else if (calcExpectedWeights) {
226                         ncpv[i] = 0.0;
227                         for (int j = fr; j < to; j++) {
228                                 HitType &hit = hitv->getHitAt(j);
229                                 hit.setConPrb(0.0);
230                         }
231                 }
232         }
233
234         return NULL;
235 }
236
237 template<class ReadType, class HitType, class ModelType>
238 void* calcConProbs(void* arg) {
239         Params *params = (Params*)arg;
240         ModelType *model = (ModelType*)(params->model);
241         ReadReader<ReadType> *reader = (ReadReader<ReadType>*)(params->reader);
242         HitContainer<HitType> *hitv = (HitContainer<HitType>*)(params->hitv);
243         double *ncpv = (double*)(params->ncpv);
244
245         ReadType read;
246         int N = hitv->getN();
247         int fr, to;
248
249         assert(model->getNeedCalcConPrb());
250         reader->reset();
251
252         for (int i = 0; i < N; i++) {
253                 general_assert(reader->next(read), "Can not load a read!");
254
255                 fr = hitv->getSAt(i);
256                 to = hitv->getSAt(i + 1);
257
258                 ncpv[i] = model->getNoiseConPrb(read);
259                 for (int j = fr; j < to; j++) {
260                         HitType &hit = hitv->getHitAt(j);
261                         hit.setConPrb(model->getConPrb(read, hit));
262                 }
263         }
264
265         return NULL;
266 }
267
268 template<class ModelType>
269 void calcExpectedEffectiveLengths(ModelType& model) {
270   int lb, ub, span;
271   double *pdf = NULL, *cdf = NULL, *clen = NULL; // clen[i] = sigma_{j=1}^{i}pdf[i]*(lb+i)
272   
273   model.getGLD().copyTo(pdf, cdf, lb, ub, span);
274   clen = new double[span + 1];
275   clen[0] = 0.0;
276   for (int i = 1; i <= span; i++) {
277     clen[i] = clen[i - 1] + pdf[i] * (lb + i);
278   }
279
280   eel.clear();
281   eel.resize(M + 1, 0.0);
282   for (int i = 1; i <= M; i++) {
283     int totLen = refs.getRef(i).getTotLen();
284     int fullLen = refs.getRef(i).getFullLen();
285     int pos1 = max(min(totLen - fullLen + 1, ub) - lb, 0);
286     int pos2 = max(min(totLen, ub) - lb, 0);
287
288     if (pos2 == 0) { eel[i] = 0.0; continue; }
289     
290     eel[i] = fullLen * cdf[pos1] + ((cdf[pos2] - cdf[pos1]) * (totLen + 1) - (clen[pos2] - clen[pos1]));
291     assert(eel[i] >= 0);
292     if (eel[i] < MINEEL) { eel[i] = 0.0; }
293   }
294   
295   delete[] pdf;
296   delete[] cdf;
297   delete[] clen;
298 }
299
300 template<class ModelType>
301 void writeResults(ModelType& model, double* counts) {
302         double denom;
303         char outF[STRLEN];
304         FILE *fo;
305
306         sprintf(modelF, "%s.model", statName);
307         model.write(modelF);
308
309         //calculate tau values
310         double *tau = new double[M + 1];
311         memset(tau, 0, sizeof(double) * (M + 1));
312
313         denom = 0.0;
314         for (int i = 1; i <= M; i++) 
315           if (eel[i] >= EPSILON) {
316             tau[i] = theta[i] / eel[i];
317             denom += tau[i];
318           }   
319
320         general_assert(denom > 0, "No alignable reads?!");
321
322         for (int i = 1; i <= M; i++) {
323                 tau[i] /= denom;
324         }
325
326         //isoform level results
327         sprintf(outF, "%s.iso_res", imdName);
328         fo = fopen(outF, "w");
329         for (int i = 1; i <= M; i++) {
330                 const Transcript& transcript = transcripts.getTranscriptAt(i);
331                 fprintf(fo, "%s%c", transcript.getTranscriptID().c_str(), (i < M ? '\t' : '\n'));
332         }
333         for (int i = 1; i <= M; i++)
334                 fprintf(fo, "%.2f%c", counts[i], (i < M ? '\t' : '\n'));
335         for (int i = 1; i <= M; i++)
336                 fprintf(fo, "%.15g%c", tau[i], (i < M ? '\t' : '\n'));
337         for (int i = 1; i <= M; i++) {
338                 const Transcript& transcript = transcripts.getTranscriptAt(i);
339                 fprintf(fo, "%s%c", transcript.getGeneID().c_str(), (i < M ? '\t' : '\n'));
340         }
341         fclose(fo);
342
343         //gene level results
344         sprintf(outF, "%s.gene_res", imdName);
345         fo = fopen(outF, "w");
346         for (int i = 0; i < m; i++) {
347                 const string& gene_id = transcripts.getTranscriptAt(gi.spAt(i)).getGeneID();
348                 fprintf(fo, "%s%c", gene_id.c_str(), (i < m - 1 ? '\t' : '\n'));
349         }
350         for (int i = 0; i < m; i++) {
351                 double sumC = 0.0; // sum of counts
352                 int b = gi.spAt(i), e = gi.spAt(i + 1);
353                 for (int j = b; j < e; j++) sumC += counts[j];
354                 fprintf(fo, "%.2f%c", sumC, (i < m - 1 ? '\t' : '\n'));
355         }
356         for (int i = 0; i < m; i++) {
357                 double sumT = 0.0; // sum of tau values
358                 int b = gi.spAt(i), e = gi.spAt(i + 1);
359                 for (int j = b; j < e; j++) sumT += tau[j];
360                 fprintf(fo, "%.15g%c", sumT, (i < m - 1 ? '\t' : '\n'));
361         }
362         for (int i = 0; i < m; i++) {
363                 int b = gi.spAt(i), e = gi.spAt(i + 1);
364                 for (int j = b; j < e; j++) {
365                         fprintf(fo, "%s%c", transcripts.getTranscriptAt(j).getTranscriptID().c_str(), (j < e - 1 ? ',' : (i < m - 1 ? '\t' :'\n')));
366                 }
367         }
368         fclose(fo);
369
370         delete[] tau;
371
372         if (verbose) { printf("Expression Results are written!\n"); }
373 }
374
375 template<class ReadType, class HitType, class ModelType>
376 void release(ReadReader<ReadType> **readers, HitContainer<HitType> **hitvs, double **ncpvs, ModelType **mhps) {
377         delete[] probv;
378         for (int i = 0; i < nThreads; i++) {
379                 delete[] countvs[i];
380         }
381         delete[] countvs;
382
383         for (int i = 0; i < nThreads; i++) {
384                 delete readers[i];
385                 delete hitvs[i];
386                 delete[] ncpvs[i];
387                 delete mhps[i];
388         }
389         delete[] readers;
390         delete[] hitvs;
391         delete[] ncpvs;
392         delete[] mhps;
393 }
394
395 inline bool doesUpdateModel(int ROUND) {
396   //  return ROUND <= 20 || ROUND % 100 == 0;
397   return ROUND <= 10;
398 }
399
400 //Including initialize, algorithm and results saving
401 template<class ReadType, class HitType, class ModelType>
402 void EM() {
403         FILE *fo;
404
405         int ROUND;
406         double sum;
407
408         double bChange = 0.0, change = 0.0; // bChange : biggest change
409         int totNum = 0;
410
411         ModelType model(mparams); //master model
412         ReadReader<ReadType> **readers;
413         HitContainer<HitType> **hitvs;
414         double **ncpvs;
415         ModelType **mhps; //model helpers
416
417         Params fparams[nThreads];
418         pthread_t threads[nThreads];
419         pthread_attr_t attr;
420         void *status;
421         int rc;
422
423
424         //initialize boolean variables
425         updateModel = calcExpectedWeights = false;
426
427         theta.clear();
428         theta.resize(M + 1, 0.0);
429         init<ReadType, HitType, ModelType>(readers, hitvs, ncpvs, mhps);
430
431         //set initial parameters
432         assert(N_tot > N2);
433         theta[0] = max(N0 * 1.0 / (N_tot - N2), 1e-8);
434         double val = (1.0 - theta[0]) / M;
435         for (int i = 1; i <= M; i++) theta[i] = val;
436
437         model.estimateFromReads(imdName);
438
439         for (int i = 0; i < nThreads; i++) {
440                 fparams[i].model = (void*)(&model);
441
442                 fparams[i].reader = (void*)readers[i];
443                 fparams[i].hitv = (void*)hitvs[i];
444                 fparams[i].ncpv = (void*)ncpvs[i];
445                 fparams[i].mhp = (void*)mhps[i];
446                 fparams[i].countv = (void*)countvs[i];
447         }
448
449         /* set thread attribute to be joinable */
450         pthread_attr_init(&attr);
451         pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
452
453         ROUND = 0;
454         do {
455                 ++ROUND;
456
457                 updateModel = doesUpdateModel(ROUND);
458
459                 for (int i = 0; i <= M; i++) probv[i] = theta[i];
460
461                 //E step
462                 for (int i = 0; i < nThreads; i++) {
463                         rc = pthread_create(&threads[i], &attr, E_STEP<ReadType, HitType, ModelType>, (void*)(&fparams[i]));
464                         pthread_assert(rc, "pthread_create", "Cannot create thread " + itos(i) + " (numbered from 0) at ROUND " + itos(ROUND) + "!");
465                 }
466
467                 for (int i = 0; i < nThreads; i++) {
468                         rc = pthread_join(threads[i], &status);
469                         pthread_assert(rc, "pthread_join", "Cannot join thread " + itos(i) + " (numbered from 0) at ROUND " + itos(ROUND) + "!");
470                 }
471
472                 model.setNeedCalcConPrb(false);
473
474                 for (int i = 1; i < nThreads; i++) {
475                         for (int j = 0; j <= M; j++) {
476                                 countvs[0][j] += countvs[i][j];
477                         }
478                 }
479
480                 //add N0 noise reads
481                 countvs[0][0] += N0;
482
483                 //M step;
484                 sum = 0.0;
485                 for (int i = 0; i <= M; i++) sum += countvs[0][i];
486                 assert(sum >= EPSILON);
487                 for (int i = 0; i <= M; i++) theta[i] = countvs[0][i] / sum;
488
489                 if (updateModel) {
490                         model.init();
491                         for (int i = 0; i < nThreads; i++) { model.collect(*mhps[i]); }
492                         model.finish();
493                 }
494
495                 // Relative error
496                 bChange = 0.0; totNum = 0;
497                 for (int i = 0; i <= M; i++)
498                         if (probv[i] >= 1e-7) {
499                                 change = fabs(theta[i] - probv[i]) / probv[i];
500                                 if (change >= STOP_CRITERIA) ++totNum;
501                                 if (bChange < change) bChange = change;
502                         }
503
504                 if (verbose) printf("ROUND = %d, SUM = %.15g, bChange = %f, totNum = %d\n", ROUND, sum, bChange, totNum);
505         } while (ROUND < MIN_ROUND || (totNum > 0 && ROUND < MAX_ROUND));
506           //while (ROUND < MAX_ROUND);
507
508         if (totNum > 0) fprintf(stderr, "Warning: RSEM reaches %d iterations before meeting the convergence criteria.\n", MAX_ROUND);
509
510         //generate output file used by Gibbs sampler
511         if (genGibbsOut) {
512                 if (model.getNeedCalcConPrb()) {
513                         for (int i = 0; i < nThreads; i++) {
514                                 rc = pthread_create(&threads[i], &attr, calcConProbs<ReadType, HitType, ModelType>, (void*)(&fparams[i]));
515                                 pthread_assert(rc, "pthread_create", "Cannot create thread " + itos(i) + " (numbered from 0) when generating files for Gibbs sampler!");
516                         }
517                         for (int i = 0; i < nThreads; i++) {
518                                 rc = pthread_join(threads[i], &status);
519                                 pthread_assert(rc, "pthread_join", "Cannot join thread " + itos(i) + " (numbered from 0) when generating files for Gibbs sampler!");
520                         }
521                 }
522                 model.setNeedCalcConPrb(false);
523
524                 sprintf(out_for_gibbs_F, "%s.ofg", imdName);
525                 fo = fopen(out_for_gibbs_F, "w");
526                 fprintf(fo, "%d %d\n", M, N0);
527                 for (int i = 0; i < nThreads; i++) {
528                         int numN = hitvs[i]->getN();
529                         for (int j = 0; j < numN; j++) {
530                                 int fr = hitvs[i]->getSAt(j);
531                                 int to = hitvs[i]->getSAt(j + 1);
532                                 int totNum = 0;
533
534                                 if (ncpvs[i][j] >= EPSILON) { ++totNum; fprintf(fo, "%d %.15g ", 0, ncpvs[i][j]); }
535                                 for (int k = fr; k < to; k++) {
536                                         HitType &hit = hitvs[i]->getHitAt(k);
537                                         if (hit.getConPrb() >= EPSILON) {
538                                                 ++totNum;
539                                                 fprintf(fo, "%d %.15g ", hit.getSid(), hit.getConPrb());
540                                         }
541                                 }
542
543                                 if (totNum > 0) { fprintf(fo, "\n"); }
544                         }
545                 }
546                 fclose(fo);
547         }
548
549         sprintf(thetaF, "%s.theta", statName);
550         fo = fopen(thetaF, "w");
551         fprintf(fo, "%d\n", M + 1);
552
553         // output theta'
554         for (int i = 0; i < M; i++) fprintf(fo, "%.15g ", theta[i]);
555         fprintf(fo, "%.15g\n", theta[M]);
556         
557         //calculate expected effective lengths for each isoform
558         calcExpectedEffectiveLengths<ModelType>(model);
559
560         //correct theta vector
561         sum = theta[0];
562         for (int i = 1; i <= M; i++) 
563           if (eel[i] < EPSILON) { theta[i] = 0.0; }
564           else sum += theta[i];
565
566         general_assert(sum >= EPSILON, "No Expected Effective Length is no less than" + ftos(MINEEL, 6) + "?!");
567
568         for (int i = 0; i <= M; i++) theta[i] /= sum;
569
570         //calculate expected weights and counts using learned parameters
571         updateModel = false; calcExpectedWeights = true;
572         for (int i = 0; i <= M; i++) probv[i] = theta[i];
573         for (int i = 0; i < nThreads; i++) {
574                 rc = pthread_create(&threads[i], &attr, E_STEP<ReadType, HitType, ModelType>, (void*)(&fparams[i]));
575                 pthread_assert(rc, "pthread_create", "Cannot create thread " + itos(i) + " (numbered from 0) when calculating expected weights!");
576         }
577         for (int i = 0; i < nThreads; i++) {
578                 rc = pthread_join(threads[i], &status);
579                 pthread_assert(rc, "pthread_join", "Cannot join thread " + itos(i) + " (numbered from 0) when calculating expected weights!");
580         }
581         model.setNeedCalcConPrb(false);
582         for (int i = 1; i < nThreads; i++) {
583                 for (int j = 0; j <= M; j++) {
584                         countvs[0][j] += countvs[i][j];
585                 }
586         }
587         countvs[0][0] += N0;
588
589         /* destroy attribute */
590         pthread_attr_destroy(&attr);
591
592         //convert theta' to theta
593         double *mw = model.getMW();
594         sum = 0.0;
595         for (int i = 0; i <= M; i++) {
596           theta[i] = (mw[i] < EPSILON ? 0.0 : theta[i] / mw[i]);
597           sum += theta[i]; 
598         }
599         assert(sum >= EPSILON);
600         for (int i = 0; i <= M; i++) theta[i] /= sum;
601
602         // output theta
603         for (int i = 0; i < M; i++) fprintf(fo, "%.15g ", theta[i]);
604         fprintf(fo, "%.15g\n", theta[M]);
605
606         fclose(fo);
607
608         writeResults<ModelType>(model, countvs[0]);
609
610         if (genBamF) {
611                 sprintf(outBamF, "%s.transcript.bam", outName);
612                 
613                 if (bamSampling) {
614                         int local_N;
615                         int fr, to, len, id;
616                         vector<double> arr;
617                         uniform01 rg(engine_type(time(NULL)));
618
619                         if (verbose) printf("Begin to sample reads from their posteriors.\n");
620                         for (int i = 0; i < nThreads; i++) {
621                                 local_N = hitvs[i]->getN();
622                                 for (int j = 0; j < local_N; j++) {
623                                         fr = hitvs[i]->getSAt(j);
624                                         to = hitvs[i]->getSAt(j + 1);
625                                         len = to - fr + 1;
626                                         arr.assign(len, 0);
627                                         arr[0] = ncpvs[i][j];
628                                         for (int k = fr; k < to; k++) arr[k - fr + 1] = arr[k - fr] + hitvs[i]->getHitAt(k).getConPrb();
629                                         id = (arr[len - 1] < EPSILON ? -1 : sample(rg, arr, len)); // if all entries in arr are 0, let id be -1
630                                         for (int k = fr; k < to; k++) hitvs[i]->getHitAt(k).setConPrb(k - fr + 1 == id ? 1.0 : 0.0);
631                                 }
632                         }
633
634                         if (verbose) printf("Sampling is finished.\n");
635                 }
636
637                 BamWriter writer(inpSamType, inpSamF, pt_fn_list, outBamF, transcripts);
638                 HitWrapper<HitType> wrapper(nThreads, hitvs);
639                 writer.work(wrapper);
640         }
641
642         release<ReadType, HitType, ModelType>(readers, hitvs, ncpvs, mhps);
643 }
644
645 int main(int argc, char* argv[]) {
646         ifstream fin;
647         bool quiet = false;
648
649         if (argc < 5) {
650                 printf("Usage : rsem-run-em refName read_type sampleName sampleToken [-p #Threads] [-b samInpType samInpF has_fn_list_? [fn_list]] [-q] [--gibbs-out] [--sampling]\n\n");
651                 printf("  refName: reference name\n");
652                 printf("  read_type: 0 single read without quality score; 1 single read with quality score; 2 paired-end read without quality score; 3 paired-end read with quality score.\n");
653                 printf("  sampleName: sample's name, including the path\n");
654                 printf("  sampleToken: sampleName excludes the path\n");
655                 printf("  -p: number of threads which user wants to use. (default: 1)\n");
656                 printf("  -b: produce bam format output file. (default: off)\n");
657                 printf("  -q: set it quiet\n");
658                 printf("  --gibbs-out: generate output file used by Gibbs sampler. (default: off)\n");
659                 printf("  --sampling: sample each read from its posterior distribution when bam file is generated. (default: off)\n");
660                 printf("// model parameters should be in imdName.mparams.\n");
661                 exit(-1);
662         }
663
664         time_t a = time(NULL);
665
666         strcpy(refName, argv[1]);
667         read_type = atoi(argv[2]);
668         strcpy(outName, argv[3]);
669         sprintf(imdName, "%s.temp/%s", argv[3], argv[4]);
670         sprintf(statName, "%s.stat/%s", argv[3], argv[4]);
671
672         nThreads = 1;
673
674         genBamF = false;
675         bamSampling = false;
676         genGibbsOut = false;
677         pt_fn_list = pt_chr_list = NULL;
678
679         for (int i = 5; i < argc; i++) {
680                 if (!strcmp(argv[i], "-p")) { nThreads = atoi(argv[i + 1]); }
681                 if (!strcmp(argv[i], "-b")) {
682                         genBamF = true;
683                         inpSamType = argv[i + 1][0];
684                         strcpy(inpSamF, argv[i + 2]);
685                         if (atoi(argv[i + 3]) == 1) {
686                                 strcpy(fn_list, argv[i + 4]);
687                                 pt_fn_list = (char*)(&fn_list);
688                         }
689                 }
690                 if (!strcmp(argv[i], "-q")) { quiet = true; }
691                 if (!strcmp(argv[i], "--gibbs-out")) { genGibbsOut = true; }
692                 if (!strcmp(argv[i], "--sampling")) { bamSampling = true; }
693         }
694
695         general_assert(nThreads > 0, "Number of threads should be bigger than 0!");
696
697         verbose = !quiet;
698
699         //basic info loading
700         sprintf(refF, "%s.seq", refName);
701         refs.loadRefs(refF);
702         M = refs.getM();
703         sprintf(groupF, "%s.grp", refName);
704         gi.load(groupF);
705         m = gi.getm();
706
707         sprintf(tiF, "%s.ti", refName);
708         transcripts.readFrom(tiF);
709
710         sprintf(cntF, "%s.cnt", statName);
711         fin.open(cntF);
712
713         general_assert(fin.is_open(), "Cannot open " + cstrtos(cntF) + "! It may not exist.");
714
715         fin>>N0>>N1>>N2>>N_tot;
716         fin.close();
717
718         general_assert(N1 > 0, "There are no alignable reads!");
719
720         if (nThreads > N1) nThreads = N1;
721
722         //set model parameters
723         mparams.M = M;
724         mparams.N[0] = N0; mparams.N[1] = N1; mparams.N[2] = N2;
725         mparams.refs = &refs;
726
727         sprintf(mparamsF, "%s.mparams", imdName);
728         fin.open(mparamsF);
729
730         general_assert(fin.is_open(), "Cannot open " + cstrtos(mparamsF) + "It may not exist.");
731
732         fin>> mparams.minL>> mparams.maxL>> mparams.probF;
733         int val; // 0 or 1 , for estRSPD
734         fin>>val;
735         mparams.estRSPD = (val != 0);
736         fin>> mparams.B>> mparams.mate_minL>> mparams.mate_maxL>> mparams.mean>> mparams.sd;
737         fin>> mparams.seedLen;
738         fin.close();
739
740         //run EM
741         switch(read_type) {
742         case 0 : EM<SingleRead, SingleHit, SingleModel>(); break;
743         case 1 : EM<SingleReadQ, SingleHit, SingleQModel>(); break;
744         case 2 : EM<PairedEndRead, PairedEndHit, PairedEndModel>(); break;
745         case 3 : EM<PairedEndReadQ, PairedEndHit, PairedEndQModel>(); break;
746         default : fprintf(stderr, "Unknown Read Type!\n"); exit(-1);
747         }
748
749         time_t b = time(NULL);
750
751         printTimeUsed(a, b, "EM.cpp");
752
753         return 0;
754 }