]> git.donarmstrong.com Git - rsem.git/blob - sampling.h
Modified the use of uniform_01
[rsem.git] / sampling.h
1 #ifndef SAMPLING
2 #define SAMPLING
3
4 #include<ctime>
5 #include<cstdio>
6 #include<cassert>
7 #include<vector>
8 #include<set>
9
10 #include "boost/random.hpp"
11
12 typedef unsigned int seedType;
13 typedef boost::random::mt19937 engine_type;
14 typedef boost::random::uniform_01<> uniform_01_dist;
15 typedef boost::random::gamma_distribution<> gamma_dist;
16 typedef boost::random::variate_generator<engine_type&, uniform_01_dist> uniform_01_generator;
17 typedef boost::random::variate_generator<engine_type&, gamma_dist> gamma_generator;
18
19 class engineFactory {
20 public:
21   static void init() { seedEngine = new engine_type(time(NULL)); }
22   static void init(seedType seed) { seedEngine = new engine_type(seed); }
23
24   static void finish() { if (seedEngine != NULL) delete seedEngine; }
25
26         static engine_type *new_engine() {
27                 seedType seed;
28                 static std::set<seedType> seedSet;                      // empty set of seeds
29                 std::set<seedType>::iterator iter;
30
31                 do {
32                         seed = (*seedEngine)();
33                         iter = seedSet.find(seed);
34                 } while (iter != seedSet.end());
35                 seedSet.insert(seed);
36
37                 return new engine_type(seed);
38         }
39
40  private:
41         static engine_type *seedEngine;
42 };
43
44 engine_type* engineFactory::seedEngine = NULL;
45
46 // arr should be cumulative!
47 // interval : [,)
48 // random number should be in [0, arr[len - 1])
49 // If by chance arr[len - 1] == 0.0, one possibility is to sample uniformly from 0...len-1
50 int sample(uniform_01_generator& rg, std::vector<double>& arr, int len) {
51   int l, r, mid;
52   double prb = rg() * arr[len - 1];
53
54   l = 0; r = len - 1;
55   while (l <= r) {
56     mid = (l + r) / 2;
57     if (arr[mid] <= prb) l = mid + 1;
58     else r = mid - 1;
59   }
60
61   if (l >= len) { printf("%d %lf %lf\n", len, arr[len - 1], prb); }
62   assert(l < len);
63
64   return l;
65 }
66
67 #endif