]> git.donarmstrong.com Git - rsem.git/blob - Gibbs.cpp
rsem v1.1.14, add --sampling-for-bam option, modify rsem-bam2wig to handle BAM files...
[rsem.git] / Gibbs.cpp
1 #include<cstdio>
2 #include<cstring>
3 #include<cstdlib>
4 #include<cassert>
5 #include<fstream>
6 #include<sstream>
7 #include<vector>
8
9 #include "utils.h"
10 #include "sampling.h"
11
12 #include "Model.h"
13 #include "SingleModel.h"
14 #include "SingleQModel.h"
15 #include "PairedEndModel.h"
16 #include "PairedEndQModel.h"
17
18 #include "Refs.h"
19 #include "GroupInfo.h"
20
21 using namespace std;
22
23 struct Item {
24         int sid;
25         double conprb;
26
27         Item(int sid, double conprb) {
28                 this->sid = sid;
29                 this->conprb = conprb;
30         }
31 };
32
33 int model_type;
34 int m, M, N0, N1, nHits;
35 double totc;
36 int BURNIN, CHAINLEN, GAP;
37 char imdName[STRLEN], statName[STRLEN];
38 char thetaF[STRLEN], ofgF[STRLEN], groupF[STRLEN], refF[STRLEN], modelF[STRLEN];
39 char cvsF[STRLEN];
40
41 Refs refs;
42 GroupInfo gi;
43
44 vector<double> theta, pme_theta, pme_c, eel;
45
46 vector<int> s, z;
47 vector<Item> hits;
48 vector<int> counts;
49
50 bool quiet;
51
52 vector<double> arr;
53
54 void load_data(char* reference_name, char* statName, char* imdName) {
55         ifstream fin;
56         string line;
57         int tmpVal;
58
59         //load reference file
60         sprintf(refF, "%s.seq", reference_name);
61         refs.loadRefs(refF, 1);
62         M = refs.getM();
63
64         //load groupF
65         sprintf(groupF, "%s.grp", reference_name);
66         gi.load(groupF);
67         m = gi.getm();
68
69         //load thetaF
70         sprintf(thetaF, "%s.theta",statName);
71         fin.open(thetaF);
72         if (!fin.is_open()) {
73                 fprintf(stderr, "Cannot open %s!\n", thetaF);
74                 exit(-1);
75         }
76         fin>>tmpVal;
77         if (tmpVal != M + 1) {
78                 fprintf(stderr, "Number of transcripts is not consistent in %s and %s!\n", refF, thetaF);
79                 exit(-1);
80         }
81         theta.clear(); theta.resize(M + 1);
82         for (int i = 0; i <= M; i++) fin>>theta[i];
83         fin.close();
84
85         //load ofgF;
86         sprintf(ofgF, "%s.ofg", imdName);
87         fin.open(ofgF);
88         if (!fin.is_open()) {
89                 fprintf(stderr, "Cannot open %s!\n", ofgF);
90                 exit(-1);
91         }
92         fin>>tmpVal>>N0;
93         if (tmpVal != M) {
94                 fprintf(stderr, "M in %s is not consistent with %s!\n", ofgF, refF);
95                 exit(-1);
96         }
97         getline(fin, line);
98
99         s.clear(); hits.clear();
100         s.push_back(0);
101         while (getline(fin, line)) {
102                 istringstream strin(line);
103                 int sid;
104                 double conprb;
105
106                 while (strin>>sid>>conprb) {
107                         hits.push_back(Item(sid, conprb));
108                 }
109                 s.push_back(hits.size());
110         }
111         fin.close();
112
113         N1 = s.size() - 1;
114         nHits = hits.size();
115
116         if (verbose) { printf("Loading Data is finished!\n"); }
117 }
118
119 void init() {
120         int len, fr, to;
121
122         arr.clear();
123         z.clear();
124         counts.clear();
125
126         z.resize(N1);
127         counts.resize(M + 1, 1); // 1 pseudo count
128         counts[0] += N0;
129
130         for (int i = 0; i < N1; i++) {
131                 fr = s[i]; to = s[i + 1];
132                 len = to - fr;
133                 arr.resize(len);
134                 for (int j = fr; j < to; j++) {
135                         arr[j - fr] = theta[hits[j].sid] * hits[j].conprb;
136                         if (j > fr) arr[j - fr] += arr[j - fr - 1];  // cumulative
137                 }
138                 z[i] = hits[fr + sample(arr, len)].sid;
139                 ++counts[z[i]];
140         }
141
142         totc = N0 + N1 + (M + 1);
143
144         if (verbose) { printf("Initialization is finished!\n"); }
145 }
146
147 void writeCountVector(FILE* fo) {
148         for (int i = 0; i < M; i++) {
149                 fprintf(fo, "%d ", counts[i]);
150         }
151         fprintf(fo, "%d\n", counts[M]);
152 }
153
154 void Gibbs(char* imdName) {
155         FILE *fo;
156         int fr, to, len;
157
158         sprintf(cvsF, "%s.countvectors", imdName);
159         fo = fopen(cvsF, "w");
160         assert(CHAINLEN % GAP == 0);
161         fprintf(fo, "%d %d\n", CHAINLEN / GAP, M + 1);
162         //fprintf(fo, "%d %d\n", CHAINLEN, M + 1);
163
164         pme_c.clear(); pme_c.resize(M + 1, 0.0);
165         pme_theta.clear(); pme_theta.resize(M + 1, 0.0);
166         for (int ROUND = 1; ROUND <= BURNIN + CHAINLEN; ROUND++) {
167
168                 for (int i = 0; i < N1; i++) {
169                         --counts[z[i]];
170                         fr = s[i]; to = s[i + 1]; len = to - fr;
171                         arr.resize(len);
172                         for (int j = fr; j < to; j++) {
173                                 arr[j - fr] = counts[hits[j].sid] * hits[j].conprb;
174                                 if (j > fr) arr[j - fr] += arr[j - fr - 1]; //cumulative
175                         }
176                         z[i] = hits[fr + sample(arr, len)].sid;
177                         ++counts[z[i]];
178                 }
179
180                 if (ROUND > BURNIN) {
181                         if ((ROUND - BURNIN - 1) % GAP == 0) writeCountVector(fo);
182                         for (int i = 0; i <= M; i++) { 
183                           pme_c[i] += counts[i] - 1;
184                           pme_theta[i] += counts[i] / totc;
185                         }
186                 }
187
188                 if (verbose) { printf("ROUND %d is finished!\n", ROUND); }
189         }
190         fclose(fo);
191
192         for (int i = 0; i <= M; i++) {
193           pme_c[i] /= CHAINLEN;
194           pme_theta[i] /= CHAINLEN;
195         }
196
197         if (verbose) { printf("Gibbs is finished!\n"); }
198 }
199
200 template<class ModelType>
201 void calcExpectedEffectiveLengths(ModelType& model) {
202   int lb, ub, span;
203   double *pdf = NULL, *cdf = NULL, *clen = NULL; // clen[i] = sigma_{j=1}^{i}pdf[i]*(lb+i)
204   
205   model.getGLD().copyTo(pdf, cdf, lb, ub, span);
206   clen = new double[span + 1];
207   clen[0] = 0.0;
208   for (int i = 1; i <= span; i++) {
209     clen[i] = clen[i - 1] + pdf[i] * (lb + i);
210   }
211
212   eel.clear();
213   eel.resize(M + 1, 0.0);
214   for (int i = 1; i <= M; i++) {
215     int totLen = refs.getRef(i).getTotLen();
216     int fullLen = refs.getRef(i).getFullLen();
217     int pos1 = max(min(totLen - fullLen + 1, ub) - lb, 0);
218     int pos2 = max(min(totLen, ub) - lb, 0);
219
220     if (pos2 == 0) { eel[i] = 0.0; continue; }
221     
222     eel[i] = fullLen * cdf[pos1] + ((cdf[pos2] - cdf[pos1]) * (totLen + 1) - (clen[pos2] - clen[pos1]));
223     assert(eel[i] >= 0);
224     if (eel[i] < MINEEL) { eel[i] = 0.0; }
225   }
226   
227   delete[] pdf;
228   delete[] cdf;
229   delete[] clen;
230 }
231
232 template<class ModelType>
233 void writeEstimatedParameters(char* modelF, char* imdName) {
234         ModelType model;
235         double denom;
236         char outF[STRLEN];
237         FILE *fo;
238
239         model.read(modelF);
240
241         calcExpectedEffectiveLengths<ModelType>(model);
242
243         denom = pme_theta[0];
244         for (int i = 1; i <= M; i++)
245           if (eel[i] < EPSILON) pme_theta[i] = 0.0;
246           else denom += pme_theta[i];
247         if (denom <= 0) { fprintf(stderr, "No Expected Effective Length is no less than %.6g?!\n", MINEEL); exit(-1); }
248         for (int i = 0; i <= M; i++) pme_theta[i] /= denom;
249
250         denom = 0.0;
251         double *mw = model.getMW();
252         for (int i = 0; i <= M; i++) {
253           pme_theta[i] = (mw[i] < EPSILON ? 0.0 : pme_theta[i] / mw[i]);
254           denom += pme_theta[i];
255         }
256         assert(denom >= EPSILON);
257         for (int i = 0; i <= M; i++) pme_theta[i] /= denom;
258
259         //calculate tau values
260         double *tau = new double[M + 1];
261         memset(tau, 0, sizeof(double) * (M + 1));
262
263         denom = 0.0;
264         for (int i = 1; i <= M; i++) 
265           if (eel[i] > EPSILON) {
266             tau[i] = pme_theta[i] / eel[i];
267             denom += tau[i];
268           }
269         if (denom <= 0) { fprintf(stderr, "No alignable reads?!\n"); exit(-1); }
270         //assert(denom > 0);
271         for (int i = 1; i <= M; i++) {
272                 tau[i] /= denom;
273         }
274
275         //isoform level results
276         sprintf(outF, "%s.iso_res", imdName);
277         fo = fopen(outF, "a");
278         if (fo == NULL) { fprintf(stderr, "Cannot open %s!\n", outF); exit(-1); }
279         for (int i = 1; i <= M; i++)
280                 fprintf(fo, "%.2f%c", pme_c[i], (i < M ? '\t' : '\n'));
281         for (int i = 1; i <= M; i++)
282                 fprintf(fo, "%.15g%c", tau[i], (i < M ? '\t' : '\n'));
283         fclose(fo);
284
285         //gene level results
286         sprintf(outF, "%s.gene_res", imdName);
287         fo = fopen(outF, "a");
288         if (fo == NULL) { fprintf(stderr, "Cannot open %s!\n", outF); exit(-1); }
289         for (int i = 0; i < m; i++) {
290                 double sumC = 0.0; //  sum of pme counts
291                 int b = gi.spAt(i), e = gi.spAt(i + 1);
292                 for (int j = b; j < e; j++) {
293                         sumC += pme_c[j];
294                 }
295                 fprintf(fo, "%.15g%c", sumC, (i < m - 1 ? '\t' : '\n'));
296         }
297         for (int i = 0; i < m; i++) {
298                 double sumT = 0.0; //  sum of tau values
299                 int b = gi.spAt(i), e = gi.spAt(i + 1);
300                 for (int j = b; j < e; j++) {
301                         sumT += tau[j];
302                 }
303                 fprintf(fo, "%.15g%c", sumT, (i < m - 1 ? '\t' : '\n'));
304         }
305         fclose(fo);
306
307         delete[] tau;
308
309         if (verbose) { printf("Gibbs based expression values are written!\n"); }
310 }
311
312
313 int main(int argc, char* argv[]) {
314         if (argc < 7) {
315                 printf("Usage: rsem-run-gibbs reference_name sample_name sampleToken BURNIN CHAINLEN GAP [-q]\n");
316                 exit(-1);
317         }
318
319         BURNIN = atoi(argv[4]);
320         CHAINLEN = atoi(argv[5]);
321         GAP = atoi(argv[6]);
322         sprintf(imdName, "%s.temp/%s", argv[2], argv[3]);
323         sprintf(statName, "%s.stat/%s", argv[2], argv[3]);
324         load_data(argv[1], statName, imdName);
325
326         quiet = false;
327         if (argc > 7 && !strcmp(argv[7], "-q")) {
328                 quiet = true;
329         }
330         verbose = !quiet;
331
332         init();
333         Gibbs(imdName);
334
335         sprintf(modelF, "%s.model", statName);
336         FILE *fi = fopen(modelF, "r");
337         if (fi == NULL) { fprintf(stderr, "Cannot open %s!\n", modelF); exit(-1); }
338         fscanf(fi, "%d", &model_type);
339         fclose(fi);
340
341         switch(model_type) {
342         case 0 : writeEstimatedParameters<SingleModel>(modelF, imdName); break;
343         case 1 : writeEstimatedParameters<SingleQModel>(modelF, imdName); break;
344         case 2 : writeEstimatedParameters<PairedEndModel>(modelF, imdName); break;
345         case 3 : writeEstimatedParameters<PairedEndQModel>(modelF, imdName); break;
346         }
347
348         return 0;
349 }