]> git.donarmstrong.com Git - rsem.git/blob - EM.cpp
fix one bug will lead using not the last theta value estimates
[rsem.git] / EM.cpp
1 #include<ctime>
2 #include<cmath>
3 #include<cstdio>
4 #include<cstdlib>
5 #include<cstring>
6 #include<cassert>
7 #include<string>
8 #include<vector>
9 #include<algorithm>
10 #include<pthread.h>
11
12 #include "utils.h"
13
14 #include "Read.h"
15 #include "SingleRead.h"
16 #include "SingleReadQ.h"
17 #include "PairedEndRead.h"
18 #include "PairedEndReadQ.h"
19
20 #include "SingleHit.h"
21 #include "PairedEndHit.h"
22
23 #include "Model.h"
24 #include "SingleModel.h"
25 #include "SingleQModel.h"
26 #include "PairedEndModel.h"
27 #include "PairedEndQModel.h"
28
29 #include "Transcript.h"
30 #include "Transcripts.h"
31
32 #include "Refs.h"
33 #include "GroupInfo.h"
34 #include "HitContainer.h"
35 #include "ReadIndex.h"
36 #include "ReadReader.h"
37
38 #include "ModelParams.h"
39
40 #include "HitWrapper.h"
41 #include "BamWriter.h"
42
43 using namespace std;
44
45 const double STOP_CRITERIA = 0.001;
46 const int MAX_ROUND = 10000;
47 const int MIN_ROUND = 20;
48
49 struct Params {
50         void *model;
51         void *reader, *hitv, *ncpv, *mhp, *countv;
52 };
53
54 int read_type;
55 int m, M; // m genes, M isoforms
56 int N0, N1, N2, N_tot;
57 int nThreads;
58
59
60 bool genBamF; // If user wants to generate bam file, true; otherwise, false.
61 bool updateModel, calcExpectedWeights;
62 bool genGibbsOut; // generate file for Gibbs sampler
63
64 char refName[STRLEN], outName[STRLEN];
65 char imdName[STRLEN], statName[STRLEN];
66 char refF[STRLEN], groupF[STRLEN], cntF[STRLEN], tiF[STRLEN];
67 char mparamsF[STRLEN], bmparamsF[STRLEN];
68 char modelF[STRLEN], thetaF[STRLEN];
69
70 char inpSamType;
71 char *pt_fn_list, *pt_chr_list;
72 char inpSamF[STRLEN], outBamF[STRLEN], fn_list[STRLEN], chr_list[STRLEN];
73
74 char out_for_gibbs_F[STRLEN];
75
76 vector<double> theta, eel; // eel : expected effective length
77
78 double *probv, **countvs;
79
80 Refs refs;
81 GroupInfo gi;
82 Transcripts transcripts;
83
84 ModelParams mparams;
85
86 template<class ReadType, class HitType, class ModelType>
87 void init(ReadReader<ReadType> **&readers, HitContainer<HitType> **&hitvs, double **&ncpvs, ModelType **&mhps) {
88         int nReads, nHits, rt;
89         int nrLeft, nhT, curnr; // nrLeft : number of reads left, nhT : hit threshold per thread, curnr: current number of reads
90         char datF[STRLEN];
91
92         int s;
93         char readFs[2][STRLEN];
94         ReadIndex *indices[2];
95         ifstream fin;
96
97         readers = new ReadReader<ReadType>*[nThreads];
98         genReadFileNames(imdName, 1, read_type, s, readFs);
99         for (int i = 0; i < s; i++) {
100                 indices[i] = new ReadIndex(readFs[i]);
101         }
102         for (int i = 0; i < nThreads; i++) {
103                 readers[i] = new ReadReader<ReadType>(s, readFs);
104                 readers[i]->setIndices(indices);
105         }
106
107         hitvs = new HitContainer<HitType>*[nThreads];
108         for (int i = 0; i < nThreads; i++) {
109                 hitvs[i] = new HitContainer<HitType>();
110         }
111
112         sprintf(datF, "%s.dat", imdName);
113         fin.open(datF);
114         if (!fin.is_open()) { fprintf(stderr, "Cannot open %s! It may not exist.\n", datF); exit(-1); }
115         fin>>nReads>>nHits>>rt;
116         if (nReads != N1) { fprintf(stderr, "Number of alignable reads does not match!\n"); exit(-1); }
117         //assert(nReads == N1);
118         if (rt != read_type) { fprintf(stderr, "Data file (.dat) does not have the right read type!\n"); exit(-1); }
119         //assert(rt == read_type);
120
121         //A just so so strategy for paralleling
122         nhT = nHits / nThreads;
123         nrLeft = N1;
124         curnr = 0;
125
126         ncpvs = new double*[nThreads];
127         for (int i = 0; i < nThreads; i++) {
128                 int ntLeft = nThreads - i - 1; // # of threads left
129                 if (!readers[i]->locate(curnr)) { fprintf(stderr, "Read indices files do not match!\n"); exit(-1); }
130                 //assert(readers[i]->locate(curnr));
131
132                 while (nrLeft > ntLeft && (i == nThreads - 1 || hitvs[i]->getNHits() < nhT)) {
133                         if (!hitvs[i]->read(fin)) { fprintf(stderr, "Cannot read alignments from .dat file!\n"); exit(-1); }
134                         //assert(hitvs[i]->read(fin));
135                         --nrLeft;
136                         if (verbose && nrLeft % 1000000 == 0) { printf("DAT %d reads left!\n", nrLeft); }
137                 }
138                 ncpvs[i] = new double[hitvs[i]->getN()];
139                 memset(ncpvs[i], 0, sizeof(double) * hitvs[i]->getN());
140                 curnr += hitvs[i]->getN();
141
142                 if (verbose) { printf("Thread %d : N = %d, NHit = %d\n", i, hitvs[i]->getN(), hitvs[i]->getNHits()); }
143         }
144
145         fin.close();
146
147         mhps = new ModelType*[nThreads];
148         for (int i = 0; i < nThreads; i++) {
149                 mhps[i] = new ModelType(mparams, false); // just model helper
150         }
151
152         probv = new double[M + 1];
153         countvs = new double*[nThreads];
154         for (int i = 0; i < nThreads; i++) {
155                 countvs[i] = new double[M + 1];
156         }
157
158
159         if (verbose) { printf("EM_init finished!\n"); }
160 }
161
162 template<class ReadType, class HitType, class ModelType>
163 void* E_STEP(void* arg) {
164         Params *params = (Params*)arg;
165         ModelType *model = (ModelType*)(params->model);
166         ReadReader<ReadType> *reader = (ReadReader<ReadType>*)(params->reader);
167         HitContainer<HitType> *hitv = (HitContainer<HitType>*)(params->hitv);
168         double *ncpv = (double*)(params->ncpv);
169         ModelType *mhp = (ModelType*)(params->mhp);
170         double *countv = (double*)(params->countv);
171
172         bool needCalcConPrb = model->getNeedCalcConPrb();
173
174         ReadType read;
175
176         int N = hitv->getN();
177         double sum;
178         vector<double> fracs; //to remove this, do calculation twice
179         int fr, to, id;
180
181         if (needCalcConPrb || updateModel) { reader->reset(); }
182         if (updateModel) { mhp->init(); }
183
184         memset(countv, 0, sizeof(double) * (M + 1));
185         for (int i = 0; i < N; i++) {
186                 if (needCalcConPrb || updateModel) {
187                         if (!reader->next(read)) {
188                                 fprintf(stderr, "Can not load a read!\n");
189                                 exit(-1);
190                         }
191                         //assert(reader->next(read));
192                 }
193                 fr = hitv->getSAt(i);
194                 to = hitv->getSAt(i + 1);
195                 fracs.resize(to - fr + 1);
196
197                 sum = 0.0;
198
199                 if (needCalcConPrb) { ncpv[i] = model->getNoiseConPrb(read); }
200                 fracs[0] = probv[0] * ncpv[i];
201                 if (fracs[0] < EPSILON) fracs[0] = 0.0;
202                 sum += fracs[0];
203                 for (int j = fr; j < to; j++) {
204                         HitType &hit = hitv->getHitAt(j);
205                         if (needCalcConPrb) { hit.setConPrb(model->getConPrb(read, hit)); }
206                         id = j - fr + 1;
207                         fracs[id] = probv[hit.getSid()] * hit.getConPrb();
208                         if (fracs[id] < EPSILON) fracs[id] = 0.0;
209                         sum += fracs[id];
210                 }
211
212                 if (sum >= EPSILON) {
213                         fracs[0] /= sum;
214                         countv[0] += fracs[0];
215                         if (updateModel) { mhp->updateNoise(read, fracs[0]); }
216                         if (calcExpectedWeights) { ncpv[i] = fracs[0]; }
217                         for (int j = fr; j < to; j++) {
218                                 HitType &hit = hitv->getHitAt(j);
219                                 id = j - fr + 1;
220                                 fracs[id] /= sum;
221                                 countv[hit.getSid()] += fracs[id];
222                                 if (updateModel) { mhp->update(read, hit, fracs[id]); }
223                                 if (calcExpectedWeights) { hit.setConPrb(fracs[id]); }
224                         }                       
225                 }
226                 else if (calcExpectedWeights) {
227                         ncpv[i] = 0.0;
228                         for (int j = fr; j < to; j++) {
229                                 HitType &hit = hitv->getHitAt(j);
230                                 hit.setConPrb(0.0);
231                         }
232                 }
233         }
234
235         return NULL;
236 }
237
238 template<class ReadType, class HitType, class ModelType>
239 void* calcConProbs(void* arg) {
240         Params *params = (Params*)arg;
241         ModelType *model = (ModelType*)(params->model);
242         ReadReader<ReadType> *reader = (ReadReader<ReadType>*)(params->reader);
243         HitContainer<HitType> *hitv = (HitContainer<HitType>*)(params->hitv);
244         double *ncpv = (double*)(params->ncpv);
245
246         ReadType read;
247         int N = hitv->getN();
248         int fr, to;
249
250         assert(model->getNeedCalcConPrb());
251         reader->reset();
252
253         for (int i = 0; i < N; i++) {
254                 if (!reader->next(read)) {
255                         fprintf(stderr, "Can not load a read!\n");
256                         exit(-1);
257                 }
258                 fr = hitv->getSAt(i);
259                 to = hitv->getSAt(i + 1);
260
261                 ncpv[i] = model->getNoiseConPrb(read);
262                 for (int j = fr; j < to; j++) {
263                         HitType &hit = hitv->getHitAt(j);
264                         hit.setConPrb(model->getConPrb(read, hit));
265                 }
266         }
267
268         return NULL;
269 }
270
271 template<class ModelType>
272 void calcExpectedEffectiveLengths(ModelType& model) {
273   int lb, ub, span;
274   double *pdf = NULL, *cdf = NULL, *clen = NULL; // clen[i] = sigma_{j=1}^{i}pdf[i]*(lb+i)
275   
276   model.getGLD().copyTo(pdf, cdf, lb, ub, span);
277   clen = new double[span + 1];
278   clen[0] = 0.0;
279   for (int i = 1; i <= span; i++) {
280     clen[i] = clen[i - 1] + pdf[i] * (lb + i);
281   }
282
283   eel.clear();
284   eel.resize(M + 1, 0.0);
285   for (int i = 1; i <= M; i++) {
286     int totLen = refs.getRef(i).getTotLen();
287     int fullLen = refs.getRef(i).getFullLen();
288     int pos1 = max(min(totLen - fullLen + 1, ub) - lb, 0);
289     int pos2 = max(min(totLen, ub) - lb, 0);
290
291     if (pos2 == 0) { eel[i] = 0.0; continue; }
292     
293     eel[i] = fullLen * cdf[pos1] + ((cdf[pos2] - cdf[pos1]) * (totLen + 1) - (clen[pos2] - clen[pos1]));
294     assert(eel[i] >= 0);
295     if (eel[i] < MINEEL) { eel[i] = 0.0; }
296   }
297   
298   delete[] pdf;
299   delete[] cdf;
300   delete[] clen;
301 }
302
303 template<class ModelType>
304 void writeResults(ModelType& model, double* counts) {
305         double denom;
306         char outF[STRLEN];
307         FILE *fo;
308
309         sprintf(modelF, "%s.model", statName);
310         model.write(modelF);
311
312         //calculate tau values
313         double *tau = new double[M + 1];
314         memset(tau, 0, sizeof(double) * (M + 1));
315
316         denom = 0.0;
317         for (int i = 1; i <= M; i++) 
318           if (eel[i] >= EPSILON) {
319             tau[i] = theta[i] / eel[i];
320             denom += tau[i];
321           }   
322         if (denom <= 0) { fprintf(stderr, "No alignable reads?!\n"); exit(-1); }
323         //assert(denom > 0);
324         for (int i = 1; i <= M; i++) {
325                 tau[i] /= denom;
326         }
327
328         //isoform level results
329         sprintf(outF, "%s.iso_res", imdName);
330         fo = fopen(outF, "w");
331         for (int i = 1; i <= M; i++) {
332                 const Transcript& transcript = transcripts.getTranscriptAt(i);
333                 fprintf(fo, "%s%c", transcript.getTranscriptID().c_str(), (i < M ? '\t' : '\n'));
334         }
335         for (int i = 1; i <= M; i++)
336                 fprintf(fo, "%.2f%c", counts[i], (i < M ? '\t' : '\n'));
337         for (int i = 1; i <= M; i++)
338                 fprintf(fo, "%.15g%c", tau[i], (i < M ? '\t' : '\n'));
339         for (int i = 1; i <= M; i++) {
340                 const Transcript& transcript = transcripts.getTranscriptAt(i);
341                 fprintf(fo, "%s%c", transcript.getLeft().c_str(), (i < M ? '\t' : '\n'));
342         }
343         fclose(fo);
344
345         //gene level results
346         sprintf(outF, "%s.gene_res", imdName);
347         fo = fopen(outF, "w");
348         for (int i = 0; i < m; i++) {
349                 const string& gene_id = transcripts.getTranscriptAt(gi.spAt(i)).getGeneID();
350                 fprintf(fo, "%s%c", gene_id.c_str(), (i < m - 1 ? '\t' : '\n'));
351         }
352         for (int i = 0; i < m; i++) {
353                 double sumC = 0.0; // sum of counts
354                 int b = gi.spAt(i), e = gi.spAt(i + 1);
355                 for (int j = b; j < e; j++) sumC += counts[j];
356                 fprintf(fo, "%.2f%c", sumC, (i < m - 1 ? '\t' : '\n'));
357         }
358         for (int i = 0; i < m; i++) {
359                 double sumT = 0.0; // sum of tau values
360                 int b = gi.spAt(i), e = gi.spAt(i + 1);
361                 for (int j = b; j < e; j++) sumT += tau[j];
362                 fprintf(fo, "%.15g%c", sumT, (i < m - 1 ? '\t' : '\n'));
363         }
364         for (int i = 0; i < m; i++) {
365                 int b = gi.spAt(i), e = gi.spAt(i + 1);
366                 for (int j = b; j < e; j++) {
367                         fprintf(fo, "%s%c", transcripts.getTranscriptAt(j).getTranscriptID().c_str(), (j < e - 1 ? ',' : (i < m - 1 ? '\t' :'\n')));
368                 }
369         }
370         fclose(fo);
371
372         delete[] tau;
373
374         if (verbose) { printf("Expression Results are written!\n"); }
375 }
376
377 template<class ReadType, class HitType, class ModelType>
378 void release(ReadReader<ReadType> **readers, HitContainer<HitType> **hitvs, double **ncpvs, ModelType **mhps) {
379         delete[] probv;
380         for (int i = 0; i < nThreads; i++) {
381                 delete[] countvs[i];
382         }
383         delete[] countvs;
384
385         for (int i = 0; i < nThreads; i++) {
386                 delete readers[i];
387                 delete hitvs[i];
388                 delete[] ncpvs[i];
389                 delete mhps[i];
390         }
391         delete[] readers;
392         delete[] hitvs;
393         delete[] ncpvs;
394         delete[] mhps;
395 }
396
397 inline bool doesUpdateModel(int ROUND) {
398         //return false; // never update, for debugging only
399         return ROUND <= 20 || ROUND % 100 == 0;
400 }
401
402 //Including initialize, algorithm and results saving
403 template<class ReadType, class HitType, class ModelType>
404 void EM() {
405         FILE *fo;
406
407         int ROUND;
408         double sum;
409
410         double bChange = 0.0, change = 0.0; // bChange : biggest change
411         int totNum = 0;
412
413         ModelType model(mparams); //master model
414         ReadReader<ReadType> **readers;
415         HitContainer<HitType> **hitvs;
416         double **ncpvs;
417         ModelType **mhps; //model helpers
418
419         Params fparams[nThreads];
420         pthread_t threads[nThreads];
421         pthread_attr_t attr;
422         void *status;
423         int rc;
424
425
426         //initialize boolean variables
427         updateModel = calcExpectedWeights = false;
428
429         theta.clear();
430         theta.resize(M + 1, 0.0);
431         init<ReadType, HitType, ModelType>(readers, hitvs, ncpvs, mhps);
432
433         //set initial parameters
434         assert(N_tot > N2);
435         theta[0] = max(N0 * 1.0 / (N_tot - N2), 1e-8);
436         double val = (1.0 - theta[0]) / M;
437         for (int i = 1; i <= M; i++) theta[i] = val;
438
439         model.estimateFromReads(imdName);
440
441         for (int i = 0; i < nThreads; i++) {
442                 fparams[i].model = (void*)(&model);
443
444                 fparams[i].reader = (void*)readers[i];
445                 fparams[i].hitv = (void*)hitvs[i];
446                 fparams[i].ncpv = (void*)ncpvs[i];
447                 fparams[i].mhp = (void*)mhps[i];
448                 fparams[i].countv = (void*)countvs[i];
449         }
450
451         /* set thread attribute to be joinable */
452         pthread_attr_init(&attr);
453         pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
454
455         ROUND = 0;
456         do {
457                 ++ROUND;
458
459                 updateModel = doesUpdateModel(ROUND);
460
461                 for (int i = 0; i <= M; i++) probv[i] = theta[i];
462
463                 //E step
464                 for (int i = 0; i < nThreads; i++) {
465                         rc = pthread_create(&threads[i], &attr, E_STEP<ReadType, HitType, ModelType>, (void*)(&fparams[i]));
466                         if (rc != 0) { fprintf(stderr, "Cannot create thread %d at ROUND %d! (numbered from 0)\n", i, ROUND); exit(-1); }
467                         //assert(rc == 0);
468                 }
469
470                 for (int i = 0; i < nThreads; i++) {
471                         rc = pthread_join(threads[i], &status);
472                         if (rc != 0) { fprintf(stderr, "Cannot join thread %d at ROUND %d! (numbered from 0)\n", i, ROUND); exit(-1); }
473                         //assert(rc == 0);
474                 }
475
476                 model.setNeedCalcConPrb(false);
477
478                 for (int i = 1; i < nThreads; i++) {
479                         for (int j = 0; j <= M; j++) {
480                                 countvs[0][j] += countvs[i][j];
481                         }
482                 }
483
484                 //add N0 noise reads
485                 countvs[0][0] += N0;
486
487                 //M step;
488                 sum = 0.0;
489                 for (int i = 0; i <= M; i++) sum += countvs[0][i];
490                 assert(sum >= EPSILON);
491                 for (int i = 0; i <= M; i++) theta[i] = countvs[0][i] / sum;
492
493                 if (updateModel) {
494                         model.init();
495                         for (int i = 0; i < nThreads; i++) { model.collect(*mhps[i]); }
496                         model.finish();
497                 }
498
499                 // Relative error
500                 bChange = 0.0; totNum = 0;
501                 for (int i = 0; i <= M; i++)
502                         if (probv[i] >= 1e-7) {
503                                 change = fabs(theta[i] - probv[i]) / probv[i];
504                                 if (change >= STOP_CRITERIA) ++totNum;
505                                 if (bChange < change) bChange = change;
506                         }
507
508                 if (verbose) printf("ROUND = %d, SUM = %.15g, bChange = %f, totNum = %d\n", ROUND, sum, bChange, totNum);
509         } while (ROUND < MIN_ROUND || (totNum > 0 && ROUND < MAX_ROUND));
510           //while (ROUND < MAX_ROUND);
511
512         if (totNum > 0) fprintf(stderr, "Warning: RSEM reaches %d iterations before meeting the convergence criteria.\n", MAX_ROUND);
513
514         //generate output file used by Gibbs sampler
515         if (genGibbsOut) {
516                 if (model.getNeedCalcConPrb()) {
517                         for (int i = 0; i < nThreads; i++) {
518                                 rc = pthread_create(&threads[i], &attr, calcConProbs<ReadType, HitType, ModelType>, (void*)(&fparams[i]));
519                                 if (rc != 0) { fprintf(stderr, "Cannot create thread %d when generate files for Gibbs sampler! (numbered from 0)\n", i); exit(-1); }
520                         }
521                         for (int i = 0; i < nThreads; i++) {
522                                 rc = pthread_join(threads[i], &status);
523                                 if (rc != 0) { fprintf(stderr, "Cannot join thread %d when generate files for Gibbs sampler! (numbered from 0)\n", i); exit(-1); }
524                         }
525                 }
526                 model.setNeedCalcConPrb(false);
527
528                 sprintf(out_for_gibbs_F, "%s.ofg", imdName);
529                 fo = fopen(out_for_gibbs_F, "w");
530                 fprintf(fo, "%d %d\n", M, N0);
531                 for (int i = 0; i < nThreads; i++) {
532                         int numN = hitvs[i]->getN();
533                         for (int j = 0; j < numN; j++) {
534                                 int fr = hitvs[i]->getSAt(j);
535                                 int to = hitvs[i]->getSAt(j + 1);
536                                 int totNum = 0;
537
538                                 if (ncpvs[i][j] >= EPSILON) { ++totNum; fprintf(fo, "%d %.15g ", 0, ncpvs[i][j]); }
539                                 for (int k = fr; k < to; k++) {
540                                         HitType &hit = hitvs[i]->getHitAt(k);
541                                         if (hit.getConPrb() >= EPSILON) {
542                                                 ++totNum;
543                                                 fprintf(fo, "%d %.15g ", hit.getSid(), hit.getConPrb());
544                                         }
545                                 }
546
547                                 if (totNum > 0) { fprintf(fo, "\n"); }
548                         }
549                 }
550                 fclose(fo);
551         }
552
553         sprintf(thetaF, "%s.theta", statName);
554         fo = fopen(thetaF, "w");
555         fprintf(fo, "%d\n", M + 1);
556
557         // output theta'
558         for (int i = 0; i < M; i++) fprintf(fo, "%.15g ", theta[i]);
559         fprintf(fo, "%.15g\n", theta[M]);
560         
561         //calculate expected effective lengths for each isoform
562         calcExpectedEffectiveLengths<ModelType>(model);
563
564         //correct theta vector
565         sum = theta[0];
566         for (int i = 1; i <= M; i++) 
567           if (eel[i] < EPSILON) { theta[i] = 0.0; }
568           else sum += theta[i];
569         if (sum < EPSILON) { fprintf(stderr, "No Expected Effective Length is no less than %.6g?!\n", MINEEL); exit(-1); }
570         for (int i = 0; i <= M; i++) theta[i] /= sum;
571
572         //calculate expected weights and counts using learned parameters
573         updateModel = false; calcExpectedWeights = true;
574         for (int i = 0; i <= M; i++) probv[i] = theta[i];
575         for (int i = 0; i < nThreads; i++) {
576                 rc = pthread_create(&threads[i], &attr, E_STEP<ReadType, HitType, ModelType>, (void*)(&fparams[i]));
577                 if (rc != 0) { fprintf(stderr, "Cannot create thread %d when calculate expected weights! (numbered from 0)\n", i); exit(-1); }
578                 //assert(rc == 0);
579         }
580         for (int i = 0; i < nThreads; i++) {
581                 rc = pthread_join(threads[i], &status);
582                 if (rc != 0) { fprintf(stderr, "Cannot join thread %d! (numbered from 0) when calculate expected weights!\n", i); exit(-1); }
583                 //assert(rc == 0);
584         }
585         model.setNeedCalcConPrb(false);
586         for (int i = 1; i < nThreads; i++) {
587                 for (int j = 0; j <= M; j++) {
588                         countvs[0][j] += countvs[i][j];
589                 }
590         }
591         countvs[0][0] += N0;
592
593         /* destroy attribute */
594         pthread_attr_destroy(&attr);
595
596         //convert theta' to theta
597         double *mw = model.getMW();
598         sum = 0.0;
599         for (int i = 0; i <= M; i++) {
600           theta[i] = (mw[i] < EPSILON ? 0.0 : theta[i] / mw[i]);
601           sum += theta[i]; 
602         }
603         assert(sum >= EPSILON);
604         for (int i = 0; i <= M; i++) theta[i] /= sum;
605
606         // output theta
607         for (int i = 0; i < M; i++) fprintf(fo, "%.15g ", theta[i]);
608         fprintf(fo, "%.15g\n", theta[M]);
609
610         fclose(fo);
611
612         writeResults<ModelType>(model, countvs[0]);
613
614         if (genBamF) {
615                 sprintf(outBamF, "%s.bam", outName);
616                 if (transcripts.getType() == 0) {
617                         sprintf(chr_list, "%s.chrlist", refName);
618                         pt_chr_list = (char*)(&chr_list);
619                 }
620
621                 BamWriter writer(inpSamType, inpSamF, pt_fn_list, outBamF, pt_chr_list);
622                 HitWrapper<HitType> wrapper(nThreads, hitvs);
623                 writer.work(wrapper, transcripts);
624         }
625
626         release<ReadType, HitType, ModelType>(readers, hitvs, ncpvs, mhps);
627 }
628
629 int main(int argc, char* argv[]) {
630         ifstream fin;
631         bool quiet = false;
632
633         if (argc < 5) {
634                 printf("Usage : rsem-run-em refName read_type sampleName sampleToken [-p #Threads] [-b samInpType samInpF has_fn_list_? [fn_list]] [-q] [--gibbs-out]\n\n");
635                 printf("  refName: reference name\n");
636                 printf("  read_type: 0 single read without quality score; 1 single read with quality score; 2 paired-end read without quality score; 3 paired-end read with quality score.\n");
637                 printf("  sampleName: sample's name, including the path\n");
638                 printf("  sampleToken: sampleName excludes the path\n");
639                 printf("  -p: number of threads which user wants to use. (default: 1)\n");
640                 printf("  -b: produce bam format output file. (default: off)\n");
641                 printf("  -q: set it quiet\n");
642                 printf("  --gibbs-out: generate output file use by Gibbs sampler. (default: off)\n");
643                 printf("// model parameters should be in imdName.mparams.\n");
644                 exit(-1);
645         }
646
647         time_t a = time(NULL);
648
649         strcpy(refName, argv[1]);
650         read_type = atoi(argv[2]);
651         strcpy(outName, argv[3]);
652         sprintf(imdName, "%s.temp/%s", argv[3], argv[4]);
653         sprintf(statName, "%s.stat/%s", argv[3], argv[4]);
654
655         nThreads = 1;
656
657         genBamF = false;
658         genGibbsOut = false;
659         pt_fn_list = pt_chr_list = NULL;
660
661         for (int i = 5; i < argc; i++) {
662                 if (!strcmp(argv[i], "-p")) { nThreads = atoi(argv[i + 1]); }
663                 if (!strcmp(argv[i], "-b")) {
664                         genBamF = true;
665                         inpSamType = argv[i + 1][0];
666                         strcpy(inpSamF, argv[i + 2]);
667                         if (atoi(argv[i + 3]) == 1) {
668                                 strcpy(fn_list, argv[i + 4]);
669                                 pt_fn_list = (char*)(&fn_list);
670                         }
671                 }
672                 if (!strcmp(argv[i], "-q")) { quiet = true; }
673                 if (!strcmp(argv[i], "--gibbs-out")) { genGibbsOut = true; }
674         }
675         if (nThreads <= 0) { fprintf(stderr, "Number of threads should be bigger than 0!\n"); exit(-1); }
676         //assert(nThreads > 0);
677
678         verbose = !quiet;
679
680         //basic info loading
681         sprintf(refF, "%s.seq", refName);
682         refs.loadRefs(refF);
683         M = refs.getM();
684         sprintf(groupF, "%s.grp", refName);
685         gi.load(groupF);
686         m = gi.getm();
687
688         sprintf(tiF, "%s.ti", refName);
689         transcripts.readFrom(tiF);
690
691         sprintf(cntF, "%s.cnt", statName);
692         fin.open(cntF);
693         if (!fin.is_open()) { fprintf(stderr, "Cannot open %s! It may not exist.\n", cntF); exit(-1); }
694         fin>>N0>>N1>>N2>>N_tot;
695         fin.close();
696
697         if (N1 <= 0) { fprintf(stderr, "There are no alignable reads!\n"); exit(-1); }
698
699         if (nThreads > N1) nThreads = N1;
700
701         //set model parameters
702         mparams.M = M;
703         mparams.N[0] = N0; mparams.N[1] = N1; mparams.N[2] = N2;
704         mparams.refs = &refs;
705
706         sprintf(mparamsF, "%s.mparams", imdName);
707         fin.open(mparamsF);
708         if (!fin.is_open()) { fprintf(stderr, "Cannot open %s! It may not exist.\n", mparamsF); exit(-1); }
709         fin>> mparams.minL>> mparams.maxL>> mparams.probF;
710         int val; // 0 or 1 , for estRSPD
711         fin>>val;
712         mparams.estRSPD = (val != 0);
713         fin>> mparams.B>> mparams.mate_minL>> mparams.mate_maxL>> mparams.mean>> mparams.sd;
714         fin>> mparams.seedLen;
715         fin.close();
716
717         //run EM
718         switch(read_type) {
719         case 0 : EM<SingleRead, SingleHit, SingleModel>(); break;
720         case 1 : EM<SingleReadQ, SingleHit, SingleQModel>(); break;
721         case 2 : EM<PairedEndRead, PairedEndHit, PairedEndModel>(); break;
722         case 3 : EM<PairedEndReadQ, PairedEndHit, PairedEndQModel>(); break;
723         default : fprintf(stderr, "Unknown Read Type!\n"); exit(-1);
724         }
725
726         time_t b = time(NULL);
727
728         printTimeUsed(a, b);
729
730         return 0;
731 }