]> git.donarmstrong.com Git - rsem.git/blob - Gibbs.cpp
Added check in rsem-plot-model for .stat output directory
[rsem.git] / Gibbs.cpp
1 #include<ctime>
2 #include<cstdio>
3 #include<cstring>
4 #include<cstdlib>
5 #include<cassert>
6 #include<fstream>
7 #include<sstream>
8 #include<vector>
9
10 #include "boost/random.hpp"
11
12 #include "utils.h"
13
14 #include "Model.h"
15 #include "SingleModel.h"
16 #include "SingleQModel.h"
17 #include "PairedEndModel.h"
18 #include "PairedEndQModel.h"
19
20 #include "Refs.h"
21 #include "GroupInfo.h"
22
23 using namespace std;
24
25 struct Item {
26         int sid;
27         double conprb;
28
29         Item(int sid, double conprb) {
30                 this->sid = sid;
31                 this->conprb = conprb;
32         }
33 };
34
35 int model_type;
36 int m, M, N0, N1, nHits;
37 double totc;
38 int BURNIN, CHAINLEN, GAP;
39 char imdName[STRLEN], statName[STRLEN];
40 char thetaF[STRLEN], ofgF[STRLEN], groupF[STRLEN], refF[STRLEN], modelF[STRLEN];
41 char cvsF[STRLEN];
42
43 Refs refs;
44 GroupInfo gi;
45
46 vector<double> theta, pme_theta, pme_c, eel;
47
48 vector<int> s, z;
49 vector<Item> hits;
50 vector<int> counts;
51
52 bool quiet;
53
54 vector<double> arr;
55
56 boost::mt19937 rng(time(NULL));
57 boost::uniform_01<boost::mt19937> rg(rng);
58
59 void load_data(char* reference_name, char* statName, char* imdName) {
60         ifstream fin;
61         string line;
62         int tmpVal;
63
64         //load reference file
65         sprintf(refF, "%s.seq", reference_name);
66         refs.loadRefs(refF, 1);
67         M = refs.getM();
68
69         //load groupF
70         sprintf(groupF, "%s.grp", reference_name);
71         gi.load(groupF);
72         m = gi.getm();
73
74         //load thetaF
75         sprintf(thetaF, "%s.theta",statName);
76         fin.open(thetaF);
77         if (!fin.is_open()) {
78                 fprintf(stderr, "Cannot open %s!\n", thetaF);
79                 exit(-1);
80         }
81         fin>>tmpVal;
82         if (tmpVal != M + 1) {
83                 fprintf(stderr, "Number of transcripts is not consistent in %s and %s!\n", refF, thetaF);
84                 exit(-1);
85         }
86         theta.clear(); theta.resize(M + 1);
87         for (int i = 0; i <= M; i++) fin>>theta[i];
88         fin.close();
89
90         //load ofgF;
91         sprintf(ofgF, "%s.ofg", imdName);
92         fin.open(ofgF);
93         if (!fin.is_open()) {
94                 fprintf(stderr, "Cannot open %s!\n", ofgF);
95                 exit(-1);
96         }
97         fin>>tmpVal>>N0;
98         if (tmpVal != M) {
99                 fprintf(stderr, "M in %s is not consistent with %s!\n", ofgF, refF);
100                 exit(-1);
101         }
102         getline(fin, line);
103
104         s.clear(); hits.clear();
105         s.push_back(0);
106         while (getline(fin, line)) {
107                 istringstream strin(line);
108                 int sid;
109                 double conprb;
110
111                 while (strin>>sid>>conprb) {
112                         hits.push_back(Item(sid, conprb));
113                 }
114                 s.push_back(hits.size());
115         }
116         fin.close();
117
118         N1 = s.size() - 1;
119         nHits = hits.size();
120
121         if (verbose) { printf("Loading Data is finished!\n"); }
122 }
123
124 // arr should be cumulative!
125 // interval : [,)
126 // random number should be in [0, arr[len - 1])
127 // If by chance arr[len - 1] == 0.0, one possibility is to sample uniformly from 0...len-1
128 int sample(vector<double>& arr, int len) {
129   int l, r, mid;
130   double prb = rg() * arr[len - 1];
131
132   l = 0; r = len - 1;
133   while (l <= r) {
134     mid = (l + r) / 2;
135     if (arr[mid] <= prb) l = mid + 1;
136     else r = mid - 1;
137   }
138
139   if (l >= len) { printf("%d %lf %lf\n", len, arr[len - 1], prb); }
140   assert(l < len);
141
142   return l;
143 }
144
145 void init() {
146         int len, fr, to;
147
148         arr.clear();
149         z.clear();
150         counts.clear();
151
152         z.resize(N1);
153         counts.resize(M + 1, 1); // 1 pseudo count
154         counts[0] += N0;
155
156         for (int i = 0; i < N1; i++) {
157                 fr = s[i]; to = s[i + 1];
158                 len = to - fr;
159                 arr.resize(len);
160                 for (int j = fr; j < to; j++) {
161                         arr[j - fr] = theta[hits[j].sid] * hits[j].conprb;
162                         if (j > fr) arr[j - fr] += arr[j - fr - 1];  // cumulative
163                 }
164                 z[i] = hits[fr + sample(arr, len)].sid;
165                 ++counts[z[i]];
166         }
167
168         totc = N0 + N1 + (M + 1);
169
170         if (verbose) { printf("Initialization is finished!\n"); }
171 }
172
173 void writeCountVector(FILE* fo) {
174         for (int i = 0; i < M; i++) {
175                 fprintf(fo, "%d ", counts[i]);
176         }
177         fprintf(fo, "%d\n", counts[M]);
178 }
179
180 void Gibbs(char* imdName) {
181         FILE *fo;
182         int fr, to, len;
183
184         sprintf(cvsF, "%s.countvectors", imdName);
185         fo = fopen(cvsF, "w");
186         assert(CHAINLEN % GAP == 0);
187         fprintf(fo, "%d %d\n", CHAINLEN / GAP, M + 1);
188         //fprintf(fo, "%d %d\n", CHAINLEN, M + 1);
189
190         pme_c.clear(); pme_c.resize(M + 1, 0.0);
191         pme_theta.clear(); pme_theta.resize(M + 1, 0.0);
192         for (int ROUND = 1; ROUND <= BURNIN + CHAINLEN; ROUND++) {
193
194                 for (int i = 0; i < N1; i++) {
195                         --counts[z[i]];
196                         fr = s[i]; to = s[i + 1]; len = to - fr;
197                         arr.resize(len);
198                         for (int j = fr; j < to; j++) {
199                                 arr[j - fr] = counts[hits[j].sid] * hits[j].conprb;
200                                 if (j > fr) arr[j - fr] += arr[j - fr - 1]; //cumulative
201                         }
202                         z[i] = hits[fr + sample(arr, len)].sid;
203                         ++counts[z[i]];
204                 }
205
206                 if (ROUND > BURNIN) {
207                         if ((ROUND - BURNIN - 1) % GAP == 0) writeCountVector(fo);
208                         for (int i = 0; i <= M; i++) { 
209                           pme_c[i] += counts[i] - 1;
210                           pme_theta[i] += counts[i] / totc;
211                         }
212                 }
213
214                 if (verbose) { printf("ROUND %d is finished!\n", ROUND); }
215         }
216         fclose(fo);
217
218         for (int i = 0; i <= M; i++) {
219           pme_c[i] /= CHAINLEN;
220           pme_theta[i] /= CHAINLEN;
221         }
222
223         if (verbose) { printf("Gibbs is finished!\n"); }
224 }
225
226 template<class ModelType>
227 void calcExpectedEffectiveLengths(ModelType& model) {
228   int lb, ub, span;
229   double *pdf = NULL, *cdf = NULL, *clen = NULL; // clen[i] = sigma_{j=1}^{i}pdf[i]*(lb+i)
230   
231   model.getGLD().copyTo(pdf, cdf, lb, ub, span);
232   clen = new double[span + 1];
233   clen[0] = 0.0;
234   for (int i = 1; i <= span; i++) {
235     clen[i] = clen[i - 1] + pdf[i] * (lb + i);
236   }
237
238   eel.clear();
239   eel.resize(M + 1, 0.0);
240   for (int i = 1; i <= M; i++) {
241     int totLen = refs.getRef(i).getTotLen();
242     int fullLen = refs.getRef(i).getFullLen();
243     int pos1 = max(min(totLen - fullLen + 1, ub) - lb, 0);
244     int pos2 = max(min(totLen, ub) - lb, 0);
245
246     if (pos2 == 0) { eel[i] = 0.0; continue; }
247     
248     eel[i] = fullLen * cdf[pos1] + ((cdf[pos2] - cdf[pos1]) * (totLen + 1) - (clen[pos2] - clen[pos1]));
249     assert(eel[i] >= 0);
250     if (eel[i] < MINEEL) { eel[i] = 0.0; }
251   }
252   
253   delete[] pdf;
254   delete[] cdf;
255   delete[] clen;
256 }
257
258 template<class ModelType>
259 void writeEstimatedParameters(char* modelF, char* imdName) {
260         ModelType model;
261         double denom;
262         char outF[STRLEN];
263         FILE *fo;
264
265         model.read(modelF);
266
267         calcExpectedEffectiveLengths<ModelType>(model);
268
269         denom = pme_theta[0];
270         for (int i = 1; i <= M; i++)
271           if (eel[i] < EPSILON) pme_theta[i] = 0.0;
272           else denom += pme_theta[i];
273         if (denom <= 0) { fprintf(stderr, "No Expected Effective Length is no less than %.6g?!\n", MINEEL); exit(-1); }
274         for (int i = 0; i <= M; i++) pme_theta[i] /= denom;
275
276         denom = 0.0;
277         double *mw = model.getMW();
278         for (int i = 0; i <= M; i++) {
279           pme_theta[i] = (mw[i] < EPSILON ? 0.0 : pme_theta[i] / mw[i]);
280           denom += pme_theta[i];
281         }
282         assert(denom >= EPSILON);
283         for (int i = 0; i <= M; i++) pme_theta[i] /= denom;
284
285         //calculate tau values
286         double *tau = new double[M + 1];
287         memset(tau, 0, sizeof(double) * (M + 1));
288
289         denom = 0.0;
290         for (int i = 1; i <= M; i++) 
291           if (eel[i] > EPSILON) {
292             tau[i] = pme_theta[i] / eel[i];
293             denom += tau[i];
294           }
295         if (denom <= 0) { fprintf(stderr, "No alignable reads?!\n"); exit(-1); }
296         //assert(denom > 0);
297         for (int i = 1; i <= M; i++) {
298                 tau[i] /= denom;
299         }
300
301         //isoform level results
302         sprintf(outF, "%s.iso_res", imdName);
303         fo = fopen(outF, "a");
304         if (fo == NULL) { fprintf(stderr, "Cannot open %s!\n", outF); exit(-1); }
305         for (int i = 1; i <= M; i++)
306                 fprintf(fo, "%.2f%c", pme_c[i], (i < M ? '\t' : '\n'));
307         for (int i = 1; i <= M; i++)
308                 fprintf(fo, "%.15g%c", tau[i], (i < M ? '\t' : '\n'));
309         fclose(fo);
310
311         //gene level results
312         sprintf(outF, "%s.gene_res", imdName);
313         fo = fopen(outF, "a");
314         if (fo == NULL) { fprintf(stderr, "Cannot open %s!\n", outF); exit(-1); }
315         for (int i = 0; i < m; i++) {
316                 double sumC = 0.0; //  sum of pme counts
317                 int b = gi.spAt(i), e = gi.spAt(i + 1);
318                 for (int j = b; j < e; j++) {
319                         sumC += pme_c[j];
320                 }
321                 fprintf(fo, "%.15g%c", sumC, (i < m - 1 ? '\t' : '\n'));
322         }
323         for (int i = 0; i < m; i++) {
324                 double sumT = 0.0; //  sum of tau values
325                 int b = gi.spAt(i), e = gi.spAt(i + 1);
326                 for (int j = b; j < e; j++) {
327                         sumT += tau[j];
328                 }
329                 fprintf(fo, "%.15g%c", sumT, (i < m - 1 ? '\t' : '\n'));
330         }
331         fclose(fo);
332
333         delete[] tau;
334
335         if (verbose) { printf("Gibbs based expression values are written!\n"); }
336 }
337
338
339 int main(int argc, char* argv[]) {
340         if (argc < 7) {
341                 printf("Usage: rsem-run-gibbs reference_name sample_name sampleToken BURNIN CHAINLEN GAP [-q]\n");
342                 exit(-1);
343         }
344
345         BURNIN = atoi(argv[4]);
346         CHAINLEN = atoi(argv[5]);
347         GAP = atoi(argv[6]);
348         sprintf(imdName, "%s.temp/%s", argv[2], argv[3]);
349         sprintf(statName, "%s.stat/%s", argv[2], argv[3]);
350         load_data(argv[1], statName, imdName);
351
352         quiet = false;
353         if (argc > 7 && !strcmp(argv[7], "-q")) {
354                 quiet = true;
355         }
356         verbose = !quiet;
357
358         init();
359         Gibbs(imdName);
360
361         sprintf(modelF, "%s.model", statName);
362         FILE *fi = fopen(modelF, "r");
363         if (fi == NULL) { fprintf(stderr, "Cannot open %s!\n", modelF); exit(-1); }
364         fscanf(fi, "%d", &model_type);
365         fclose(fi);
366
367         switch(model_type) {
368         case 0 : writeEstimatedParameters<SingleModel>(modelF, imdName); break;
369         case 1 : writeEstimatedParameters<SingleQModel>(modelF, imdName); break;
370         case 2 : writeEstimatedParameters<PairedEndModel>(modelF, imdName); break;
371         case 3 : writeEstimatedParameters<PairedEndQModel>(modelF, imdName); break;
372         }
373
374         return 0;
375 }