]> git.donarmstrong.com Git - rsem.git/blobdiff - sampling.h
Added .gitignore file back
[rsem.git] / sampling.h
index 88200909d4b6e8bf13854808fa72f835ce20cdcb..7e445cf8995e871183b5b18b66ee23b8efd91b8e 100644 (file)
 #include "boost/random.hpp"
 
 typedef unsigned int seedType;
-typedef boost::mt19937 engine_type;
-typedef boost::gamma_distribution<> gamma_dist;
-typedef boost::uniform_01<engine_type> uniform01;
-typedef boost::variate_generator<engine_type&, gamma_dist> gamma_generator;
+typedef boost::random::mt19937 engine_type;
+typedef boost::random::uniform_01<> uniform_01_dist;
+typedef boost::random::gamma_distribution<> gamma_dist;
+typedef boost::random::variate_generator<engine_type&, uniform_01_dist> uniform_01_generator;
+typedef boost::random::variate_generator<engine_type&, gamma_dist> gamma_generator;
 
 class engineFactory {
 public:
+  static void init() { seedEngine = new engine_type(time(NULL)); }
+  static void init(seedType seed) { seedEngine = new engine_type(seed); }
+
+  static void finish() { if (seedEngine != NULL) delete seedEngine; }
+
        static engine_type *new_engine() {
                seedType seed;
-               static engine_type seedEngine(time(NULL));
                static std::set<seedType> seedSet;                      // empty set of seeds
                std::set<seedType>::iterator iter;
 
                do {
-                       seed = seedEngine();
+                       seed = (*seedEngine)();
                        iter = seedSet.find(seed);
                } while (iter != seedSet.end());
                seedSet.insert(seed);
 
                return new engine_type(seed);
        }
+
+ private:
+       static engine_type *seedEngine;
 };
 
+engine_type* engineFactory::seedEngine = NULL;
+
 // arr should be cumulative!
 // interval : [,)
 // random number should be in [0, arr[len - 1])
 // If by chance arr[len - 1] == 0.0, one possibility is to sample uniformly from 0...len-1
-int sample(uniform01& rg, std::vector<double>& arr, int len) {
+int sample(uniform_01_generator& rg, std::vector<double>& arr, int len) {
   int l, r, mid;
   double prb = rg() * arr[len - 1];