]> git.donarmstrong.com Git - rsem.git/blobdiff - EM.cpp
Refactored wiggle code and added rsem-bam2readdepth program
[rsem.git] / EM.cpp
diff --git a/EM.cpp b/EM.cpp
index e381b470e002be865e2bdb8a30adb3f6f510ec07..08fc8c7d2400396a8fef40fc945717be4b166e83 100644 (file)
--- a/EM.cpp
+++ b/EM.cpp
@@ -61,9 +61,11 @@ bool genBamF; // If user wants to generate bam file, true; otherwise, false.
 bool updateModel, calcExpectedWeights;
 bool genGibbsOut; // generate file for Gibbs sampler
 
-char refName[STRLEN], imdName[STRLEN], outName[STRLEN];
+char refName[STRLEN], outName[STRLEN];
+char imdName[STRLEN], statName[STRLEN];
 char refF[STRLEN], groupF[STRLEN], cntF[STRLEN], tiF[STRLEN];
 char mparamsF[STRLEN], bmparamsF[STRLEN];
+char modelF[STRLEN], thetaF[STRLEN];
 
 char inpSamType;
 char *pt_fn_list, *pt_chr_list;
@@ -127,7 +129,7 @@ void init(ReadReader<ReadType> **&readers, HitContainer<HitType> **&hitvs, doubl
                if (!readers[i]->locate(curnr)) { fprintf(stderr, "Read indices files do not match!\n"); exit(-1); }
                //assert(readers[i]->locate(curnr));
 
-               while (nrLeft > ntLeft && hitvs[i]->getNHits() < nhT) {
+               while (nrLeft > ntLeft && (i == nThreads - 1 || hitvs[i]->getNHits() < nhT)) {
                        if (!hitvs[i]->read(fin)) { fprintf(stderr, "Cannot read alignments from .dat file!\n"); exit(-1); }
                        //assert(hitvs[i]->read(fin));
                        --nrLeft;
@@ -301,20 +303,12 @@ void calcExpectedEffectiveLengths(ModelType& model) {
 template<class ModelType>
 void writeResults(ModelType& model, double* counts) {
        double denom;
-       char modelF[STRLEN], thetaF[STRLEN];
        char outF[STRLEN];
        FILE *fo;
 
-       sprintf(modelF, "%s.model", outName);
+       sprintf(modelF, "%s.model", statName);
        model.write(modelF);
 
-       sprintf(thetaF, "%s.theta", outName);
-       fo = fopen(thetaF, "w");
-       fprintf(fo, "%d\n", M + 1);
-       for (int i = 0; i < M; i++) fprintf(fo, "%.15g ", theta[i]);
-       fprintf(fo, "%.15g\n", theta[M]);
-       fclose(fo);
-
        //calculate tau values
        double *tau = new double[M + 1];
        memset(tau, 0, sizeof(double) * (M + 1));
@@ -408,6 +402,8 @@ inline bool doesUpdateModel(int ROUND) {
 //Including initialize, algorithm and results saving
 template<class ReadType, class HitType, class ModelType>
 void EM() {
+       FILE *fo;
+
        int ROUND;
        double sum;
 
@@ -426,6 +422,7 @@ void EM() {
        void *status;
        int rc;
 
+
        //initialize boolean variables
        updateModel = calcExpectedWeights = false;
 
@@ -514,17 +511,6 @@ void EM() {
 
        if (totNum > 0) fprintf(stderr, "Warning: RSEM reaches %d iterations before meeting the convergence criteria.\n", MAX_ROUND);
 
-       //calculate expected effective lengths for each isoform
-       calcExpectedEffectiveLengths<ModelType>(model);
-
-       //correct theta vector
-       sum = theta[0];
-       for (int i = 1; i <= M; i++) 
-         if (eel[i] < EPSILON) { theta[i] = 0.0; }
-         else sum += theta[i];
-       if (sum < EPSILON) { fprintf(stderr, "No Expected Effective Length is no less than %.6g?!\n", MINEEL); exit(-1); }
-       for (int i = 0; i <= M; i++) theta[i] /= sum;
-
        //generate output file used by Gibbs sampler
        if (genGibbsOut) {
                if (model.getNeedCalcConPrb()) {
@@ -540,7 +526,7 @@ void EM() {
                model.setNeedCalcConPrb(false);
 
                sprintf(out_for_gibbs_F, "%s.ofg", imdName);
-               FILE *fo = fopen(out_for_gibbs_F, "w");
+               fo = fopen(out_for_gibbs_F, "w");
                fprintf(fo, "%d %d\n", M, N0);
                for (int i = 0; i < nThreads; i++) {
                        int numN = hitvs[i]->getN();
@@ -549,7 +535,7 @@ void EM() {
                                int to = hitvs[i]->getSAt(j + 1);
                                int totNum = 0;
 
-                               if (ncpvs[i][j] > 0.0) { ++totNum; fprintf(fo, "%d %.15g ", 0, ncpvs[i][j]); }
+                               if (ncpvs[i][j] >= EPSILON) { ++totNum; fprintf(fo, "%d %.15g ", 0, ncpvs[i][j]); }
                                for (int k = fr; k < to; k++) {
                                        HitType &hit = hitvs[i]->getHitAt(k);
                                        if (hit.getConPrb() >= EPSILON) {
@@ -564,6 +550,25 @@ void EM() {
                fclose(fo);
        }
 
+       sprintf(thetaF, "%s.theta", statName);
+       fo = fopen(thetaF, "w");
+       fprintf(fo, "%d\n", M + 1);
+
+       // output theta'
+       for (int i = 0; i < M; i++) fprintf(fo, "%.15g ", theta[i]);
+       fprintf(fo, "%.15g\n", theta[M]);
+       
+       //calculate expected effective lengths for each isoform
+       calcExpectedEffectiveLengths<ModelType>(model);
+
+       //correct theta vector
+       sum = theta[0];
+       for (int i = 1; i <= M; i++) 
+         if (eel[i] < EPSILON) { theta[i] = 0.0; }
+         else sum += theta[i];
+       if (sum < EPSILON) { fprintf(stderr, "No Expected Effective Length is no less than %.6g?!\n", MINEEL); exit(-1); }
+       for (int i = 0; i <= M; i++) theta[i] /= sum;
+
        //calculate expected weights and counts using learned parameters
        updateModel = false; calcExpectedWeights = true;
        for (int i = 0; i < nThreads; i++) {
@@ -587,8 +592,7 @@ void EM() {
        /* destroy attribute */
        pthread_attr_destroy(&attr);
 
-       
-       //for all
+       //convert theta' to theta
                double *mw = model.getMW();
        sum = 0.0;
        for (int i = 0; i <= M; i++) {
@@ -598,6 +602,12 @@ void EM() {
        assert(sum >= EPSILON);
        for (int i = 0; i <= M; i++) theta[i] /= sum;
 
+       // output theta
+       for (int i = 0; i < M; i++) fprintf(fo, "%.15g ", theta[i]);
+       fprintf(fo, "%.15g\n", theta[M]);
+
+       fclose(fo);
+
        writeResults<ModelType>(model, countvs[0]);
 
        if (genBamF) {
@@ -620,11 +630,11 @@ int main(int argc, char* argv[]) {
        bool quiet = false;
 
        if (argc < 5) {
-               printf("Usage : rsem-run-em refName read_type imdName outName [-p #Threads] [-b samInpType samInpF has_fn_list_? [fn_list]] [-q] [--gibbs-out]\n\n");
+               printf("Usage : rsem-run-em refName read_type sampleName sampleToken [-p #Threads] [-b samInpType samInpF has_fn_list_? [fn_list]] [-q] [--gibbs-out]\n\n");
                printf("  refName: reference name\n");
                printf("  read_type: 0 single read without quality score; 1 single read with quality score; 2 paired-end read without quality score; 3 paired-end read with quality score.\n");
-               printf("  imdName: name for all upstream/downstream user-unseen files. (different files have different suffices)\n");
-               printf("  outName: name for all output files. (different files have different suffices)\n");
+               printf("  sampleName: sample's name, including the path\n");
+               printf("  sampleToken: sampleName excludes the path\n");
                printf("  -p: number of threads which user wants to use. (default: 1)\n");
                printf("  -b: produce bam format output file. (default: off)\n");
                printf("  -q: set it quiet\n");
@@ -637,8 +647,9 @@ int main(int argc, char* argv[]) {
 
        strcpy(refName, argv[1]);
        read_type = atoi(argv[2]);
-       strcpy(imdName, argv[3]);
-       strcpy(outName, argv[4]);
+       strcpy(outName, argv[3]);
+       sprintf(imdName, "%s.temp/%s", argv[3], argv[4]);
+       sprintf(statName, "%s.stat/%s", argv[3], argv[4]);
 
        nThreads = 1;
 
@@ -676,7 +687,7 @@ int main(int argc, char* argv[]) {
        sprintf(tiF, "%s.ti", refName);
        transcripts.readFrom(tiF);
 
-       sprintf(cntF, "%s.cnt", imdName);
+       sprintf(cntF, "%s.cnt", statName);
        fin.open(cntF);
        if (!fin.is_open()) { fprintf(stderr, "Cannot open %s! It may not exist.\n", cntF); exit(-1); }
        fin>>N0>>N1>>N2>>N_tot;