]> git.donarmstrong.com Git - rsem.git/blob - EM.cpp
tested version for tbam2gbam
[rsem.git] / EM.cpp
1 #include<ctime>
2 #include<cmath>
3 #include<cstdio>
4 #include<cstdlib>
5 #include<cstring>
6 #include<cassert>
7 #include<string>
8 #include<vector>
9 #include<algorithm>
10 #include<pthread.h>
11
12 #include "utils.h"
13 #include "sampling.h"
14
15 #include "Read.h"
16 #include "SingleRead.h"
17 #include "SingleReadQ.h"
18 #include "PairedEndRead.h"
19 #include "PairedEndReadQ.h"
20
21 #include "SingleHit.h"
22 #include "PairedEndHit.h"
23
24 #include "Model.h"
25 #include "SingleModel.h"
26 #include "SingleQModel.h"
27 #include "PairedEndModel.h"
28 #include "PairedEndQModel.h"
29
30 #include "Transcript.h"
31 #include "Transcripts.h"
32
33 #include "Refs.h"
34 #include "GroupInfo.h"
35 #include "HitContainer.h"
36 #include "ReadIndex.h"
37 #include "ReadReader.h"
38
39 #include "ModelParams.h"
40
41 #include "HitWrapper.h"
42 #include "BamWriter.h"
43
44 using namespace std;
45
46 const double STOP_CRITERIA = 0.001;
47 const int MAX_ROUND = 10000;
48 const int MIN_ROUND = 20;
49
50 struct Params {
51         void *model;
52         void *reader, *hitv, *ncpv, *mhp, *countv;
53 };
54
55 int read_type;
56 int m, M; // m genes, M isoforms
57 int N0, N1, N2, N_tot;
58 int nThreads;
59
60
61 bool genBamF; // If user wants to generate bam file, true; otherwise, false.
62 bool bamSampling; // true if sampling from read posterior distribution when bam file is generated
63 bool updateModel, calcExpectedWeights;
64 bool genGibbsOut; // generate file for Gibbs sampler
65
66 char refName[STRLEN], outName[STRLEN];
67 char imdName[STRLEN], statName[STRLEN];
68 char refF[STRLEN], groupF[STRLEN], cntF[STRLEN], tiF[STRLEN];
69 char mparamsF[STRLEN], bmparamsF[STRLEN];
70 char modelF[STRLEN], thetaF[STRLEN];
71
72 char inpSamType;
73 char *pt_fn_list, *pt_chr_list;
74 char inpSamF[STRLEN], outBamF[STRLEN], fn_list[STRLEN], chr_list[STRLEN];
75
76 char out_for_gibbs_F[STRLEN];
77
78 vector<double> theta, eel; // eel : expected effective length
79
80 double *probv, **countvs;
81
82 Refs refs;
83 GroupInfo gi;
84 Transcripts transcripts;
85
86 ModelParams mparams;
87
88 template<class ReadType, class HitType, class ModelType>
89 void init(ReadReader<ReadType> **&readers, HitContainer<HitType> **&hitvs, double **&ncpvs, ModelType **&mhps) {
90         int nReads, nHits, rt;
91         int nrLeft, nhT, curnr; // nrLeft : number of reads left, nhT : hit threshold per thread, curnr: current number of reads
92         char datF[STRLEN];
93
94         int s;
95         char readFs[2][STRLEN];
96         ReadIndex *indices[2];
97         ifstream fin;
98
99         readers = new ReadReader<ReadType>*[nThreads];
100         genReadFileNames(imdName, 1, read_type, s, readFs);
101         for (int i = 0; i < s; i++) {
102                 indices[i] = new ReadIndex(readFs[i]);
103         }
104         for (int i = 0; i < nThreads; i++) {
105                 readers[i] = new ReadReader<ReadType>(s, readFs);
106                 readers[i]->setIndices(indices);
107         }
108
109         hitvs = new HitContainer<HitType>*[nThreads];
110         for (int i = 0; i < nThreads; i++) {
111                 hitvs[i] = new HitContainer<HitType>();
112         }
113
114         sprintf(datF, "%s.dat", imdName);
115         fin.open(datF);
116         if (!fin.is_open()) { fprintf(stderr, "Cannot open %s! It may not exist.\n", datF); exit(-1); }
117         fin>>nReads>>nHits>>rt;
118         if (nReads != N1) { fprintf(stderr, "Number of alignable reads does not match!\n"); exit(-1); }
119         //assert(nReads == N1);
120         if (rt != read_type) { fprintf(stderr, "Data file (.dat) does not have the right read type!\n"); exit(-1); }
121         //assert(rt == read_type);
122
123         //A just so so strategy for paralleling
124         nhT = nHits / nThreads;
125         nrLeft = N1;
126         curnr = 0;
127
128         ncpvs = new double*[nThreads];
129         for (int i = 0; i < nThreads; i++) {
130                 int ntLeft = nThreads - i - 1; // # of threads left
131                 if (!readers[i]->locate(curnr)) { fprintf(stderr, "Read indices files do not match!\n"); exit(-1); }
132                 //assert(readers[i]->locate(curnr));
133
134                 while (nrLeft > ntLeft && (i == nThreads - 1 || hitvs[i]->getNHits() < nhT)) {
135                         if (!hitvs[i]->read(fin)) { fprintf(stderr, "Cannot read alignments from .dat file!\n"); exit(-1); }
136                         //assert(hitvs[i]->read(fin));
137                         --nrLeft;
138                         if (verbose && nrLeft % 1000000 == 0) { printf("DAT %d reads left!\n", nrLeft); }
139                 }
140                 ncpvs[i] = new double[hitvs[i]->getN()];
141                 memset(ncpvs[i], 0, sizeof(double) * hitvs[i]->getN());
142                 curnr += hitvs[i]->getN();
143
144                 if (verbose) { printf("Thread %d : N = %d, NHit = %d\n", i, hitvs[i]->getN(), hitvs[i]->getNHits()); }
145         }
146
147         fin.close();
148
149         mhps = new ModelType*[nThreads];
150         for (int i = 0; i < nThreads; i++) {
151                 mhps[i] = new ModelType(mparams, false); // just model helper
152         }
153
154         probv = new double[M + 1];
155         countvs = new double*[nThreads];
156         for (int i = 0; i < nThreads; i++) {
157                 countvs[i] = new double[M + 1];
158         }
159
160
161         if (verbose) { printf("EM_init finished!\n"); }
162 }
163
164 template<class ReadType, class HitType, class ModelType>
165 void* E_STEP(void* arg) {
166         Params *params = (Params*)arg;
167         ModelType *model = (ModelType*)(params->model);
168         ReadReader<ReadType> *reader = (ReadReader<ReadType>*)(params->reader);
169         HitContainer<HitType> *hitv = (HitContainer<HitType>*)(params->hitv);
170         double *ncpv = (double*)(params->ncpv);
171         ModelType *mhp = (ModelType*)(params->mhp);
172         double *countv = (double*)(params->countv);
173
174         bool needCalcConPrb = model->getNeedCalcConPrb();
175
176         ReadType read;
177
178         int N = hitv->getN();
179         double sum;
180         vector<double> fracs; //to remove this, do calculation twice
181         int fr, to, id;
182
183         if (needCalcConPrb || updateModel) { reader->reset(); }
184         if (updateModel) { mhp->init(); }
185
186         memset(countv, 0, sizeof(double) * (M + 1));
187         for (int i = 0; i < N; i++) {
188                 if (needCalcConPrb || updateModel) {
189                         if (!reader->next(read)) {
190                                 fprintf(stderr, "Can not load a read!\n");
191                                 exit(-1);
192                         }
193                         //assert(reader->next(read));
194                 }
195                 fr = hitv->getSAt(i);
196                 to = hitv->getSAt(i + 1);
197                 fracs.resize(to - fr + 1);
198
199                 sum = 0.0;
200
201                 if (needCalcConPrb) { ncpv[i] = model->getNoiseConPrb(read); }
202                 fracs[0] = probv[0] * ncpv[i];
203                 if (fracs[0] < EPSILON) fracs[0] = 0.0;
204                 sum += fracs[0];
205                 for (int j = fr; j < to; j++) {
206                         HitType &hit = hitv->getHitAt(j);
207                         if (needCalcConPrb) { hit.setConPrb(model->getConPrb(read, hit)); }
208                         id = j - fr + 1;
209                         fracs[id] = probv[hit.getSid()] * hit.getConPrb();
210                         if (fracs[id] < EPSILON) fracs[id] = 0.0;
211                         sum += fracs[id];
212                 }
213
214                 if (sum >= EPSILON) {
215                         fracs[0] /= sum;
216                         countv[0] += fracs[0];
217                         if (updateModel) { mhp->updateNoise(read, fracs[0]); }
218                         if (calcExpectedWeights) { ncpv[i] = fracs[0]; }
219                         for (int j = fr; j < to; j++) {
220                                 HitType &hit = hitv->getHitAt(j);
221                                 id = j - fr + 1;
222                                 fracs[id] /= sum;
223                                 countv[hit.getSid()] += fracs[id];
224                                 if (updateModel) { mhp->update(read, hit, fracs[id]); }
225                                 if (calcExpectedWeights) { hit.setConPrb(fracs[id]); }
226                         }                       
227                 }
228                 else if (calcExpectedWeights) {
229                         ncpv[i] = 0.0;
230                         for (int j = fr; j < to; j++) {
231                                 HitType &hit = hitv->getHitAt(j);
232                                 hit.setConPrb(0.0);
233                         }
234                 }
235         }
236
237         return NULL;
238 }
239
240 template<class ReadType, class HitType, class ModelType>
241 void* calcConProbs(void* arg) {
242         Params *params = (Params*)arg;
243         ModelType *model = (ModelType*)(params->model);
244         ReadReader<ReadType> *reader = (ReadReader<ReadType>*)(params->reader);
245         HitContainer<HitType> *hitv = (HitContainer<HitType>*)(params->hitv);
246         double *ncpv = (double*)(params->ncpv);
247
248         ReadType read;
249         int N = hitv->getN();
250         int fr, to;
251
252         assert(model->getNeedCalcConPrb());
253         reader->reset();
254
255         for (int i = 0; i < N; i++) {
256                 if (!reader->next(read)) {
257                         fprintf(stderr, "Can not load a read!\n");
258                         exit(-1);
259                 }
260                 fr = hitv->getSAt(i);
261                 to = hitv->getSAt(i + 1);
262
263                 ncpv[i] = model->getNoiseConPrb(read);
264                 for (int j = fr; j < to; j++) {
265                         HitType &hit = hitv->getHitAt(j);
266                         hit.setConPrb(model->getConPrb(read, hit));
267                 }
268         }
269
270         return NULL;
271 }
272
273 template<class ModelType>
274 void calcExpectedEffectiveLengths(ModelType& model) {
275   int lb, ub, span;
276   double *pdf = NULL, *cdf = NULL, *clen = NULL; // clen[i] = sigma_{j=1}^{i}pdf[i]*(lb+i)
277   
278   model.getGLD().copyTo(pdf, cdf, lb, ub, span);
279   clen = new double[span + 1];
280   clen[0] = 0.0;
281   for (int i = 1; i <= span; i++) {
282     clen[i] = clen[i - 1] + pdf[i] * (lb + i);
283   }
284
285   eel.clear();
286   eel.resize(M + 1, 0.0);
287   for (int i = 1; i <= M; i++) {
288     int totLen = refs.getRef(i).getTotLen();
289     int fullLen = refs.getRef(i).getFullLen();
290     int pos1 = max(min(totLen - fullLen + 1, ub) - lb, 0);
291     int pos2 = max(min(totLen, ub) - lb, 0);
292
293     if (pos2 == 0) { eel[i] = 0.0; continue; }
294     
295     eel[i] = fullLen * cdf[pos1] + ((cdf[pos2] - cdf[pos1]) * (totLen + 1) - (clen[pos2] - clen[pos1]));
296     assert(eel[i] >= 0);
297     if (eel[i] < MINEEL) { eel[i] = 0.0; }
298   }
299   
300   delete[] pdf;
301   delete[] cdf;
302   delete[] clen;
303 }
304
305 template<class ModelType>
306 void writeResults(ModelType& model, double* counts) {
307         double denom;
308         char outF[STRLEN];
309         FILE *fo;
310
311         sprintf(modelF, "%s.model", statName);
312         model.write(modelF);
313
314         //calculate tau values
315         double *tau = new double[M + 1];
316         memset(tau, 0, sizeof(double) * (M + 1));
317
318         denom = 0.0;
319         for (int i = 1; i <= M; i++) 
320           if (eel[i] >= EPSILON) {
321             tau[i] = theta[i] / eel[i];
322             denom += tau[i];
323           }   
324         if (denom <= 0) { fprintf(stderr, "No alignable reads?!\n"); exit(-1); }
325         //assert(denom > 0);
326         for (int i = 1; i <= M; i++) {
327                 tau[i] /= denom;
328         }
329
330         //isoform level results
331         sprintf(outF, "%s.iso_res", imdName);
332         fo = fopen(outF, "w");
333         for (int i = 1; i <= M; i++) {
334                 const Transcript& transcript = transcripts.getTranscriptAt(i);
335                 fprintf(fo, "%s%c", transcript.getTranscriptID().c_str(), (i < M ? '\t' : '\n'));
336         }
337         for (int i = 1; i <= M; i++)
338                 fprintf(fo, "%.2f%c", counts[i], (i < M ? '\t' : '\n'));
339         for (int i = 1; i <= M; i++)
340                 fprintf(fo, "%.15g%c", tau[i], (i < M ? '\t' : '\n'));
341         for (int i = 1; i <= M; i++) {
342                 const Transcript& transcript = transcripts.getTranscriptAt(i);
343                 fprintf(fo, "%s%c", transcript.getLeft().c_str(), (i < M ? '\t' : '\n'));
344         }
345         fclose(fo);
346
347         //gene level results
348         sprintf(outF, "%s.gene_res", imdName);
349         fo = fopen(outF, "w");
350         for (int i = 0; i < m; i++) {
351                 const string& gene_id = transcripts.getTranscriptAt(gi.spAt(i)).getGeneID();
352                 fprintf(fo, "%s%c", gene_id.c_str(), (i < m - 1 ? '\t' : '\n'));
353         }
354         for (int i = 0; i < m; i++) {
355                 double sumC = 0.0; // sum of counts
356                 int b = gi.spAt(i), e = gi.spAt(i + 1);
357                 for (int j = b; j < e; j++) sumC += counts[j];
358                 fprintf(fo, "%.2f%c", sumC, (i < m - 1 ? '\t' : '\n'));
359         }
360         for (int i = 0; i < m; i++) {
361                 double sumT = 0.0; // sum of tau values
362                 int b = gi.spAt(i), e = gi.spAt(i + 1);
363                 for (int j = b; j < e; j++) sumT += tau[j];
364                 fprintf(fo, "%.15g%c", sumT, (i < m - 1 ? '\t' : '\n'));
365         }
366         for (int i = 0; i < m; i++) {
367                 int b = gi.spAt(i), e = gi.spAt(i + 1);
368                 for (int j = b; j < e; j++) {
369                         fprintf(fo, "%s%c", transcripts.getTranscriptAt(j).getTranscriptID().c_str(), (j < e - 1 ? ',' : (i < m - 1 ? '\t' :'\n')));
370                 }
371         }
372         fclose(fo);
373
374         delete[] tau;
375
376         if (verbose) { printf("Expression Results are written!\n"); }
377 }
378
379 template<class ReadType, class HitType, class ModelType>
380 void release(ReadReader<ReadType> **readers, HitContainer<HitType> **hitvs, double **ncpvs, ModelType **mhps) {
381         delete[] probv;
382         for (int i = 0; i < nThreads; i++) {
383                 delete[] countvs[i];
384         }
385         delete[] countvs;
386
387         for (int i = 0; i < nThreads; i++) {
388                 delete readers[i];
389                 delete hitvs[i];
390                 delete[] ncpvs[i];
391                 delete mhps[i];
392         }
393         delete[] readers;
394         delete[] hitvs;
395         delete[] ncpvs;
396         delete[] mhps;
397 }
398
399 inline bool doesUpdateModel(int ROUND) {
400   //  return ROUND <= 20 || ROUND % 100 == 0;
401   return ROUND <= 10;
402 }
403
404 //Including initialize, algorithm and results saving
405 template<class ReadType, class HitType, class ModelType>
406 void EM() {
407         FILE *fo;
408
409         int ROUND;
410         double sum;
411
412         double bChange = 0.0, change = 0.0; // bChange : biggest change
413         int totNum = 0;
414
415         ModelType model(mparams); //master model
416         ReadReader<ReadType> **readers;
417         HitContainer<HitType> **hitvs;
418         double **ncpvs;
419         ModelType **mhps; //model helpers
420
421         Params fparams[nThreads];
422         pthread_t threads[nThreads];
423         pthread_attr_t attr;
424         void *status;
425         int rc;
426
427
428         //initialize boolean variables
429         updateModel = calcExpectedWeights = false;
430
431         theta.clear();
432         theta.resize(M + 1, 0.0);
433         init<ReadType, HitType, ModelType>(readers, hitvs, ncpvs, mhps);
434
435         //set initial parameters
436         assert(N_tot > N2);
437         theta[0] = max(N0 * 1.0 / (N_tot - N2), 1e-8);
438         double val = (1.0 - theta[0]) / M;
439         for (int i = 1; i <= M; i++) theta[i] = val;
440
441         model.estimateFromReads(imdName);
442
443         for (int i = 0; i < nThreads; i++) {
444                 fparams[i].model = (void*)(&model);
445
446                 fparams[i].reader = (void*)readers[i];
447                 fparams[i].hitv = (void*)hitvs[i];
448                 fparams[i].ncpv = (void*)ncpvs[i];
449                 fparams[i].mhp = (void*)mhps[i];
450                 fparams[i].countv = (void*)countvs[i];
451         }
452
453         /* set thread attribute to be joinable */
454         pthread_attr_init(&attr);
455         pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
456
457         ROUND = 0;
458         do {
459                 ++ROUND;
460
461                 updateModel = doesUpdateModel(ROUND);
462
463                 for (int i = 0; i <= M; i++) probv[i] = theta[i];
464
465                 //E step
466                 for (int i = 0; i < nThreads; i++) {
467                         rc = pthread_create(&threads[i], &attr, E_STEP<ReadType, HitType, ModelType>, (void*)(&fparams[i]));
468                         if (rc != 0) { fprintf(stderr, "Cannot create thread %d at ROUND %d! (numbered from 0)\n", i, ROUND); exit(-1); }
469                         //assert(rc == 0);
470                 }
471
472                 for (int i = 0; i < nThreads; i++) {
473                         rc = pthread_join(threads[i], &status);
474                         if (rc != 0) { fprintf(stderr, "Cannot join thread %d at ROUND %d! (numbered from 0)\n", i, ROUND); exit(-1); }
475                         //assert(rc == 0);
476                 }
477
478                 model.setNeedCalcConPrb(false);
479
480                 for (int i = 1; i < nThreads; i++) {
481                         for (int j = 0; j <= M; j++) {
482                                 countvs[0][j] += countvs[i][j];
483                         }
484                 }
485
486                 //add N0 noise reads
487                 countvs[0][0] += N0;
488
489                 //M step;
490                 sum = 0.0;
491                 for (int i = 0; i <= M; i++) sum += countvs[0][i];
492                 assert(sum >= EPSILON);
493                 for (int i = 0; i <= M; i++) theta[i] = countvs[0][i] / sum;
494
495                 if (updateModel) {
496                         model.init();
497                         for (int i = 0; i < nThreads; i++) { model.collect(*mhps[i]); }
498                         model.finish();
499                 }
500
501                 // Relative error
502                 bChange = 0.0; totNum = 0;
503                 for (int i = 0; i <= M; i++)
504                         if (probv[i] >= 1e-7) {
505                                 change = fabs(theta[i] - probv[i]) / probv[i];
506                                 if (change >= STOP_CRITERIA) ++totNum;
507                                 if (bChange < change) bChange = change;
508                         }
509
510                 if (verbose) printf("ROUND = %d, SUM = %.15g, bChange = %f, totNum = %d\n", ROUND, sum, bChange, totNum);
511         } while (ROUND < MIN_ROUND || (totNum > 0 && ROUND < MAX_ROUND));
512           //while (ROUND < MAX_ROUND);
513
514         if (totNum > 0) fprintf(stderr, "Warning: RSEM reaches %d iterations before meeting the convergence criteria.\n", MAX_ROUND);
515
516         //generate output file used by Gibbs sampler
517         if (genGibbsOut) {
518                 if (model.getNeedCalcConPrb()) {
519                         for (int i = 0; i < nThreads; i++) {
520                                 rc = pthread_create(&threads[i], &attr, calcConProbs<ReadType, HitType, ModelType>, (void*)(&fparams[i]));
521                                 if (rc != 0) { fprintf(stderr, "Cannot create thread %d when generate files for Gibbs sampler! (numbered from 0)\n", i); exit(-1); }
522                         }
523                         for (int i = 0; i < nThreads; i++) {
524                                 rc = pthread_join(threads[i], &status);
525                                 if (rc != 0) { fprintf(stderr, "Cannot join thread %d when generate files for Gibbs sampler! (numbered from 0)\n", i); exit(-1); }
526                         }
527                 }
528                 model.setNeedCalcConPrb(false);
529
530                 sprintf(out_for_gibbs_F, "%s.ofg", imdName);
531                 fo = fopen(out_for_gibbs_F, "w");
532                 fprintf(fo, "%d %d\n", M, N0);
533                 for (int i = 0; i < nThreads; i++) {
534                         int numN = hitvs[i]->getN();
535                         for (int j = 0; j < numN; j++) {
536                                 int fr = hitvs[i]->getSAt(j);
537                                 int to = hitvs[i]->getSAt(j + 1);
538                                 int totNum = 0;
539
540                                 if (ncpvs[i][j] >= EPSILON) { ++totNum; fprintf(fo, "%d %.15g ", 0, ncpvs[i][j]); }
541                                 for (int k = fr; k < to; k++) {
542                                         HitType &hit = hitvs[i]->getHitAt(k);
543                                         if (hit.getConPrb() >= EPSILON) {
544                                                 ++totNum;
545                                                 fprintf(fo, "%d %.15g ", hit.getSid(), hit.getConPrb());
546                                         }
547                                 }
548
549                                 if (totNum > 0) { fprintf(fo, "\n"); }
550                         }
551                 }
552                 fclose(fo);
553         }
554
555         sprintf(thetaF, "%s.theta", statName);
556         fo = fopen(thetaF, "w");
557         fprintf(fo, "%d\n", M + 1);
558
559         // output theta'
560         for (int i = 0; i < M; i++) fprintf(fo, "%.15g ", theta[i]);
561         fprintf(fo, "%.15g\n", theta[M]);
562         
563         //calculate expected effective lengths for each isoform
564         calcExpectedEffectiveLengths<ModelType>(model);
565
566         //correct theta vector
567         sum = theta[0];
568         for (int i = 1; i <= M; i++) 
569           if (eel[i] < EPSILON) { theta[i] = 0.0; }
570           else sum += theta[i];
571         if (sum < EPSILON) { fprintf(stderr, "No Expected Effective Length is no less than %.6g?!\n", MINEEL); exit(-1); }
572         for (int i = 0; i <= M; i++) theta[i] /= sum;
573
574         //calculate expected weights and counts using learned parameters
575         updateModel = false; calcExpectedWeights = true;
576         for (int i = 0; i <= M; i++) probv[i] = theta[i];
577         for (int i = 0; i < nThreads; i++) {
578                 rc = pthread_create(&threads[i], &attr, E_STEP<ReadType, HitType, ModelType>, (void*)(&fparams[i]));
579                 if (rc != 0) { fprintf(stderr, "Cannot create thread %d when calculate expected weights! (numbered from 0)\n", i); exit(-1); }
580                 //assert(rc == 0);
581         }
582         for (int i = 0; i < nThreads; i++) {
583                 rc = pthread_join(threads[i], &status);
584                 if (rc != 0) { fprintf(stderr, "Cannot join thread %d! (numbered from 0) when calculate expected weights!\n", i); exit(-1); }
585                 //assert(rc == 0);
586         }
587         model.setNeedCalcConPrb(false);
588         for (int i = 1; i < nThreads; i++) {
589                 for (int j = 0; j <= M; j++) {
590                         countvs[0][j] += countvs[i][j];
591                 }
592         }
593         countvs[0][0] += N0;
594
595         /* destroy attribute */
596         pthread_attr_destroy(&attr);
597
598         //convert theta' to theta
599         double *mw = model.getMW();
600         sum = 0.0;
601         for (int i = 0; i <= M; i++) {
602           theta[i] = (mw[i] < EPSILON ? 0.0 : theta[i] / mw[i]);
603           sum += theta[i]; 
604         }
605         assert(sum >= EPSILON);
606         for (int i = 0; i <= M; i++) theta[i] /= sum;
607
608         // output theta
609         for (int i = 0; i < M; i++) fprintf(fo, "%.15g ", theta[i]);
610         fprintf(fo, "%.15g\n", theta[M]);
611
612         fclose(fo);
613
614         writeResults<ModelType>(model, countvs[0]);
615
616         if (genBamF) {
617                 sprintf(outBamF, "%s.transcript.bam", outName);
618                 
619                 if (bamSampling) {
620                         int local_N;
621                         int fr, to, len, id;
622                         vector<double> arr;
623                         arr.clear();
624
625                         if (verbose) printf("Begin to sample reads from their posteriors.\n");
626                         for (int i = 0; i < nThreads; i++) {
627                                 local_N = hitvs[i]->getN();
628                                 for (int j = 0; j < local_N; j++) {
629                                         fr = hitvs[i]->getSAt(j);
630                                         to = hitvs[i]->getSAt(j + 1);
631                                         len = to - fr + 1;
632                                         arr.resize(len);
633                                         arr[0] = ncpvs[i][j];
634                                         for (int k = fr; k < to; k++) arr[k - fr + 1] = arr[k - fr] + hitvs[i]->getHitAt(k).getConPrb();
635                                         id = (arr[len - 1] < EPSILON ? -1 : sample(arr, len)); // if all entries in arr are 0, let id be -1
636                                         for (int k = fr; k < to; k++) hitvs[i]->getHitAt(k).setConPrb(k - fr + 1 == id ? 1.0 : 0.0);
637                                 }
638                         }
639                         if (verbose) printf("Sampling is finished.\n");
640                 }
641
642                 BamWriter writer(inpSamType, inpSamF, pt_fn_list, outBamF, transcripts);
643                 HitWrapper<HitType> wrapper(nThreads, hitvs);
644                 writer.work(wrapper);
645         }
646
647         release<ReadType, HitType, ModelType>(readers, hitvs, ncpvs, mhps);
648 }
649
650 int main(int argc, char* argv[]) {
651         ifstream fin;
652         bool quiet = false;
653
654         if (argc < 5) {
655                 printf("Usage : rsem-run-em refName read_type sampleName sampleToken [-p #Threads] [-b samInpType samInpF has_fn_list_? [fn_list]] [-q] [--gibbs-out] [--sampling]\n\n");
656                 printf("  refName: reference name\n");
657                 printf("  read_type: 0 single read without quality score; 1 single read with quality score; 2 paired-end read without quality score; 3 paired-end read with quality score.\n");
658                 printf("  sampleName: sample's name, including the path\n");
659                 printf("  sampleToken: sampleName excludes the path\n");
660                 printf("  -p: number of threads which user wants to use. (default: 1)\n");
661                 printf("  -b: produce bam format output file. (default: off)\n");
662                 printf("  -q: set it quiet\n");
663                 printf("  --gibbs-out: generate output file used by Gibbs sampler. (default: off)\n");
664                 printf("  --sampling: sample each read from its posterior distribution when bam file is generated. (default: off)\n");
665                 printf("// model parameters should be in imdName.mparams.\n");
666                 exit(-1);
667         }
668
669         time_t a = time(NULL);
670
671         strcpy(refName, argv[1]);
672         read_type = atoi(argv[2]);
673         strcpy(outName, argv[3]);
674         sprintf(imdName, "%s.temp/%s", argv[3], argv[4]);
675         sprintf(statName, "%s.stat/%s", argv[3], argv[4]);
676
677         nThreads = 1;
678
679         genBamF = false;
680         bamSampling = false;
681         genGibbsOut = false;
682         pt_fn_list = pt_chr_list = NULL;
683
684         for (int i = 5; i < argc; i++) {
685                 if (!strcmp(argv[i], "-p")) { nThreads = atoi(argv[i + 1]); }
686                 if (!strcmp(argv[i], "-b")) {
687                         genBamF = true;
688                         inpSamType = argv[i + 1][0];
689                         strcpy(inpSamF, argv[i + 2]);
690                         if (atoi(argv[i + 3]) == 1) {
691                                 strcpy(fn_list, argv[i + 4]);
692                                 pt_fn_list = (char*)(&fn_list);
693                         }
694                 }
695                 if (!strcmp(argv[i], "-q")) { quiet = true; }
696                 if (!strcmp(argv[i], "--gibbs-out")) { genGibbsOut = true; }
697                 if (!strcmp(argv[i], "--sampling")) { bamSampling = true; }
698         }
699         if (nThreads <= 0) { fprintf(stderr, "Number of threads should be bigger than 0!\n"); exit(-1); }
700         //assert(nThreads > 0);
701
702         verbose = !quiet;
703
704         //basic info loading
705         sprintf(refF, "%s.seq", refName);
706         refs.loadRefs(refF);
707         M = refs.getM();
708         sprintf(groupF, "%s.grp", refName);
709         gi.load(groupF);
710         m = gi.getm();
711
712         sprintf(tiF, "%s.ti", refName);
713         transcripts.readFrom(tiF);
714
715         sprintf(cntF, "%s.cnt", statName);
716         fin.open(cntF);
717         if (!fin.is_open()) { fprintf(stderr, "Cannot open %s! It may not exist.\n", cntF); exit(-1); }
718         fin>>N0>>N1>>N2>>N_tot;
719         fin.close();
720
721         if (N1 <= 0) { fprintf(stderr, "There are no alignable reads!\n"); exit(-1); }
722
723         if (nThreads > N1) nThreads = N1;
724
725         //set model parameters
726         mparams.M = M;
727         mparams.N[0] = N0; mparams.N[1] = N1; mparams.N[2] = N2;
728         mparams.refs = &refs;
729
730         sprintf(mparamsF, "%s.mparams", imdName);
731         fin.open(mparamsF);
732         if (!fin.is_open()) { fprintf(stderr, "Cannot open %s! It may not exist.\n", mparamsF); exit(-1); }
733         fin>> mparams.minL>> mparams.maxL>> mparams.probF;
734         int val; // 0 or 1 , for estRSPD
735         fin>>val;
736         mparams.estRSPD = (val != 0);
737         fin>> mparams.B>> mparams.mate_minL>> mparams.mate_maxL>> mparams.mean>> mparams.sd;
738         fin>> mparams.seedLen;
739         fin.close();
740
741         //run EM
742         switch(read_type) {
743         case 0 : EM<SingleRead, SingleHit, SingleModel>(); break;
744         case 1 : EM<SingleReadQ, SingleHit, SingleQModel>(); break;
745         case 2 : EM<PairedEndRead, PairedEndHit, PairedEndModel>(); break;
746         case 3 : EM<PairedEndReadQ, PairedEndHit, PairedEndQModel>(); break;
747         default : fprintf(stderr, "Unknown Read Type!\n"); exit(-1);
748         }
749
750         time_t b = time(NULL);
751
752         printTimeUsed(a, b);
753
754         return 0;
755 }