X-Git-Url: https://git.donarmstrong.com/?p=rsem.git;a=blobdiff_plain;f=EM.cpp;h=e106335f7b6a17b31e3b479a09849518802907cb;hp=f8849790d5a35b3f3dbeaf16b26a2415558b1f0e;hb=9eef8b58056b7cdaad1b4bdb2b2904d9fc0ff430;hpb=f67ec16ff8add74c17df026f77cf39e5a1aca051 diff --git a/EM.cpp b/EM.cpp index f884979..e106335 100644 --- a/EM.cpp +++ b/EM.cpp @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include #include "utils.h" @@ -55,7 +57,7 @@ struct Params { int read_type; int m, M; // m genes, M isoforms -int N0, N1, N2, N_tot; +READ_INT_TYPE N0, N1, N2, N_tot; int nThreads; @@ -67,7 +69,7 @@ bool genGibbsOut; // generate file for Gibbs sampler char refName[STRLEN], outName[STRLEN]; char imdName[STRLEN], statName[STRLEN]; char refF[STRLEN], groupF[STRLEN], cntF[STRLEN], tiF[STRLEN]; -char mparamsF[STRLEN], bmparamsF[STRLEN]; +char mparamsF[STRLEN]; char modelF[STRLEN], thetaF[STRLEN]; char inpSamType; @@ -88,8 +90,12 @@ ModelParams mparams; template void init(ReadReader **&readers, HitContainer **&hitvs, double **&ncpvs, ModelType **&mhps) { - int nReads, nHits, rt; - int nrLeft, nhT, curnr; // nrLeft : number of reads left, nhT : hit threshold per thread, curnr: current number of reads + READ_INT_TYPE nReads; + HIT_INT_TYPE nHits; + int rt; // read type + + READ_INT_TYPE nrLeft, curnr; // nrLeft : number of reads left, curnr: current number of reads + HIT_INT_TYPE nhT; // nhT : hit threshold per thread char datF[STRLEN]; int s; @@ -127,7 +133,7 @@ void init(ReadReader **&readers, HitContainer **&hitvs, doubl ncpvs = new double*[nThreads]; for (int i = 0; i < nThreads; i++) { - int ntLeft = nThreads - i - 1; // # of threads left + HIT_INT_TYPE ntLeft = nThreads - i - 1; // # of threads left general_assert(readers[i]->locate(curnr), "Read indices files do not match!"); @@ -135,13 +141,13 @@ void init(ReadReader **&readers, HitContainer **&hitvs, doubl general_assert(hitvs[i]->read(fin), "Cannot read alignments from .dat file!"); --nrLeft; - if (verbose && nrLeft % 1000000 == 0) { printf("DAT %d reads left!\n", nrLeft); } + if (verbose && nrLeft % 1000000 == 0) { cout<< "DAT "<< nrLeft << " reads left"<< endl; } } ncpvs[i] = new double[hitvs[i]->getN()]; memset(ncpvs[i], 0, sizeof(double) * hitvs[i]->getN()); curnr += hitvs[i]->getN(); - if (verbose) { printf("Thread %d : N = %d, NHit = %d\n", i, hitvs[i]->getN(), hitvs[i]->getNHits()); } + if (verbose) { cout<<"Thread "<< i<< " : N = "<< hitvs[i]->getN()<< ", NHit = "<< hitvs[i]->getNHits()<< endl; } } fin.close(); @@ -175,16 +181,16 @@ void* E_STEP(void* arg) { ReadType read; - int N = hitv->getN(); + READ_INT_TYPE N = hitv->getN(); double sum; vector fracs; //to remove this, do calculation twice - int fr, to, id; + HIT_INT_TYPE fr, to, id; if (needCalcConPrb || updateModel) { reader->reset(); } if (updateModel) { mhp->init(); } memset(countv, 0, sizeof(double) * (M + 1)); - for (int i = 0; i < N; i++) { + for (READ_INT_TYPE i = 0; i < N; i++) { if (needCalcConPrb || updateModel) { general_assert(reader->next(read), "Can not load a read!"); } @@ -199,7 +205,7 @@ void* E_STEP(void* arg) { fracs[0] = probv[0] * ncpv[i]; if (fracs[0] < EPSILON) fracs[0] = 0.0; sum += fracs[0]; - for (int j = fr; j < to; j++) { + for (HIT_INT_TYPE j = fr; j < to; j++) { HitType &hit = hitv->getHitAt(j); if (needCalcConPrb) { hit.setConPrb(model->getConPrb(read, hit)); } id = j - fr + 1; @@ -213,7 +219,7 @@ void* E_STEP(void* arg) { countv[0] += fracs[0]; if (updateModel) { mhp->updateNoise(read, fracs[0]); } if (calcExpectedWeights) { ncpv[i] = fracs[0]; } - for (int j = fr; j < to; j++) { + for (HIT_INT_TYPE j = fr; j < to; j++) { HitType &hit = hitv->getHitAt(j); id = j - fr + 1; fracs[id] /= sum; @@ -224,7 +230,7 @@ void* E_STEP(void* arg) { } else if (calcExpectedWeights) { ncpv[i] = 0.0; - for (int j = fr; j < to; j++) { + for (HIT_INT_TYPE j = fr; j < to; j++) { HitType &hit = hitv->getHitAt(j); hit.setConPrb(0.0); } @@ -243,20 +249,20 @@ void* calcConProbs(void* arg) { double *ncpv = (double*)(params->ncpv); ReadType read; - int N = hitv->getN(); - int fr, to; + READ_INT_TYPE N = hitv->getN(); + HIT_INT_TYPE fr, to; assert(model->getNeedCalcConPrb()); reader->reset(); - for (int i = 0; i < N; i++) { + for (READ_INT_TYPE i = 0; i < N; i++) { general_assert(reader->next(read), "Can not load a read!"); fr = hitv->getSAt(i); to = hitv->getSAt(i + 1); ncpv[i] = model->getNoiseConPrb(read); - for (int j = fr; j < to; j++) { + for (HIT_INT_TYPE j = fr; j < to; j++) { HitType &hit = hitv->getHitAt(j); hit.setConPrb(model->getConPrb(read, hit)); } @@ -406,7 +412,7 @@ void EM() { double sum; double bChange = 0.0, change = 0.0; // bChange : biggest change - int totNum = 0; + READ_INT_TYPE totNum = 0; ModelType model(mparams); //master model ReadReader **readers; @@ -501,11 +507,11 @@ void EM() { if (bChange < change) bChange = change; } - if (verbose) printf("ROUND = %d, SUM = %.15g, bChange = %f, totNum = %d\n", ROUND, sum, bChange, totNum); + if (verbose) { cout<< "ROUND = "<< ROUND<< ", SUM = "<< setprecision(15)<< sum<< ", bChange = " << setprecision(6)<< bChange<< ", totNum = %" << totNum<< endl; } } while (ROUND < MIN_ROUND || (totNum > 0 && ROUND < MAX_ROUND)); // } while (ROUND < 1); - if (totNum > 0) fprintf(stderr, "Warning: RSEM reaches %d iterations before meeting the convergence criteria.\n", MAX_ROUND); + if (totNum > 0) { cout<< "Warning: RSEM reaches "<< MAX_ROUND<< " iterations before meeting the convergence criteria."<< endl; } //generate output file used by Gibbs sampler if (genGibbsOut) { @@ -522,28 +528,28 @@ void EM() { model.setNeedCalcConPrb(false); sprintf(out_for_gibbs_F, "%s.ofg", imdName); - fo = fopen(out_for_gibbs_F, "w"); - fprintf(fo, "%d %d\n", M, N0); + ofstream fout(out_for_gibbs_F); + fout<< M<< " "<< N0<< endl; for (int i = 0; i < nThreads; i++) { - int numN = hitvs[i]->getN(); - for (int j = 0; j < numN; j++) { - int fr = hitvs[i]->getSAt(j); - int to = hitvs[i]->getSAt(j + 1); - int totNum = 0; - - if (ncpvs[i][j] >= EPSILON) { ++totNum; fprintf(fo, "%d %.15g ", 0, ncpvs[i][j]); } - for (int k = fr; k < to; k++) { + READ_INT_TYPE numN = hitvs[i]->getN(); + for (READ_INT_TYPE j = 0; j < numN; j++) { + HIT_INT_TYPE fr = hitvs[i]->getSAt(j); + HIT_INT_TYPE to = hitvs[i]->getSAt(j + 1); + HIT_INT_TYPE totNum = 0; + + if (ncpvs[i][j] >= EPSILON) { ++totNum; fout<< "0 "<< setprecision(15)<< ncpvs[i][j]<< " "; } + for (HIT_INT_TYPE k = fr; k < to; k++) { HitType &hit = hitvs[i]->getHitAt(k); if (hit.getConPrb() >= EPSILON) { ++totNum; - fprintf(fo, "%d %.15g ", hit.getSid(), hit.getConPrb()); + fout<< hit.getSid()<< " "<< setprecision(15)<< hit.getConPrb()<< " "; } } - if (totNum > 0) { fprintf(fo, "\n"); } + if (totNum > 0) { fout<< endl; } } } - fclose(fo); + fout.close(); } sprintf(thetaF, "%s.theta", statName); @@ -611,27 +617,27 @@ void EM() { sprintf(outBamF, "%s.transcript.bam", outName); if (bamSampling) { - int local_N; - int fr, to, len, id; + READ_INT_TYPE local_N; + HIT_INT_TYPE fr, to, len, id; vector arr; uniform01 rg(engine_type(time(NULL))); - if (verbose) printf("Begin to sample reads from their posteriors.\n"); + if (verbose) cout<< "Begin to sample reads from their posteriors."<< endl; for (int i = 0; i < nThreads; i++) { local_N = hitvs[i]->getN(); - for (int j = 0; j < local_N; j++) { + for (READ_INT_TYPE j = 0; j < local_N; j++) { fr = hitvs[i]->getSAt(j); to = hitvs[i]->getSAt(j + 1); len = to - fr + 1; arr.assign(len, 0); arr[0] = ncpvs[i][j]; - for (int k = fr; k < to; k++) arr[k - fr + 1] = arr[k - fr] + hitvs[i]->getHitAt(k).getConPrb(); + for (HIT_INT_TYPE k = fr; k < to; k++) arr[k - fr + 1] = arr[k - fr] + hitvs[i]->getHitAt(k).getConPrb(); id = (arr[len - 1] < EPSILON ? -1 : sample(rg, arr, len)); // if all entries in arr are 0, let id be -1 - for (int k = fr; k < to; k++) hitvs[i]->getHitAt(k).setConPrb(k - fr + 1 == id ? 1.0 : 0.0); + for (HIT_INT_TYPE k = fr; k < to; k++) hitvs[i]->getHitAt(k).setConPrb(k - fr + 1 == id ? 1.0 : 0.0); } } - if (verbose) printf("Sampling is finished.\n"); + if (verbose) cout<< "Sampling is finished."<< endl; } BamWriter writer(inpSamType, inpSamF, pt_fn_list, outBamF, transcripts); @@ -717,7 +723,7 @@ int main(int argc, char* argv[]) { general_assert(N1 > 0, "There are no alignable reads!"); - if (nThreads > N1) nThreads = N1; + if ((READ_INT_TYPE)nThreads > N1) nThreads = N1; //set model parameters mparams.M = M;