]> git.donarmstrong.com Git - rsem.git/blob - EM.cpp
Modified the acknowledgement section of README.md
[rsem.git] / EM.cpp
1 #include<ctime>
2 #include<cmath>
3 #include<cstdio>
4 #include<cstdlib>
5 #include<cstring>
6 #include<cassert>
7 #include<string>
8 #include<vector>
9 #include<algorithm>
10 #include<fstream>
11 #include<iostream>
12 #include<pthread.h>
13
14 #include "utils.h"
15 #include "my_assert.h"
16 #include "sampling.h"
17
18 #include "Read.h"
19 #include "SingleRead.h"
20 #include "SingleReadQ.h"
21 #include "PairedEndRead.h"
22 #include "PairedEndReadQ.h"
23
24 #include "SingleHit.h"
25 #include "PairedEndHit.h"
26
27 #include "Model.h"
28 #include "SingleModel.h"
29 #include "SingleQModel.h"
30 #include "PairedEndModel.h"
31 #include "PairedEndQModel.h"
32
33 #include "Transcript.h"
34 #include "Transcripts.h"
35
36 #include "Refs.h"
37 #include "GroupInfo.h"
38 #include "HitContainer.h"
39 #include "ReadIndex.h"
40 #include "ReadReader.h"
41
42 #include "ModelParams.h"
43
44 #include "HitWrapper.h"
45 #include "BamWriter.h"
46
47 #include "WriteResults.h"
48
49 using namespace std;
50
51 const double STOP_CRITERIA = 0.001;
52 const int MAX_ROUND = 10000;
53 const int MIN_ROUND = 20;
54
55 struct Params {
56         void *model;
57         void *reader, *hitv, *ncpv, *mhp, *countv;
58 };
59
60 int read_type;
61 int m, M; // m genes, M isoforms
62 READ_INT_TYPE N0, N1, N2, N_tot;
63 int nThreads;
64
65
66 bool genBamF; // If user wants to generate bam file, true; otherwise, false.
67 bool bamSampling; // true if sampling from read posterior distribution when bam file is generated
68 bool updateModel, calcExpectedWeights;
69 bool genGibbsOut; // generate file for Gibbs sampler
70
71 char refName[STRLEN], outName[STRLEN];
72 char imdName[STRLEN], statName[STRLEN];
73 char refF[STRLEN], cntF[STRLEN], tiF[STRLEN];
74 char mparamsF[STRLEN];
75 char modelF[STRLEN], thetaF[STRLEN];
76
77 char inpSamType;
78 char *pt_fn_list, *pt_chr_list;
79 char inpSamF[STRLEN], outBamF[STRLEN], fn_list[STRLEN], chr_list[STRLEN];
80
81 char out_for_gibbs_F[STRLEN];
82
83 vector<double> theta, eel; // eel : expected effective length
84
85 double *probv, **countvs;
86
87 Refs refs;
88 Transcripts transcripts;
89
90 ModelParams mparams;
91
92 template<class ReadType, class HitType, class ModelType>
93 void init(ReadReader<ReadType> **&readers, HitContainer<HitType> **&hitvs, double **&ncpvs, ModelType **&mhps) {
94         READ_INT_TYPE nReads;
95         HIT_INT_TYPE nHits;
96         int rt; // read type
97
98         READ_INT_TYPE nrLeft, curnr; // nrLeft : number of reads left, curnr: current number of reads
99         HIT_INT_TYPE nhT; // nhT : hit threshold per thread
100         char datF[STRLEN];
101
102         int s;
103         char readFs[2][STRLEN];
104         ReadIndex *indices[2];
105         ifstream fin;
106
107         readers = new ReadReader<ReadType>*[nThreads];
108         genReadFileNames(imdName, 1, read_type, s, readFs);
109         for (int i = 0; i < s; i++) {
110                 indices[i] = new ReadIndex(readFs[i]);
111         }
112         for (int i = 0; i < nThreads; i++) {
113                 readers[i] = new ReadReader<ReadType>(s, readFs, refs.hasPolyA(), mparams.seedLen); // allow calculation of calc_lq() function
114                 readers[i]->setIndices(indices);
115         }
116
117         hitvs = new HitContainer<HitType>*[nThreads];
118         for (int i = 0; i < nThreads; i++) {
119                 hitvs[i] = new HitContainer<HitType>();
120         }
121
122         sprintf(datF, "%s.dat", imdName);
123         fin.open(datF);
124         general_assert(fin.is_open(), "Cannot open " + cstrtos(datF) + "! It may not exist.");
125         fin>>nReads>>nHits>>rt;
126         general_assert(nReads == N1, "Number of alignable reads does not match!");
127         general_assert(rt == read_type, "Data file (.dat) does not have the right read type!");
128
129
130         //A just so so strategy for paralleling
131         nhT = nHits / nThreads;
132         nrLeft = N1;
133         curnr = 0;
134
135         ncpvs = new double*[nThreads];
136         for (int i = 0; i < nThreads; i++) {
137                 HIT_INT_TYPE ntLeft = nThreads - i - 1; // # of threads left
138
139                 general_assert(readers[i]->locate(curnr), "Read indices files do not match!");
140
141                 while (nrLeft > ntLeft && (i == nThreads - 1 || hitvs[i]->getNHits() < nhT)) {
142                         general_assert(hitvs[i]->read(fin), "Cannot read alignments from .dat file!");
143
144                         --nrLeft;
145                         if (verbose && nrLeft % 1000000 == 0) { cout<< "DAT "<< nrLeft << " reads left"<< endl; }
146                 }
147                 ncpvs[i] = new double[hitvs[i]->getN()];
148                 memset(ncpvs[i], 0, sizeof(double) * hitvs[i]->getN());
149                 curnr += hitvs[i]->getN();
150
151                 if (verbose) { cout<<"Thread "<< i<< " : N = "<< hitvs[i]->getN()<< ", NHit = "<< hitvs[i]->getNHits()<< endl; }
152         }
153
154         fin.close();
155
156         mhps = new ModelType*[nThreads];
157         for (int i = 0; i < nThreads; i++) {
158                 mhps[i] = new ModelType(mparams, false); // just model helper
159         }
160
161         probv = new double[M + 1];
162         countvs = new double*[nThreads];
163         for (int i = 0; i < nThreads; i++) {
164                 countvs[i] = new double[M + 1];
165         }
166
167
168         if (verbose) { printf("EM_init finished!\n"); }
169 }
170
171 template<class ReadType, class HitType, class ModelType>
172 void* E_STEP(void* arg) {
173         Params *params = (Params*)arg;
174         ModelType *model = (ModelType*)(params->model);
175         ReadReader<ReadType> *reader = (ReadReader<ReadType>*)(params->reader);
176         HitContainer<HitType> *hitv = (HitContainer<HitType>*)(params->hitv);
177         double *ncpv = (double*)(params->ncpv);
178         ModelType *mhp = (ModelType*)(params->mhp);
179         double *countv = (double*)(params->countv);
180
181         bool needCalcConPrb = model->getNeedCalcConPrb();
182
183         ReadType read;
184
185         READ_INT_TYPE N = hitv->getN();
186         double sum;
187         vector<double> fracs; //to remove this, do calculation twice
188         HIT_INT_TYPE fr, to, id;
189
190         if (needCalcConPrb || updateModel) { reader->reset(); }
191         if (updateModel) { mhp->init(); }
192
193         memset(countv, 0, sizeof(double) * (M + 1));
194         for (READ_INT_TYPE i = 0; i < N; i++) {
195                 if (needCalcConPrb || updateModel) {
196                         general_assert(reader->next(read), "Can not load a read!");
197                 }
198
199                 fr = hitv->getSAt(i);
200                 to = hitv->getSAt(i + 1);
201                 fracs.resize(to - fr + 1);
202
203                 sum = 0.0;
204
205                 if (needCalcConPrb) { ncpv[i] = model->getNoiseConPrb(read); }
206                 fracs[0] = probv[0] * ncpv[i];
207                 if (fracs[0] < EPSILON) fracs[0] = 0.0;
208                 sum += fracs[0];
209                 for (HIT_INT_TYPE j = fr; j < to; j++) {
210                         HitType &hit = hitv->getHitAt(j);
211                         if (needCalcConPrb) { hit.setConPrb(model->getConPrb(read, hit)); }
212                         id = j - fr + 1;
213                         fracs[id] = probv[hit.getSid()] * hit.getConPrb();
214                         if (fracs[id] < EPSILON) fracs[id] = 0.0;
215                         sum += fracs[id];
216                 }
217
218                 if (sum >= EPSILON) {
219                         fracs[0] /= sum;
220                         countv[0] += fracs[0];
221                         if (updateModel) { mhp->updateNoise(read, fracs[0]); }
222                         if (calcExpectedWeights) { ncpv[i] = fracs[0]; }
223                         for (HIT_INT_TYPE j = fr; j < to; j++) {
224                                 HitType &hit = hitv->getHitAt(j);
225                                 id = j - fr + 1;
226                                 fracs[id] /= sum;
227                                 countv[hit.getSid()] += fracs[id];
228                                 if (updateModel) { mhp->update(read, hit, fracs[id]); }
229                                 if (calcExpectedWeights) { hit.setConPrb(fracs[id]); }
230                         }                       
231                 }
232                 else if (calcExpectedWeights) {
233                         ncpv[i] = 0.0;
234                         for (HIT_INT_TYPE j = fr; j < to; j++) {
235                                 HitType &hit = hitv->getHitAt(j);
236                                 hit.setConPrb(0.0);
237                         }
238                 }
239         }
240
241         return NULL;
242 }
243
244 template<class ReadType, class HitType, class ModelType>
245 void* calcConProbs(void* arg) {
246         Params *params = (Params*)arg;
247         ModelType *model = (ModelType*)(params->model);
248         ReadReader<ReadType> *reader = (ReadReader<ReadType>*)(params->reader);
249         HitContainer<HitType> *hitv = (HitContainer<HitType>*)(params->hitv);
250         double *ncpv = (double*)(params->ncpv);
251
252         ReadType read;
253         READ_INT_TYPE N = hitv->getN();
254         HIT_INT_TYPE fr, to;
255
256         assert(model->getNeedCalcConPrb());
257         reader->reset();
258
259         for (READ_INT_TYPE i = 0; i < N; i++) {
260                 general_assert(reader->next(read), "Can not load a read!");
261
262                 fr = hitv->getSAt(i);
263                 to = hitv->getSAt(i + 1);
264
265                 ncpv[i] = model->getNoiseConPrb(read);
266                 for (HIT_INT_TYPE j = fr; j < to; j++) {
267                         HitType &hit = hitv->getHitAt(j);
268                         hit.setConPrb(model->getConPrb(read, hit));
269                 }
270         }
271
272         return NULL;
273 }
274
275 template<class ModelType>
276 void writeResults(ModelType& model, double* counts) {
277   sprintf(modelF, "%s.model", statName);
278   model.write(modelF);
279   writeResultsEM(M, refName, imdName, transcripts, theta, eel, countvs[0]);
280 }
281
282 template<class ReadType, class HitType, class ModelType>
283 void release(ReadReader<ReadType> **readers, HitContainer<HitType> **hitvs, double **ncpvs, ModelType **mhps) {
284         delete[] probv;
285         for (int i = 0; i < nThreads; i++) {
286                 delete[] countvs[i];
287         }
288         delete[] countvs;
289
290         for (int i = 0; i < nThreads; i++) {
291                 delete readers[i];
292                 delete hitvs[i];
293                 delete[] ncpvs[i];
294                 delete mhps[i];
295         }
296         delete[] readers;
297         delete[] hitvs;
298         delete[] ncpvs;
299         delete[] mhps;
300 }
301
302 inline bool doesUpdateModel(int ROUND) {
303   //  return ROUND <= 20 || ROUND % 100 == 0;
304   return ROUND <= 10;
305 }
306
307 //Including initialize, algorithm and results saving
308 template<class ReadType, class HitType, class ModelType>
309 void EM() {
310         FILE *fo;
311
312         int ROUND;
313         double sum;
314
315         double bChange = 0.0, change = 0.0; // bChange : biggest change
316         int totNum = 0;
317
318         ModelType model(mparams); //master model
319         ReadReader<ReadType> **readers;
320         HitContainer<HitType> **hitvs;
321         double **ncpvs;
322         ModelType **mhps; //model helpers
323
324         Params fparams[nThreads];
325         pthread_t threads[nThreads];
326         pthread_attr_t attr;
327         int rc;
328
329
330         //initialize boolean variables
331         updateModel = calcExpectedWeights = false;
332
333         theta.clear();
334         theta.resize(M + 1, 0.0);
335         init<ReadType, HitType, ModelType>(readers, hitvs, ncpvs, mhps);
336
337         //set initial parameters
338         assert(N_tot > N2);
339         theta[0] = max(N0 * 1.0 / (N_tot - N2), 1e-8);
340         double val = (1.0 - theta[0]) / M;
341         for (int i = 1; i <= M; i++) theta[i] = val;
342
343         model.estimateFromReads(imdName);
344
345         for (int i = 0; i < nThreads; i++) {
346                 fparams[i].model = (void*)(&model);
347
348                 fparams[i].reader = (void*)readers[i];
349                 fparams[i].hitv = (void*)hitvs[i];
350                 fparams[i].ncpv = (void*)ncpvs[i];
351                 fparams[i].mhp = (void*)mhps[i];
352                 fparams[i].countv = (void*)countvs[i];
353         }
354
355         /* set thread attribute to be joinable */
356         pthread_attr_init(&attr);
357         pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
358
359         ROUND = 0;
360         do {
361                 ++ROUND;
362
363                 updateModel = doesUpdateModel(ROUND);
364
365                 for (int i = 0; i <= M; i++) probv[i] = theta[i];
366
367                 //E step
368                 for (int i = 0; i < nThreads; i++) {
369                         rc = pthread_create(&threads[i], &attr, E_STEP<ReadType, HitType, ModelType>, (void*)(&fparams[i]));
370                         pthread_assert(rc, "pthread_create", "Cannot create thread " + itos(i) + " (numbered from 0) at ROUND " + itos(ROUND) + "!");
371                 }
372
373                 for (int i = 0; i < nThreads; i++) {
374                         rc = pthread_join(threads[i], NULL);
375                         pthread_assert(rc, "pthread_join", "Cannot join thread " + itos(i) + " (numbered from 0) at ROUND " + itos(ROUND) + "!");
376                 }
377
378                 model.setNeedCalcConPrb(false);
379
380                 for (int i = 1; i < nThreads; i++) {
381                         for (int j = 0; j <= M; j++) {
382                                 countvs[0][j] += countvs[i][j];
383                         }
384                 }
385
386                 //add N0 noise reads
387                 countvs[0][0] += N0;
388
389                 //M step;
390                 sum = 0.0;
391                 for (int i = 0; i <= M; i++) sum += countvs[0][i];
392                 assert(sum >= EPSILON);
393                 for (int i = 0; i <= M; i++) theta[i] = countvs[0][i] / sum;
394
395                 if (updateModel) {
396                         model.init();
397                         for (int i = 0; i < nThreads; i++) { model.collect(*mhps[i]); }
398                         model.finish();
399                 }
400
401                 // Relative error
402                 bChange = 0.0; totNum = 0;
403                 for (int i = 0; i <= M; i++)
404                         if (probv[i] >= 1e-7) {
405                                 change = fabs(theta[i] - probv[i]) / probv[i];
406                                 if (change >= STOP_CRITERIA) ++totNum;
407                                 if (bChange < change) bChange = change;
408                         }
409
410                 if (verbose) { cout<< "ROUND = "<< ROUND<< ", SUM = "<< setprecision(15)<< sum<< ", bChange = " << setprecision(6)<< bChange<< ", totNum = " << totNum<< endl; }
411         } while (ROUND < MIN_ROUND || (totNum > 0 && ROUND < MAX_ROUND));
412 //      } while (ROUND < 1);
413
414         if (totNum > 0) { cout<< "Warning: RSEM reaches "<< MAX_ROUND<< " iterations before meeting the convergence criteria."<< endl; }
415
416         //generate output file used by Gibbs sampler
417         if (genGibbsOut) {
418                 if (model.getNeedCalcConPrb()) {
419                         for (int i = 0; i < nThreads; i++) {
420                                 rc = pthread_create(&threads[i], &attr, calcConProbs<ReadType, HitType, ModelType>, (void*)(&fparams[i]));
421                                 pthread_assert(rc, "pthread_create", "Cannot create thread " + itos(i) + " (numbered from 0) when generating files for Gibbs sampler!");
422                         }
423                         for (int i = 0; i < nThreads; i++) {
424                                 rc = pthread_join(threads[i], NULL);
425                                 pthread_assert(rc, "pthread_join", "Cannot join thread " + itos(i) + " (numbered from 0) when generating files for Gibbs sampler!");
426                         }
427                 }
428                 model.setNeedCalcConPrb(false);
429
430                 sprintf(out_for_gibbs_F, "%s.ofg", imdName);
431                 ofstream fout(out_for_gibbs_F);
432                 fout<< M<< " "<< N0<< endl;
433                 for (int i = 0; i < nThreads; i++) {
434                         READ_INT_TYPE numN = hitvs[i]->getN();
435                         for (READ_INT_TYPE j = 0; j < numN; j++) {
436                                 HIT_INT_TYPE fr = hitvs[i]->getSAt(j);
437                                 HIT_INT_TYPE to = hitvs[i]->getSAt(j + 1);
438                                 HIT_INT_TYPE totNum = 0;
439
440                                 if (ncpvs[i][j] >= EPSILON) { ++totNum; fout<< "0 "<< setprecision(15)<< ncpvs[i][j]<< " "; }
441                                 for (HIT_INT_TYPE k = fr; k < to; k++) {
442                                         HitType &hit = hitvs[i]->getHitAt(k);
443                                         if (hit.getConPrb() >= EPSILON) {
444                                                 ++totNum;
445                                                 fout<< hit.getSid()<< " "<< setprecision(15)<< hit.getConPrb()<< " ";
446                                         }
447                                 }
448
449                                 if (totNum > 0) { fout<< endl; }
450                         }
451                 }
452                 fout.close();
453         }
454
455         //calculate expected weights and counts using learned parameters
456         //just use the raw theta learned from the data, do not correct for eel or mw
457         updateModel = false; calcExpectedWeights = true;
458         for (int i = 0; i <= M; i++) probv[i] = theta[i];
459         for (int i = 0; i < nThreads; i++) {
460                 rc = pthread_create(&threads[i], &attr, E_STEP<ReadType, HitType, ModelType>, (void*)(&fparams[i]));
461                 pthread_assert(rc, "pthread_create", "Cannot create thread " + itos(i) + " (numbered from 0) when calculating expected weights!");
462         }
463         for (int i = 0; i < nThreads; i++) {
464                 rc = pthread_join(threads[i], NULL);
465                 pthread_assert(rc, "pthread_join", "Cannot join thread " + itos(i) + " (numbered from 0) when calculating expected weights!");
466         }
467         model.setNeedCalcConPrb(false);
468         for (int i = 1; i < nThreads; i++) {
469                 for (int j = 0; j <= M; j++) {
470                         countvs[0][j] += countvs[i][j];
471                 }
472         }
473         countvs[0][0] += N0;
474
475         /* destroy attribute */
476         pthread_attr_destroy(&attr);
477
478
479         sprintf(thetaF, "%s.theta", statName);
480         fo = fopen(thetaF, "w");
481         fprintf(fo, "%d\n", M + 1);
482
483         // output theta'
484         for (int i = 0; i < M; i++) fprintf(fo, "%.15g ", theta[i]);
485         fprintf(fo, "%.15g\n", theta[M]);
486         
487         //calculate expected effective lengths for each isoform
488         calcExpectedEffectiveLengths<ModelType>(M, refs, model, eel);
489         polishTheta(M, theta, eel, model.getMW());
490
491         // output theta
492         for (int i = 0; i < M; i++) fprintf(fo, "%.15g ", theta[i]);
493         fprintf(fo, "%.15g\n", theta[M]);
494
495         fclose(fo);
496
497         writeResults<ModelType>(model, countvs[0]);
498
499         if (genBamF) {
500                 sprintf(outBamF, "%s.transcript.bam", outName);
501                 
502                 if (bamSampling) {
503                         READ_INT_TYPE local_N;
504                         HIT_INT_TYPE fr, to, len, id;
505                         vector<double> arr;
506                         uniform01 rg(engine_type(time(NULL)));
507
508                         if (verbose) cout<< "Begin to sample reads from their posteriors."<< endl;
509                         for (int i = 0; i < nThreads; i++) {
510                                 local_N = hitvs[i]->getN();
511                                 for (READ_INT_TYPE j = 0; j < local_N; j++) {
512                                         fr = hitvs[i]->getSAt(j);
513                                         to = hitvs[i]->getSAt(j + 1);
514                                         len = to - fr + 1;
515                                         arr.assign(len, 0);
516                                         arr[0] = ncpvs[i][j];
517                                         for (HIT_INT_TYPE k = fr; k < to; k++) arr[k - fr + 1] = arr[k - fr] + hitvs[i]->getHitAt(k).getConPrb();
518                                         id = (arr[len - 1] < EPSILON ? -1 : sample(rg, arr, len)); // if all entries in arr are 0, let id be -1
519                                         for (HIT_INT_TYPE k = fr; k < to; k++) hitvs[i]->getHitAt(k).setConPrb(k - fr + 1 == id ? 1.0 : 0.0);
520                                 }
521                         }
522
523                         if (verbose) cout<< "Sampling is finished."<< endl;
524                 }
525
526                 BamWriter writer(inpSamType, inpSamF, pt_fn_list, outBamF, transcripts);
527                 HitWrapper<HitType> wrapper(nThreads, hitvs);
528                 writer.work(wrapper);
529         }
530
531         release<ReadType, HitType, ModelType>(readers, hitvs, ncpvs, mhps);
532 }
533
534 int main(int argc, char* argv[]) {
535         ifstream fin;
536         bool quiet = false;
537
538         if (argc < 6) {
539                 printf("Usage : rsem-run-em refName read_type sampleName imdName statName [-p #Threads] [-b samInpType samInpF has_fn_list_? [fn_list]] [-q] [--gibbs-out] [--sampling]\n\n");
540                 printf("  refName: reference name\n");
541                 printf("  read_type: 0 single read without quality score; 1 single read with quality score; 2 paired-end read without quality score; 3 paired-end read with quality score.\n");
542                 printf("  sampleName: sample's name, including the path\n");
543                 printf("  sampleToken: sampleName excludes the path\n");
544                 printf("  -p: number of threads which user wants to use. (default: 1)\n");
545                 printf("  -b: produce bam format output file. (default: off)\n");
546                 printf("  -q: set it quiet\n");
547                 printf("  --gibbs-out: generate output file used by Gibbs sampler. (default: off)\n");
548                 printf("  --sampling: sample each read from its posterior distribution when bam file is generated. (default: off)\n");
549                 printf("// model parameters should be in imdName.mparams.\n");
550                 exit(-1);
551         }
552
553         time_t a = time(NULL);
554
555         strcpy(refName, argv[1]);
556         read_type = atoi(argv[2]);
557         strcpy(outName, argv[3]);
558         strcpy(imdName, argv[4]);
559         strcpy(statName, argv[5]);
560
561         nThreads = 1;
562
563         genBamF = false;
564         bamSampling = false;
565         genGibbsOut = false;
566         pt_fn_list = pt_chr_list = NULL;
567
568         for (int i = 6; i < argc; i++) {
569                 if (!strcmp(argv[i], "-p")) { nThreads = atoi(argv[i + 1]); }
570                 if (!strcmp(argv[i], "-b")) {
571                         genBamF = true;
572                         inpSamType = argv[i + 1][0];
573                         strcpy(inpSamF, argv[i + 2]);
574                         if (atoi(argv[i + 3]) == 1) {
575                                 strcpy(fn_list, argv[i + 4]);
576                                 pt_fn_list = (char*)(&fn_list);
577                         }
578                 }
579                 if (!strcmp(argv[i], "-q")) { quiet = true; }
580                 if (!strcmp(argv[i], "--gibbs-out")) { genGibbsOut = true; }
581                 if (!strcmp(argv[i], "--sampling")) { bamSampling = true; }
582         }
583
584         general_assert(nThreads > 0, "Number of threads should be bigger than 0!");
585
586         verbose = !quiet;
587
588         //basic info loading
589         sprintf(refF, "%s.seq", refName);
590         refs.loadRefs(refF);
591         M = refs.getM();
592
593         sprintf(tiF, "%s.ti", refName);
594         transcripts.readFrom(tiF);
595
596         sprintf(cntF, "%s.cnt", statName);
597         fin.open(cntF);
598
599         general_assert(fin.is_open(), "Cannot open " + cstrtos(cntF) + "! It may not exist.");
600
601         fin>>N0>>N1>>N2>>N_tot;
602         fin.close();
603
604         general_assert(N1 > 0, "There are no alignable reads!");
605
606         if ((READ_INT_TYPE)nThreads > N1) nThreads = N1;
607
608         //set model parameters
609         mparams.M = M;
610         mparams.N[0] = N0; mparams.N[1] = N1; mparams.N[2] = N2;
611         mparams.refs = &refs;
612
613         sprintf(mparamsF, "%s.mparams", imdName);
614         fin.open(mparamsF);
615
616         general_assert(fin.is_open(), "Cannot open " + cstrtos(mparamsF) + "It may not exist.");
617
618         fin>> mparams.minL>> mparams.maxL>> mparams.probF;
619         int val; // 0 or 1 , for estRSPD
620         fin>>val;
621         mparams.estRSPD = (val != 0);
622         fin>> mparams.B>> mparams.mate_minL>> mparams.mate_maxL>> mparams.mean>> mparams.sd;
623         fin>> mparams.seedLen;
624         fin.close();
625
626         //run EM
627         switch(read_type) {
628         case 0 : EM<SingleRead, SingleHit, SingleModel>(); break;
629         case 1 : EM<SingleReadQ, SingleHit, SingleQModel>(); break;
630         case 2 : EM<PairedEndRead, PairedEndHit, PairedEndModel>(); break;
631         case 3 : EM<PairedEndReadQ, PairedEndHit, PairedEndQModel>(); break;
632         default : fprintf(stderr, "Unknown Read Type!\n"); exit(-1);
633         }
634
635         time_t b = time(NULL);
636
637         printTimeUsed(a, b, "EM.cpp");
638
639         return 0;
640 }