]> git.donarmstrong.com Git - rsem.git/blobdiff - calcCI.cpp
rewrote parallelization of calcCI.cpp
[rsem.git] / calcCI.cpp
index 79a6bb2f924ddd0746e0cd816430541ad00e0370..3ae71157c58a023bb4cd7fe316e26c7b9e572663 100644 (file)
@@ -27,6 +27,8 @@ using namespace std;
 struct Params {
        int no;
        FILE *fi;
+       bufsize_type size;
+       int sp;
        engine_type *engine;
        double *mw;
 };
@@ -115,14 +117,14 @@ void* sample_theta_from_c(void* arg) {
        Params *params = (Params*)arg;
        FILE *fi = params->fi;
        double *mw = params->mw;
+       Buffer buffer(params->size, params->sp, nSamples, cvlen, tmpF);
 
        cvec = new int[cvlen];
        theta = new double[cvlen];
        gammas = new gamma_dist*[cvlen];
        rgs = new gamma_generator*[cvlen];
 
-       float **vecs = new float*[nSpC];
-       for (int i = 0; i < nSpC; i++) vecs[i] = new float[cvlen];
+       float *vec = new float[cvlen];
 
        int cnt = 0;
        while (fscanf(fi, "%d", &cvec[0]) == 1) {
@@ -154,19 +156,18 @@ void* sample_theta_from_c(void* arg) {
 
 
                        sum = 0.0;
-                       vecs[i][0] = theta[0];
+                       vec[0] = theta[0];
                        for (int j = 1; j < cvlen; j++)
                                if (eel[j] >= EPSILON) {
-                                       vecs[i][j] = theta[j] / eel[j];
-                                       sum += vecs[i][j];
+                                       vec[j] = theta[j] / eel[j];
+                                       sum += vec[j];
                                }
                                else assert(theta[j] < EPSILON);
-
                        assert(sum >= EPSILON);
-                       for (int j = 1; j < cvlen; j++) vecs[i][j] /= sum;
-               }
+                       for (int j = 1; j < cvlen; j++) vec[j] /= sum;
 
-               buffer->write(nSpC, vecs);
+                       buffer.write(vec);
+               }
 
                for (int j = 0; j < cvlen; j++) {
                        delete gammas[j];
@@ -181,8 +182,7 @@ void* sample_theta_from_c(void* arg) {
        delete[] gammas;
        delete[] rgs;
 
-       for (int i = 0; i < nSpC; i++) delete[] vecs[i];
-       delete[] vecs;
+       delete[] vec;
 
        return NULL;
 }
@@ -193,29 +193,57 @@ void sample_theta_vectors_from_count_vectors() {
        model.read(modelF);
        calcExpectedEffectiveLengths<ModelType>(model);
 
-       buffer = new Buffer(nMB, nSamples, cvlen, tmpF);
+       char splitF[STRLEN];
+       bufsize_type buf_maxcv = bufsize_type(nMB) * 1024 * 1024 / FLOATSIZE / cvlen;
+       bufsize_type quotient, left;
+       int sum = 0;
 
-       paramsArray = new Params[nThreads];
-       threads = new pthread_t[nThreads];
+       sprintf(splitF, "%s.split", imdName);
+       FILE *fi = fopen(splitF, "r");
+       int num_threads;
+       assert(fscanf(fi, "%d", &num_threads) == 1);
+       assert(num_threads <= nThreads);
+
+       quotient = buf_maxcv / num_threads;
+       assert(quotient > 0);
+       left = buf_maxcv % num_threads;
+
+       paramsArray = new Params[num_threads];
+       threads = new pthread_t[num_threads];
 
        char inpF[STRLEN];
-       for (int i = 0; i < nThreads; i++) {
+       for (int i = 0; i < num_threads; i++) {
                paramsArray[i].no = i;
                sprintf(inpF, "%s%d", cvsF, i);
                paramsArray[i].fi = fopen(inpF, "r");
+
+               int num_samples;
+               assert(fscanf(fi, "%d", &num_samples) == 1);
+               num_samples *= nSpC;
+               if (bufsize_type(num_samples) <= quotient) paramsArray[i].size = num_samples;
+               else {
+                       paramsArray[i].size = quotient;
+                       if (left > 0) { ++paramsArray[i].size; --left; }
+               }
+               paramsArray[i].size *=  cvlen * FLOATSIZE;
+               paramsArray[i].sp = sum;
+               sum += num_samples;
+
                paramsArray[i].engine = engineFactory::new_engine();
                paramsArray[i].mw = model.getMW();
        }
 
+       fclose(fi);
+
        /* set thread attribute to be joinable */
        pthread_attr_init(&attr);
        pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
 
-       for (int i = 0; i < nThreads; i++) {
+       for (int i = 0; i < num_threads; i++) {
                rc = pthread_create(&threads[i], &attr, &sample_theta_from_c, (void*)(&paramsArray[i]));
                pthread_assert(rc, "pthread_create", "Cannot create thread " + itos(i) + " (numbered from 0) in sample_theta_vectors_from_count_vectors!");
        }
-       for (int i = 0; i < nThreads; i++) {
+       for (int i = 0; i < num_threads; i++) {
                rc = pthread_join(threads[i], &status);
                pthread_assert(rc, "pthread_join", "Cannot join thread " + itos(i) + " (numbered from 0) in sample_theta_vectors_from_count_vectors!");
        }
@@ -224,14 +252,12 @@ void sample_theta_vectors_from_count_vectors() {
        pthread_attr_destroy(&attr);
        delete[] threads;
 
-       for (int i = 0; i < nThreads; i++) {
+       for (int i = 0; i < num_threads; i++) {
                fclose(paramsArray[i].fi);
                delete paramsArray[i].engine;
        }
        delete[] paramsArray;
 
-       delete buffer; // Must delete here, force the content left in the buffer be written into the disk
-
        if (verbose) { printf("Sampling is finished!\n"); }
 }
 
@@ -322,6 +348,7 @@ void calculate_credibility_intervals(char* imdName) {
        iso_tau = new CIType[M + 1];
        gene_tau = new CIType[m];
 
+       // nThreads must be intact here.
        assert(M > 0);
        int quotient = M / nThreads;
        if (quotient < 1) { nThreads = M; quotient = 1; }
@@ -410,11 +437,6 @@ int main(int argc, char* argv[]) {
        }
        verbose = !quiet;
 
-       if (nThreads > nCV) {
-               nThreads = nCV;
-               printf("Warning: Number of count vectors is less than number of threads! Change the number of threads to %d!\n", nThreads);
-       }
-
        sprintf(refF, "%s.seq", argv[1]);
        refs.loadRefs(refF, 1);
        M = refs.getM();