]> git.donarmstrong.com Git - rsem.git/blob - EM.cpp
rsem v1.1.13, speed up EM by only updating model parameters for first 10 iterations...
[rsem.git] / EM.cpp
1 #include<ctime>
2 #include<cmath>
3 #include<cstdio>
4 #include<cstdlib>
5 #include<cstring>
6 #include<cassert>
7 #include<string>
8 #include<vector>
9 #include<algorithm>
10 #include<pthread.h>
11
12 #include "utils.h"
13
14 #include "Read.h"
15 #include "SingleRead.h"
16 #include "SingleReadQ.h"
17 #include "PairedEndRead.h"
18 #include "PairedEndReadQ.h"
19
20 #include "SingleHit.h"
21 #include "PairedEndHit.h"
22
23 #include "Model.h"
24 #include "SingleModel.h"
25 #include "SingleQModel.h"
26 #include "PairedEndModel.h"
27 #include "PairedEndQModel.h"
28
29 #include "Transcript.h"
30 #include "Transcripts.h"
31
32 #include "Refs.h"
33 #include "GroupInfo.h"
34 #include "HitContainer.h"
35 #include "ReadIndex.h"
36 #include "ReadReader.h"
37
38 #include "ModelParams.h"
39
40 #include "HitWrapper.h"
41 #include "BamWriter.h"
42
43 using namespace std;
44
45 const double STOP_CRITERIA = 0.001;
46 const int MAX_ROUND = 10000;
47 const int MIN_ROUND = 20;
48
49 struct Params {
50         void *model;
51         void *reader, *hitv, *ncpv, *mhp, *countv;
52 };
53
54 int read_type;
55 int m, M; // m genes, M isoforms
56 int N0, N1, N2, N_tot;
57 int nThreads;
58
59
60 bool genBamF; // If user wants to generate bam file, true; otherwise, false.
61 bool updateModel, calcExpectedWeights;
62 bool genGibbsOut; // generate file for Gibbs sampler
63
64 char refName[STRLEN], outName[STRLEN];
65 char imdName[STRLEN], statName[STRLEN];
66 char refF[STRLEN], groupF[STRLEN], cntF[STRLEN], tiF[STRLEN];
67 char mparamsF[STRLEN], bmparamsF[STRLEN];
68 char modelF[STRLEN], thetaF[STRLEN];
69
70 char inpSamType;
71 char *pt_fn_list, *pt_chr_list;
72 char inpSamF[STRLEN], outBamF[STRLEN], fn_list[STRLEN], chr_list[STRLEN];
73
74 char out_for_gibbs_F[STRLEN];
75
76 vector<double> theta, eel; // eel : expected effective length
77
78 double *probv, **countvs;
79
80 Refs refs;
81 GroupInfo gi;
82 Transcripts transcripts;
83
84 ModelParams mparams;
85
86 template<class ReadType, class HitType, class ModelType>
87 void init(ReadReader<ReadType> **&readers, HitContainer<HitType> **&hitvs, double **&ncpvs, ModelType **&mhps) {
88         int nReads, nHits, rt;
89         int nrLeft, nhT, curnr; // nrLeft : number of reads left, nhT : hit threshold per thread, curnr: current number of reads
90         char datF[STRLEN];
91
92         int s;
93         char readFs[2][STRLEN];
94         ReadIndex *indices[2];
95         ifstream fin;
96
97         readers = new ReadReader<ReadType>*[nThreads];
98         genReadFileNames(imdName, 1, read_type, s, readFs);
99         for (int i = 0; i < s; i++) {
100                 indices[i] = new ReadIndex(readFs[i]);
101         }
102         for (int i = 0; i < nThreads; i++) {
103                 readers[i] = new ReadReader<ReadType>(s, readFs);
104                 readers[i]->setIndices(indices);
105         }
106
107         hitvs = new HitContainer<HitType>*[nThreads];
108         for (int i = 0; i < nThreads; i++) {
109                 hitvs[i] = new HitContainer<HitType>();
110         }
111
112         sprintf(datF, "%s.dat", imdName);
113         fin.open(datF);
114         if (!fin.is_open()) { fprintf(stderr, "Cannot open %s! It may not exist.\n", datF); exit(-1); }
115         fin>>nReads>>nHits>>rt;
116         if (nReads != N1) { fprintf(stderr, "Number of alignable reads does not match!\n"); exit(-1); }
117         //assert(nReads == N1);
118         if (rt != read_type) { fprintf(stderr, "Data file (.dat) does not have the right read type!\n"); exit(-1); }
119         //assert(rt == read_type);
120
121         //A just so so strategy for paralleling
122         nhT = nHits / nThreads;
123         nrLeft = N1;
124         curnr = 0;
125
126         ncpvs = new double*[nThreads];
127         for (int i = 0; i < nThreads; i++) {
128                 int ntLeft = nThreads - i - 1; // # of threads left
129                 if (!readers[i]->locate(curnr)) { fprintf(stderr, "Read indices files do not match!\n"); exit(-1); }
130                 //assert(readers[i]->locate(curnr));
131
132                 while (nrLeft > ntLeft && (i == nThreads - 1 || hitvs[i]->getNHits() < nhT)) {
133                         if (!hitvs[i]->read(fin)) { fprintf(stderr, "Cannot read alignments from .dat file!\n"); exit(-1); }
134                         //assert(hitvs[i]->read(fin));
135                         --nrLeft;
136                         if (verbose && nrLeft % 1000000 == 0) { printf("DAT %d reads left!\n", nrLeft); }
137                 }
138                 ncpvs[i] = new double[hitvs[i]->getN()];
139                 memset(ncpvs[i], 0, sizeof(double) * hitvs[i]->getN());
140                 curnr += hitvs[i]->getN();
141
142                 if (verbose) { printf("Thread %d : N = %d, NHit = %d\n", i, hitvs[i]->getN(), hitvs[i]->getNHits()); }
143         }
144
145         fin.close();
146
147         mhps = new ModelType*[nThreads];
148         for (int i = 0; i < nThreads; i++) {
149                 mhps[i] = new ModelType(mparams, false); // just model helper
150         }
151
152         probv = new double[M + 1];
153         countvs = new double*[nThreads];
154         for (int i = 0; i < nThreads; i++) {
155                 countvs[i] = new double[M + 1];
156         }
157
158
159         if (verbose) { printf("EM_init finished!\n"); }
160 }
161
162 template<class ReadType, class HitType, class ModelType>
163 void* E_STEP(void* arg) {
164         Params *params = (Params*)arg;
165         ModelType *model = (ModelType*)(params->model);
166         ReadReader<ReadType> *reader = (ReadReader<ReadType>*)(params->reader);
167         HitContainer<HitType> *hitv = (HitContainer<HitType>*)(params->hitv);
168         double *ncpv = (double*)(params->ncpv);
169         ModelType *mhp = (ModelType*)(params->mhp);
170         double *countv = (double*)(params->countv);
171
172         bool needCalcConPrb = model->getNeedCalcConPrb();
173
174         ReadType read;
175
176         int N = hitv->getN();
177         double sum;
178         vector<double> fracs; //to remove this, do calculation twice
179         int fr, to, id;
180
181         if (needCalcConPrb || updateModel) { reader->reset(); }
182         if (updateModel) { mhp->init(); }
183
184         memset(countv, 0, sizeof(double) * (M + 1));
185         for (int i = 0; i < N; i++) {
186                 if (needCalcConPrb || updateModel) {
187                         if (!reader->next(read)) {
188                                 fprintf(stderr, "Can not load a read!\n");
189                                 exit(-1);
190                         }
191                         //assert(reader->next(read));
192                 }
193                 fr = hitv->getSAt(i);
194                 to = hitv->getSAt(i + 1);
195                 fracs.resize(to - fr + 1);
196
197                 sum = 0.0;
198
199                 if (needCalcConPrb) { ncpv[i] = model->getNoiseConPrb(read); }
200                 fracs[0] = probv[0] * ncpv[i];
201                 if (fracs[0] < EPSILON) fracs[0] = 0.0;
202                 sum += fracs[0];
203                 for (int j = fr; j < to; j++) {
204                         HitType &hit = hitv->getHitAt(j);
205                         if (needCalcConPrb) { hit.setConPrb(model->getConPrb(read, hit)); }
206                         id = j - fr + 1;
207                         fracs[id] = probv[hit.getSid()] * hit.getConPrb();
208                         if (fracs[id] < EPSILON) fracs[id] = 0.0;
209                         sum += fracs[id];
210                 }
211
212                 if (sum >= EPSILON) {
213                         fracs[0] /= sum;
214                         countv[0] += fracs[0];
215                         if (updateModel) { mhp->updateNoise(read, fracs[0]); }
216                         if (calcExpectedWeights) { ncpv[i] = fracs[0]; }
217                         for (int j = fr; j < to; j++) {
218                                 HitType &hit = hitv->getHitAt(j);
219                                 id = j - fr + 1;
220                                 fracs[id] /= sum;
221                                 countv[hit.getSid()] += fracs[id];
222                                 if (updateModel) { mhp->update(read, hit, fracs[id]); }
223                                 if (calcExpectedWeights) { hit.setConPrb(fracs[id]); }
224                         }                       
225                 }
226                 else if (calcExpectedWeights) {
227                         ncpv[i] = 0.0;
228                         for (int j = fr; j < to; j++) {
229                                 HitType &hit = hitv->getHitAt(j);
230                                 hit.setConPrb(0.0);
231                         }
232                 }
233         }
234
235         return NULL;
236 }
237
238 template<class ReadType, class HitType, class ModelType>
239 void* calcConProbs(void* arg) {
240         Params *params = (Params*)arg;
241         ModelType *model = (ModelType*)(params->model);
242         ReadReader<ReadType> *reader = (ReadReader<ReadType>*)(params->reader);
243         HitContainer<HitType> *hitv = (HitContainer<HitType>*)(params->hitv);
244         double *ncpv = (double*)(params->ncpv);
245
246         ReadType read;
247         int N = hitv->getN();
248         int fr, to;
249
250         assert(model->getNeedCalcConPrb());
251         reader->reset();
252
253         for (int i = 0; i < N; i++) {
254                 if (!reader->next(read)) {
255                         fprintf(stderr, "Can not load a read!\n");
256                         exit(-1);
257                 }
258                 fr = hitv->getSAt(i);
259                 to = hitv->getSAt(i + 1);
260
261                 ncpv[i] = model->getNoiseConPrb(read);
262                 for (int j = fr; j < to; j++) {
263                         HitType &hit = hitv->getHitAt(j);
264                         hit.setConPrb(model->getConPrb(read, hit));
265                 }
266         }
267
268         return NULL;
269 }
270
271 template<class ModelType>
272 void calcExpectedEffectiveLengths(ModelType& model) {
273   int lb, ub, span;
274   double *pdf = NULL, *cdf = NULL, *clen = NULL; // clen[i] = sigma_{j=1}^{i}pdf[i]*(lb+i)
275   
276   model.getGLD().copyTo(pdf, cdf, lb, ub, span);
277   clen = new double[span + 1];
278   clen[0] = 0.0;
279   for (int i = 1; i <= span; i++) {
280     clen[i] = clen[i - 1] + pdf[i] * (lb + i);
281   }
282
283   eel.clear();
284   eel.resize(M + 1, 0.0);
285   for (int i = 1; i <= M; i++) {
286     int totLen = refs.getRef(i).getTotLen();
287     int fullLen = refs.getRef(i).getFullLen();
288     int pos1 = max(min(totLen - fullLen + 1, ub) - lb, 0);
289     int pos2 = max(min(totLen, ub) - lb, 0);
290
291     if (pos2 == 0) { eel[i] = 0.0; continue; }
292     
293     eel[i] = fullLen * cdf[pos1] + ((cdf[pos2] - cdf[pos1]) * (totLen + 1) - (clen[pos2] - clen[pos1]));
294     assert(eel[i] >= 0);
295     if (eel[i] < MINEEL) { eel[i] = 0.0; }
296   }
297   
298   delete[] pdf;
299   delete[] cdf;
300   delete[] clen;
301 }
302
303 template<class ModelType>
304 void writeResults(ModelType& model, double* counts) {
305         double denom;
306         char outF[STRLEN];
307         FILE *fo;
308
309         sprintf(modelF, "%s.model", statName);
310         model.write(modelF);
311
312         //calculate tau values
313         double *tau = new double[M + 1];
314         memset(tau, 0, sizeof(double) * (M + 1));
315
316         denom = 0.0;
317         for (int i = 1; i <= M; i++) 
318           if (eel[i] >= EPSILON) {
319             tau[i] = theta[i] / eel[i];
320             denom += tau[i];
321           }   
322         if (denom <= 0) { fprintf(stderr, "No alignable reads?!\n"); exit(-1); }
323         //assert(denom > 0);
324         for (int i = 1; i <= M; i++) {
325                 tau[i] /= denom;
326         }
327
328         //isoform level results
329         sprintf(outF, "%s.iso_res", imdName);
330         fo = fopen(outF, "w");
331         for (int i = 1; i <= M; i++) {
332                 const Transcript& transcript = transcripts.getTranscriptAt(i);
333                 fprintf(fo, "%s%c", transcript.getTranscriptID().c_str(), (i < M ? '\t' : '\n'));
334         }
335         for (int i = 1; i <= M; i++)
336                 fprintf(fo, "%.2f%c", counts[i], (i < M ? '\t' : '\n'));
337         for (int i = 1; i <= M; i++)
338                 fprintf(fo, "%.15g%c", tau[i], (i < M ? '\t' : '\n'));
339         for (int i = 1; i <= M; i++) {
340                 const Transcript& transcript = transcripts.getTranscriptAt(i);
341                 fprintf(fo, "%s%c", transcript.getLeft().c_str(), (i < M ? '\t' : '\n'));
342         }
343         fclose(fo);
344
345         //gene level results
346         sprintf(outF, "%s.gene_res", imdName);
347         fo = fopen(outF, "w");
348         for (int i = 0; i < m; i++) {
349                 const string& gene_id = transcripts.getTranscriptAt(gi.spAt(i)).getGeneID();
350                 fprintf(fo, "%s%c", gene_id.c_str(), (i < m - 1 ? '\t' : '\n'));
351         }
352         for (int i = 0; i < m; i++) {
353                 double sumC = 0.0; // sum of counts
354                 int b = gi.spAt(i), e = gi.spAt(i + 1);
355                 for (int j = b; j < e; j++) sumC += counts[j];
356                 fprintf(fo, "%.2f%c", sumC, (i < m - 1 ? '\t' : '\n'));
357         }
358         for (int i = 0; i < m; i++) {
359                 double sumT = 0.0; // sum of tau values
360                 int b = gi.spAt(i), e = gi.spAt(i + 1);
361                 for (int j = b; j < e; j++) sumT += tau[j];
362                 fprintf(fo, "%.15g%c", sumT, (i < m - 1 ? '\t' : '\n'));
363         }
364         for (int i = 0; i < m; i++) {
365                 int b = gi.spAt(i), e = gi.spAt(i + 1);
366                 for (int j = b; j < e; j++) {
367                         fprintf(fo, "%s%c", transcripts.getTranscriptAt(j).getTranscriptID().c_str(), (j < e - 1 ? ',' : (i < m - 1 ? '\t' :'\n')));
368                 }
369         }
370         fclose(fo);
371
372         delete[] tau;
373
374         if (verbose) { printf("Expression Results are written!\n"); }
375 }
376
377 template<class ReadType, class HitType, class ModelType>
378 void release(ReadReader<ReadType> **readers, HitContainer<HitType> **hitvs, double **ncpvs, ModelType **mhps) {
379         delete[] probv;
380         for (int i = 0; i < nThreads; i++) {
381                 delete[] countvs[i];
382         }
383         delete[] countvs;
384
385         for (int i = 0; i < nThreads; i++) {
386                 delete readers[i];
387                 delete hitvs[i];
388                 delete[] ncpvs[i];
389                 delete mhps[i];
390         }
391         delete[] readers;
392         delete[] hitvs;
393         delete[] ncpvs;
394         delete[] mhps;
395 }
396
397 int tmp_n;
398
399 inline bool doesUpdateModel(int ROUND) {
400   //  return ROUND <= 20 || ROUND % 100 == 0;
401   return ROUND <= 10;
402 }
403
404 //Including initialize, algorithm and results saving
405 template<class ReadType, class HitType, class ModelType>
406 void EM() {
407         FILE *fo;
408
409         int ROUND;
410         double sum;
411
412         double bChange = 0.0, change = 0.0; // bChange : biggest change
413         int totNum = 0;
414
415         ModelType model(mparams); //master model
416         ReadReader<ReadType> **readers;
417         HitContainer<HitType> **hitvs;
418         double **ncpvs;
419         ModelType **mhps; //model helpers
420
421         Params fparams[nThreads];
422         pthread_t threads[nThreads];
423         pthread_attr_t attr;
424         void *status;
425         int rc;
426
427
428         //initialize boolean variables
429         updateModel = calcExpectedWeights = false;
430
431         theta.clear();
432         theta.resize(M + 1, 0.0);
433         init<ReadType, HitType, ModelType>(readers, hitvs, ncpvs, mhps);
434
435         //set initial parameters
436         assert(N_tot > N2);
437         theta[0] = max(N0 * 1.0 / (N_tot - N2), 1e-8);
438         double val = (1.0 - theta[0]) / M;
439         for (int i = 1; i <= M; i++) theta[i] = val;
440
441         model.estimateFromReads(imdName);
442
443         for (int i = 0; i < nThreads; i++) {
444                 fparams[i].model = (void*)(&model);
445
446                 fparams[i].reader = (void*)readers[i];
447                 fparams[i].hitv = (void*)hitvs[i];
448                 fparams[i].ncpv = (void*)ncpvs[i];
449                 fparams[i].mhp = (void*)mhps[i];
450                 fparams[i].countv = (void*)countvs[i];
451         }
452
453         /* set thread attribute to be joinable */
454         pthread_attr_init(&attr);
455         pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
456
457         ROUND = 0;
458         do {
459                 ++ROUND;
460
461                 updateModel = doesUpdateModel(ROUND);
462
463                 for (int i = 0; i <= M; i++) probv[i] = theta[i];
464
465                 //E step
466                 for (int i = 0; i < nThreads; i++) {
467                         rc = pthread_create(&threads[i], &attr, E_STEP<ReadType, HitType, ModelType>, (void*)(&fparams[i]));
468                         if (rc != 0) { fprintf(stderr, "Cannot create thread %d at ROUND %d! (numbered from 0)\n", i, ROUND); exit(-1); }
469                         //assert(rc == 0);
470                 }
471
472                 for (int i = 0; i < nThreads; i++) {
473                         rc = pthread_join(threads[i], &status);
474                         if (rc != 0) { fprintf(stderr, "Cannot join thread %d at ROUND %d! (numbered from 0)\n", i, ROUND); exit(-1); }
475                         //assert(rc == 0);
476                 }
477
478                 model.setNeedCalcConPrb(false);
479
480                 for (int i = 1; i < nThreads; i++) {
481                         for (int j = 0; j <= M; j++) {
482                                 countvs[0][j] += countvs[i][j];
483                         }
484                 }
485
486                 //add N0 noise reads
487                 countvs[0][0] += N0;
488
489                 //M step;
490                 sum = 0.0;
491                 for (int i = 0; i <= M; i++) sum += countvs[0][i];
492                 assert(sum >= EPSILON);
493                 for (int i = 0; i <= M; i++) theta[i] = countvs[0][i] / sum;
494
495                 if (updateModel) {
496                         model.init();
497                         for (int i = 0; i < nThreads; i++) { model.collect(*mhps[i]); }
498                         model.finish();
499                 }
500
501                 // Relative error
502                 bChange = 0.0; totNum = 0;
503                 for (int i = 0; i <= M; i++)
504                         if (probv[i] >= 1e-7) {
505                                 change = fabs(theta[i] - probv[i]) / probv[i];
506                                 if (change >= STOP_CRITERIA) ++totNum;
507                                 if (bChange < change) bChange = change;
508                         }
509
510                 if (verbose) printf("ROUND = %d, SUM = %.15g, bChange = %f, totNum = %d\n", ROUND, sum, bChange, totNum);
511         } while (ROUND < MIN_ROUND || (totNum > 0 && ROUND < MAX_ROUND));
512           //while (ROUND < MAX_ROUND);
513
514         if (totNum > 0) fprintf(stderr, "Warning: RSEM reaches %d iterations before meeting the convergence criteria.\n", MAX_ROUND);
515
516         //generate output file used by Gibbs sampler
517         if (genGibbsOut) {
518                 if (model.getNeedCalcConPrb()) {
519                         for (int i = 0; i < nThreads; i++) {
520                                 rc = pthread_create(&threads[i], &attr, calcConProbs<ReadType, HitType, ModelType>, (void*)(&fparams[i]));
521                                 if (rc != 0) { fprintf(stderr, "Cannot create thread %d when generate files for Gibbs sampler! (numbered from 0)\n", i); exit(-1); }
522                         }
523                         for (int i = 0; i < nThreads; i++) {
524                                 rc = pthread_join(threads[i], &status);
525                                 if (rc != 0) { fprintf(stderr, "Cannot join thread %d when generate files for Gibbs sampler! (numbered from 0)\n", i); exit(-1); }
526                         }
527                 }
528                 model.setNeedCalcConPrb(false);
529
530                 sprintf(out_for_gibbs_F, "%s.ofg", imdName);
531                 fo = fopen(out_for_gibbs_F, "w");
532                 fprintf(fo, "%d %d\n", M, N0);
533                 for (int i = 0; i < nThreads; i++) {
534                         int numN = hitvs[i]->getN();
535                         for (int j = 0; j < numN; j++) {
536                                 int fr = hitvs[i]->getSAt(j);
537                                 int to = hitvs[i]->getSAt(j + 1);
538                                 int totNum = 0;
539
540                                 if (ncpvs[i][j] >= EPSILON) { ++totNum; fprintf(fo, "%d %.15g ", 0, ncpvs[i][j]); }
541                                 for (int k = fr; k < to; k++) {
542                                         HitType &hit = hitvs[i]->getHitAt(k);
543                                         if (hit.getConPrb() >= EPSILON) {
544                                                 ++totNum;
545                                                 fprintf(fo, "%d %.15g ", hit.getSid(), hit.getConPrb());
546                                         }
547                                 }
548
549                                 if (totNum > 0) { fprintf(fo, "\n"); }
550                         }
551                 }
552                 fclose(fo);
553         }
554
555         sprintf(thetaF, "%s.theta", statName);
556         fo = fopen(thetaF, "w");
557         fprintf(fo, "%d\n", M + 1);
558
559         // output theta'
560         for (int i = 0; i < M; i++) fprintf(fo, "%.15g ", theta[i]);
561         fprintf(fo, "%.15g\n", theta[M]);
562         
563         //calculate expected effective lengths for each isoform
564         calcExpectedEffectiveLengths<ModelType>(model);
565
566         //correct theta vector
567         sum = theta[0];
568         for (int i = 1; i <= M; i++) 
569           if (eel[i] < EPSILON) { theta[i] = 0.0; }
570           else sum += theta[i];
571         if (sum < EPSILON) { fprintf(stderr, "No Expected Effective Length is no less than %.6g?!\n", MINEEL); exit(-1); }
572         for (int i = 0; i <= M; i++) theta[i] /= sum;
573
574         //calculate expected weights and counts using learned parameters
575         updateModel = false; calcExpectedWeights = true;
576         for (int i = 0; i <= M; i++) probv[i] = theta[i];
577         for (int i = 0; i < nThreads; i++) {
578                 rc = pthread_create(&threads[i], &attr, E_STEP<ReadType, HitType, ModelType>, (void*)(&fparams[i]));
579                 if (rc != 0) { fprintf(stderr, "Cannot create thread %d when calculate expected weights! (numbered from 0)\n", i); exit(-1); }
580                 //assert(rc == 0);
581         }
582         for (int i = 0; i < nThreads; i++) {
583                 rc = pthread_join(threads[i], &status);
584                 if (rc != 0) { fprintf(stderr, "Cannot join thread %d! (numbered from 0) when calculate expected weights!\n", i); exit(-1); }
585                 //assert(rc == 0);
586         }
587         model.setNeedCalcConPrb(false);
588         for (int i = 1; i < nThreads; i++) {
589                 for (int j = 0; j <= M; j++) {
590                         countvs[0][j] += countvs[i][j];
591                 }
592         }
593         countvs[0][0] += N0;
594
595         /* destroy attribute */
596         pthread_attr_destroy(&attr);
597
598         //convert theta' to theta
599         double *mw = model.getMW();
600         sum = 0.0;
601         for (int i = 0; i <= M; i++) {
602           theta[i] = (mw[i] < EPSILON ? 0.0 : theta[i] / mw[i]);
603           sum += theta[i]; 
604         }
605         assert(sum >= EPSILON);
606         for (int i = 0; i <= M; i++) theta[i] /= sum;
607
608         // output theta
609         for (int i = 0; i < M; i++) fprintf(fo, "%.15g ", theta[i]);
610         fprintf(fo, "%.15g\n", theta[M]);
611
612         fclose(fo);
613
614         writeResults<ModelType>(model, countvs[0]);
615
616         if (genBamF) {
617                 sprintf(outBamF, "%s.bam", outName);
618                 if (transcripts.getType() == 0) {
619                         sprintf(chr_list, "%s.chrlist", refName);
620                         pt_chr_list = (char*)(&chr_list);
621                 }
622
623                 BamWriter writer(inpSamType, inpSamF, pt_fn_list, outBamF, pt_chr_list);
624                 HitWrapper<HitType> wrapper(nThreads, hitvs);
625                 writer.work(wrapper, transcripts);
626         }
627
628         release<ReadType, HitType, ModelType>(readers, hitvs, ncpvs, mhps);
629 }
630
631 int main(int argc, char* argv[]) {
632         ifstream fin;
633         bool quiet = false;
634
635         if (argc < 5) {
636                 printf("Usage : rsem-run-em refName read_type sampleName sampleToken [-p #Threads] [-b samInpType samInpF has_fn_list_? [fn_list]] [-q] [--gibbs-out]\n\n");
637                 printf("  refName: reference name\n");
638                 printf("  read_type: 0 single read without quality score; 1 single read with quality score; 2 paired-end read without quality score; 3 paired-end read with quality score.\n");
639                 printf("  sampleName: sample's name, including the path\n");
640                 printf("  sampleToken: sampleName excludes the path\n");
641                 printf("  -p: number of threads which user wants to use. (default: 1)\n");
642                 printf("  -b: produce bam format output file. (default: off)\n");
643                 printf("  -q: set it quiet\n");
644                 printf("  --gibbs-out: generate output file use by Gibbs sampler. (default: off)\n");
645                 printf("// model parameters should be in imdName.mparams.\n");
646                 exit(-1);
647         }
648
649         time_t a = time(NULL);
650
651         strcpy(refName, argv[1]);
652         read_type = atoi(argv[2]);
653         strcpy(outName, argv[3]);
654         sprintf(imdName, "%s.temp/%s", argv[3], argv[4]);
655         sprintf(statName, "%s.stat/%s", argv[3], argv[4]);
656
657         nThreads = 1;
658
659         genBamF = false;
660         genGibbsOut = false;
661         pt_fn_list = pt_chr_list = NULL;
662
663         for (int i = 5; i < argc; i++) {
664                 if (!strcmp(argv[i], "-p")) { nThreads = atoi(argv[i + 1]); }
665                 if (!strcmp(argv[i], "-b")) {
666                         genBamF = true;
667                         inpSamType = argv[i + 1][0];
668                         strcpy(inpSamF, argv[i + 2]);
669                         if (atoi(argv[i + 3]) == 1) {
670                                 strcpy(fn_list, argv[i + 4]);
671                                 pt_fn_list = (char*)(&fn_list);
672                         }
673                 }
674                 if (!strcmp(argv[i], "-q")) { quiet = true; }
675                 if (!strcmp(argv[i], "--gibbs-out")) { genGibbsOut = true; }
676         }
677         if (nThreads <= 0) { fprintf(stderr, "Number of threads should be bigger than 0!\n"); exit(-1); }
678         //assert(nThreads > 0);
679
680         verbose = !quiet;
681
682         //basic info loading
683         sprintf(refF, "%s.seq", refName);
684         refs.loadRefs(refF);
685         M = refs.getM();
686         sprintf(groupF, "%s.grp", refName);
687         gi.load(groupF);
688         m = gi.getm();
689
690         sprintf(tiF, "%s.ti", refName);
691         transcripts.readFrom(tiF);
692
693         sprintf(cntF, "%s.cnt", statName);
694         fin.open(cntF);
695         if (!fin.is_open()) { fprintf(stderr, "Cannot open %s! It may not exist.\n", cntF); exit(-1); }
696         fin>>N0>>N1>>N2>>N_tot;
697         fin.close();
698
699         if (N1 <= 0) { fprintf(stderr, "There are no alignable reads!\n"); exit(-1); }
700
701         if (nThreads > N1) nThreads = N1;
702
703         //set model parameters
704         mparams.M = M;
705         mparams.N[0] = N0; mparams.N[1] = N1; mparams.N[2] = N2;
706         mparams.refs = &refs;
707
708         sprintf(mparamsF, "%s.mparams", imdName);
709         fin.open(mparamsF);
710         if (!fin.is_open()) { fprintf(stderr, "Cannot open %s! It may not exist.\n", mparamsF); exit(-1); }
711         fin>> mparams.minL>> mparams.maxL>> mparams.probF;
712         int val; // 0 or 1 , for estRSPD
713         fin>>val;
714         mparams.estRSPD = (val != 0);
715         fin>> mparams.B>> mparams.mate_minL>> mparams.mate_maxL>> mparams.mean>> mparams.sd;
716         fin>> mparams.seedLen;
717         fin.close();
718
719         //run EM
720         switch(read_type) {
721         case 0 : EM<SingleRead, SingleHit, SingleModel>(); break;
722         case 1 : EM<SingleReadQ, SingleHit, SingleQModel>(); break;
723         case 2 : EM<PairedEndRead, PairedEndHit, PairedEndModel>(); break;
724         case 3 : EM<PairedEndReadQ, PairedEndHit, PairedEndQModel>(); break;
725         default : fprintf(stderr, "Unknown Read Type!\n"); exit(-1);
726         }
727
728         time_t b = time(NULL);
729
730         printTimeUsed(a, b);
731
732         return 0;
733 }