]> git.donarmstrong.com Git - rsem.git/commitdiff
Added --seed option to set random number generator seeds in 'rsem-calculate-expression'
authorBo Li <bli@cs.wisc.edu>
Fri, 6 Jun 2014 03:53:28 +0000 (22:53 -0500)
committerBo Li <bli@cs.wisc.edu>
Fri, 6 Jun 2014 03:53:28 +0000 (22:53 -0500)
EM.cpp
Gibbs.cpp
calcCI.cpp
rsem-calculate-expression
sampling.h

diff --git a/EM.cpp b/EM.cpp
index bf15c2b4b6493ca3674804428bc456540f4aa345..4bda1d8ce8ae7431f7251724fee856aa37b5650f 100644 (file)
--- a/EM.cpp
+++ b/EM.cpp
@@ -89,6 +89,9 @@ Transcripts transcripts;
 
 ModelParams mparams;
 
+bool hasSeed;
+seedType seed;
+
 template<class ReadType, class HitType, class ModelType>
 void init(ReadReader<ReadType> **&readers, HitContainer<HitType> **&hitvs, double **&ncpvs, ModelType **&mhps) {
        READ_INT_TYPE nReads;
@@ -503,7 +506,7 @@ void EM() {
                        READ_INT_TYPE local_N;
                        HIT_INT_TYPE fr, to, len, id;
                        vector<double> arr;
-                       uniform01 rg(engine_type(time(NULL)));
+                       uniform01 rg(engine_type(hasSeed ? seed : time(NULL)));
 
                        if (verbose) cout<< "Begin to sample reads from their posteriors."<< endl;
                        for (int i = 0; i < nThreads; i++) {
@@ -536,7 +539,7 @@ int main(int argc, char* argv[]) {
        bool quiet = false;
 
        if (argc < 6) {
-               printf("Usage : rsem-run-em refName read_type sampleName imdName statName [-p #Threads] [-b samInpType samInpF has_fn_list_? [fn_list]] [-q] [--gibbs-out] [--sampling]\n\n");
+               printf("Usage : rsem-run-em refName read_type sampleName imdName statName [-p #Threads] [-b samInpType samInpF has_fn_list_? [fn_list]] [-q] [--gibbs-out] [--sampling] [--seed seed]\n\n");
                printf("  refName: reference name\n");
                printf("  read_type: 0 single read without quality score; 1 single read with quality score; 2 paired-end read without quality score; 3 paired-end read with quality score.\n");
                printf("  sampleName: sample's name, including the path\n");
@@ -546,6 +549,7 @@ int main(int argc, char* argv[]) {
                printf("  -q: set it quiet\n");
                printf("  --gibbs-out: generate output file used by Gibbs sampler. (default: off)\n");
                printf("  --sampling: sample each read from its posterior distribution when bam file is generated. (default: off)\n");
+               printf("  --seed uint32: the seed used for the BAM sampling. (default: off)\n");
                printf("// model parameters should be in imdName.mparams.\n");
                exit(-1);
        }
@@ -564,6 +568,7 @@ int main(int argc, char* argv[]) {
        bamSampling = false;
        genGibbsOut = false;
        pt_fn_list = pt_chr_list = NULL;
+       hasSeed = false;
 
        for (int i = 6; i < argc; i++) {
                if (!strcmp(argv[i], "-p")) { nThreads = atoi(argv[i + 1]); }
@@ -579,6 +584,12 @@ int main(int argc, char* argv[]) {
                if (!strcmp(argv[i], "-q")) { quiet = true; }
                if (!strcmp(argv[i], "--gibbs-out")) { genGibbsOut = true; }
                if (!strcmp(argv[i], "--sampling")) { bamSampling = true; }
+               if (!strcmp(argv[i], "--seed")) {
+                 hasSeed = true;
+                 int len = strlen(argv[i + 1]);
+                 seed = 0;
+                 for (int k = 0; k < len; k++) seed = seed * 10 + (argv[i + 1][k] - '0');
+               }
        }
 
        general_assert(nThreads > 0, "Number of threads should be bigger than 0!");
index e7a1182ce4b8ca78e5a41f59a0cdf5ddf657020a..7e26d864d9258c97ee693379c88330636ef67c5c 100644 (file)
--- a/Gibbs.cpp
+++ b/Gibbs.cpp
@@ -74,6 +74,9 @@ pthread_t *threads;
 pthread_attr_t attr;
 int rc;
 
+bool hasSeed;
+seedType seed;
+
 void load_data(char* refName, char* statName, char* imdName) {
        ifstream fin;
        string line;
@@ -135,6 +138,7 @@ void init() {
        paramsArray = new Params[nThreads];
        threads = new pthread_t[nThreads];
 
+       hasSeed ? engineFactory::init(seed) : engineFactory::init();
        for (int i = 0; i < nThreads; i++) {
                paramsArray[i].no = i;
 
@@ -154,6 +158,7 @@ void init() {
                paramsArray[i].pme_fpkm = new double[M + 1];
                memset(paramsArray[i].pme_fpkm, 0, sizeof(double) * (M + 1));
        }
+       engineFactory::finish();
 
        /* set thread attribute to be joinable */
        pthread_attr_init(&attr);
@@ -292,7 +297,7 @@ void release() {
 
 int main(int argc, char* argv[]) {
        if (argc < 7) {
-               printf("Usage: rsem-run-gibbs reference_name imdName statName BURNIN NSAMPLES GAP [-p #Threads] [--var] [-q]\n");
+               printf("Usage: rsem-run-gibbs reference_name imdName statName BURNIN NSAMPLES GAP [-p #Threads] [--var] [--seed seed] [-q]\n");
                exit(-1);
        }
 
@@ -306,11 +311,18 @@ int main(int argc, char* argv[]) {
 
        nThreads = 1;
        var_opt = false;
+       hasSeed = false;
        quiet = false;
 
        for (int i = 7; i < argc; i++) {
                if (!strcmp(argv[i], "-p")) nThreads = atoi(argv[i + 1]);
                if (!strcmp(argv[i], "--var")) var_opt = true;
+               if (!strcmp(argv[i], "--seed")) {
+                 hasSeed = true;
+                 int len = strlen(argv[i + 1]);
+                 seed = 0;
+                 for (int k = 0; k < len; k++) seed = seed * 10 + (argv[i + 1][k] - '0');
+               }
                if (!strcmp(argv[i], "-q")) quiet = true;
        }
        verbose = !quiet;
index 97eba64e1ea9965c3983409090dd252abb78cbad..86b3937ed8009fda5a40cfefdcc7825a0de0fc80 100644 (file)
@@ -80,6 +80,9 @@ pthread_t *threads;
 pthread_attr_t attr;
 int rc;
 
+bool hasSeed;
+seedType seed;
+
 CIParams *ciParamsArray;
 
 void* sample_theta_from_c(void* arg) {
@@ -165,6 +168,7 @@ void sample_theta_vectors_from_count_vectors() {
        threads = new pthread_t[num_threads];
 
        char inpF[STRLEN];
+       hasSeed ? engineFactory::init(seed) : engineFactory::init();
        for (int i = 0; i < num_threads; i++) {
                paramsArray[i].no = i;
                sprintf(inpF, "%s%d", cvsF, i);
@@ -172,6 +176,7 @@ void sample_theta_vectors_from_count_vectors() {
                paramsArray[i].engine = engineFactory::new_engine();
                paramsArray[i].mw = model.getMW();
        }
+       engineFactory::finish();
 
        /* set thread attribute to be joinable */
        pthread_attr_init(&attr);
@@ -458,7 +463,7 @@ void calculate_credibility_intervals(char* imdName) {
 
 int main(int argc, char* argv[]) {
        if (argc < 8) {
-               printf("Usage: rsem-calculate-credibility-intervals reference_name imdName statName confidence nCV nSpC nMB [-p #Threads] [-q]\n");
+               printf("Usage: rsem-calculate-credibility-intervals reference_name imdName statName confidence nCV nSpC nMB [-p #Threads] [--seed seed] [-q]\n");
                exit(-1);
        }
 
@@ -473,8 +478,15 @@ int main(int argc, char* argv[]) {
 
        nThreads = 1;
        quiet = false;
+       hasSeed = false;
        for (int i = 8; i < argc; i++) {
                if (!strcmp(argv[i], "-p")) nThreads = atoi(argv[i + 1]);
+               if (!strcmp(argv[i], "--seed")) {
+                 hasSeed = true;
+                 int len = strlen(argv[i + 1]);
+                 seed = 0;
+                 for (int k = 0; k < len; k++) seed = seed * 10 + (argv[i + 1][k] - '0');
+               }
                if (!strcmp(argv[i], "-q")) quiet = true;
        }
        verbose = !quiet;
index 03811afb0c8b79c1ae348ce01883820908e0c4b8..d6cb5f2a3f6b989ced2e1f96b3302c8d720ba9b5 100755 (executable)
@@ -73,6 +73,8 @@ my $bowtie2_mismatch_rate = 0.1;
 my $bowtie2_k = 200;
 my $bowtie2_sensitivity_level = "sensitive"; # must be one of "very_fast", "fast", "sensitive", "very_sensitive"
 
+my $seed = "NULL";
+
 my $version = 0;
 
 my $mTime = 0;
@@ -126,6 +128,7 @@ GetOptions("keep-intermediate-files" => \$keep_intermediate_files,
           "calc-ci" => \$calcCI,
           "ci-memory=i" => \$NMB,
           "samtools-sort-mem=s" => \$SortMem,
+          "seed=i" => \$seed,
           "time" => \$mTime,
           "version" => \$version,
           "q|quiet" => \$quiet,
@@ -160,6 +163,7 @@ pod2usage(-msg => "Number of threads should be at least 1!\n", -exitval => 2, -v
 pod2usage(-msg => "Seed length should be at least 5!\n", -exitval => 2, -verbose => 2) if ($L < 5);
 pod2usage(-msg => "--sampling-for-bam cannot be specified if --no-bam-output is specified!\n", -exitval => 2, -verbose => 2) if ($sampling && !$genBamF);
 pod2usage(-msg => "--output-genome-bam cannot be specified if --no-bam-output is specified!\n", -exitval => 2, -verbose => 2) if ($genGenomeBamF && !$genBamF);
+pod2usage(-msg => "The seed for random number generator must be a non-negative 32bit integer!\n", -exitval => 2, -verbose => 2) if (($seed ne "NULL") && ($seed < 0 || $seed > 0xffffffff));
 
 if ($L < 25) { print "Warning: the seed length set is less than 25! This is only allowed if the references are not added poly(A) tails.\n"; }
 
@@ -335,12 +339,21 @@ print OUTPUT "$mean $sd\n";
 print OUTPUT "$L\n";
 close(OUTPUT);  
 
+my @seeds = ();
+if ($seed ne "NULL") { 
+    srand($seed); 
+    for (my $i = 0; $i < 3; $i++) {
+       push(@seeds, int(rand(1 << 32))); 
+    }
+}
+
 $command = "rsem-run-em $refName $read_type $sampleName $imdName $statName -p $nThreads";
 if ($genBamF) { 
     $command .= " -b $samInpType $inpF";
     if ($fn_list ne "") { $command .= " 1 $fn_list"; }
     else { $command .= " 0"; }
     if ($sampling) { $command .= " --sampling"; }
+    if ($seed ne "NULL") { $command .= " --seed $seeds[0]"; }
 }
 if ($calcPME || $var_opt || $calcCI) { $command .= " --gibbs-out"; }
 if ($quiet) { $command .= " -q"; }
@@ -381,6 +394,7 @@ if ($calcPME || $var_opt || $calcCI ) {
     $command = "rsem-run-gibbs $refName $imdName $statName $BURNIN $NCV $SAMPLEGAP";
     $command .= " -p $nThreads";
     if ($var_opt) { $command .= " --var"; }
+    if ($seed ne "NULL") { $command .= " --seed $seeds[1]"; }
     if ($quiet) { $command .= " -q"; }
     &runCommand($command);
 }
@@ -405,6 +419,7 @@ if ($calcPME || $calcCI) {
 if ($calcCI) {
     $command = "rsem-calculate-credibility-intervals $refName $imdName $statName $CONFIDENCE $NCV $NSPC $NMB";
     $command .= " -p $nThreads";
+    if ($seed ne "NULL") { $command .= " --seed $seeds[2]"; }
     if ($quiet) { $command .= " -q"; }
     &runCommand($command);
 
@@ -526,6 +541,10 @@ Generate a BAM file, 'sample_name.genome.bam', with alignments mapped to genomic
 
 When RSEM generates a BAM file, instead of outputing all alignments a read has with their posterior probabilities, one alignment is sampled according to the posterior probabilities. The sampling procedure includes the alignment to the "noise" transcript, which does not appear in the BAM file. Only the sampled alignment has a weight of 1. All other alignments have weight 0. If the "noise" transcript is sampled, all alignments appeared in the BAM file should have weight 0. (Default: off)
 
+=item B<--seed> <uint32>
+
+Set the seed for the random number generators used in calculating posterior mean estimates and credibility intervals. The seed must be a non-negative 32 bit interger. (Default: off)
+
 =item B<--calc-pme>
 
 Run RSEM's collapsed Gibbs sampler to calculate posterior mean estimates. (Default: off) 
index 88200909d4b6e8bf13854808fa72f835ce20cdcb..8f80b72eca8d4982731d15a366100b5c68563ed3 100644 (file)
@@ -17,22 +17,31 @@ typedef boost::variate_generator<engine_type&, gamma_dist> gamma_generator;
 
 class engineFactory {
 public:
+  static void init() { seedEngine = new engine_type(time(NULL)); }
+  static void init(seedType seed) { seedEngine = new engine_type(seed); }
+
+  static void finish() { if (seedEngine != NULL) delete seedEngine; }
+
        static engine_type *new_engine() {
                seedType seed;
-               static engine_type seedEngine(time(NULL));
                static std::set<seedType> seedSet;                      // empty set of seeds
                std::set<seedType>::iterator iter;
 
                do {
-                       seed = seedEngine();
+                       seed = (*seedEngine)();
                        iter = seedSet.find(seed);
                } while (iter != seedSet.end());
                seedSet.insert(seed);
 
                return new engine_type(seed);
        }
+
+ private:
+       static engine_type *seedEngine;
 };
 
+engine_type* engineFactory::seedEngine = NULL;
+
 // arr should be cumulative!
 // interval : [,)
 // random number should be in [0, arr[len - 1])