]> git.donarmstrong.com Git - rsem.git/blob - Gibbs.cpp
RSEM Source Codes
[rsem.git] / Gibbs.cpp
1 #include<ctime>
2 #include<cstdio>
3 #include<cstring>
4 #include<cstdlib>
5 #include<cassert>
6 #include<fstream>
7 #include<sstream>
8 #include<vector>
9
10 #include "randomc.h"
11 #include "utils.h"
12
13 #include "Model.h"
14 #include "SingleModel.h"
15 #include "SingleQModel.h"
16 #include "PairedEndModel.h"
17 #include "PairedEndQModel.h"
18
19 #include "Refs.h"
20 #include "GroupInfo.h"
21
22 using namespace std;
23
24 struct Item {
25         int sid;
26         double conprb;
27
28         Item(int sid, double conprb) {
29                 this->sid = sid;
30                 this->conprb = conprb;
31         }
32 };
33
34 int model_type;
35 int m, M, N0, N1, nHits;
36 double totc;
37 int BURNIN, CHAINLEN, GAP;
38 char thetaF[STRLEN], ofgF[STRLEN], groupF[STRLEN], refF[STRLEN], modelF[STRLEN];
39 char cvsF[STRLEN];
40
41 Refs refs;
42 GroupInfo gi;
43
44 vector<double> theta, pme_theta, eel;
45
46 vector<int> s, z;
47 vector<Item> hits;
48 vector<int> counts;
49
50 bool quiet;
51
52 vector<double> arr;
53 CRandomMersenne rg(time(NULL));
54
55 void load_data(char* reference_name, char* sample_name, char* imdName) {
56         ifstream fin;
57         string line;
58         int tmpVal;
59
60         //load reference file
61         sprintf(refF, "%s.seq", reference_name);
62         refs.loadRefs(refF, 1);
63         M = refs.getM();
64
65         //load groupF
66         sprintf(groupF, "%s.grp", reference_name);
67         gi.load(groupF);
68         m = gi.getm();
69
70         //load thetaF
71         sprintf(thetaF, "%s.theta",sample_name);
72         fin.open(thetaF);
73         if (!fin.is_open()) {
74                 fprintf(stderr, "Cannot open %s!\n", thetaF);
75                 exit(-1);
76         }
77         fin>>tmpVal;
78         if (tmpVal != M + 1) {
79                 fprintf(stderr, "Number of transcripts is not consistent in %s and %s!\n", refF, thetaF);
80                 exit(-1);
81         }
82         theta.clear(); theta.resize(M + 1);
83         for (int i = 0; i <= M; i++) fin>>theta[i];
84         fin.close();
85
86         //load ofgF;
87         sprintf(ofgF, "%s.ofg", imdName);
88         fin.open(ofgF);
89         if (!fin.is_open()) {
90                 fprintf(stderr, "Cannot open %s!\n", ofgF);
91                 exit(-1);
92         }
93         fin>>tmpVal>>N0;
94         if (tmpVal != M) {
95                 fprintf(stderr, "M in %s is not consistent with %s!\n", ofgF, refF);
96                 exit(-1);
97         }
98         getline(fin, line);
99
100         s.clear(); hits.clear();
101         s.push_back(0);
102         while (getline(fin, line)) {
103                 istringstream strin(line);
104                 int sid;
105                 double conprb;
106
107                 while (strin>>sid>>conprb) {
108                         hits.push_back(Item(sid, conprb));
109                 }
110                 s.push_back(hits.size());
111         }
112         fin.close();
113
114         N1 = s.size() - 1;
115         nHits = hits.size();
116
117         if (verbose) { printf("Loading Data is finished!\n"); }
118 }
119
120 // arr should be cumulative!
121 // interval : [,)
122 // random number should be in [0, arr[len - 1])
123 // If by chance arr[len - 1] == 0.0, one possibility is to sample uniformly from 0...len-1
124 int sample(vector<double>& arr, int len) {
125   int l, r, mid;
126   double prb = rg.Random() * arr[len - 1];
127
128   l = 0; r = len - 1;
129   while (l <= r) {
130     mid = (l + r) / 2;
131     if (arr[mid] <= prb) l = mid + 1;
132     else r = mid - 1;
133   }
134
135   if (l >= len) { printf("%d %lf %lf\n", len, arr[len - 1], prb); }
136   assert(l < len);
137
138   return l;
139 }
140
141 void init() {
142         int len, fr, to;
143
144         arr.clear();
145         z.clear();
146         counts.clear();
147
148         z.resize(N1);
149         counts.resize(M + 1, 1); // 1 pseudo count
150         counts[0] += N0;
151
152         for (int i = 0; i < N1; i++) {
153                 fr = s[i]; to = s[i + 1];
154                 len = to - fr;
155                 arr.resize(len);
156                 for (int j = fr; j < to; j++) {
157                         arr[j - fr] = theta[hits[j].sid] * hits[j].conprb;
158                         if (j > fr) arr[j - fr] += arr[j - fr - 1];  // cumulative
159                 }
160                 z[i] = hits[fr + sample(arr, len)].sid;
161                 ++counts[z[i]];
162         }
163
164         totc = N0 + N1 + (M + 1);
165
166         if (verbose) { printf("Initialization is finished!\n"); }
167 }
168
169 void writeCountVector(FILE* fo) {
170         for (int i = 0; i < M; i++) {
171                 fprintf(fo, "%d ", counts[i]);
172         }
173         fprintf(fo, "%d\n", counts[M]);
174 }
175
176 void Gibbs(char* imdName) {
177         FILE *fo;
178         int fr, to, len;
179
180         sprintf(cvsF, "%s.countvectors", imdName);
181         fo = fopen(cvsF, "w");
182         assert(CHAINLEN % GAP == 0);
183         fprintf(fo, "%d %d\n", CHAINLEN / GAP, M + 1);
184         //fprintf(fo, "%d %d\n", CHAINLEN, M + 1);
185
186         pme_theta.clear(); pme_theta.resize(M + 1, 0.0);
187         for (int ROUND = 1; ROUND <= BURNIN + CHAINLEN; ROUND++) {
188
189                 for (int i = 0; i < N1; i++) {
190                         --counts[z[i]];
191                         fr = s[i]; to = s[i + 1]; len = to - fr;
192                         arr.resize(len);
193                         for (int j = fr; j < to; j++) {
194                                 arr[j - fr] = counts[hits[j].sid] * hits[j].conprb;
195                                 if (j > fr) arr[j - fr] += arr[j - fr - 1]; //cumulative
196                         }
197                         z[i] = hits[fr + sample(arr, len)].sid;
198                         ++counts[z[i]];
199                 }
200
201                 if (ROUND > BURNIN) {
202                         if ((ROUND - BURNIN -1) % GAP == 0) writeCountVector(fo);
203                         writeCountVector(fo);
204                         for (int i = 0; i <= M; i++) pme_theta[i] += counts[i] / totc;
205                 }
206
207                 if (verbose) { printf("ROUND %d is finished!\n", ROUND); }
208         }
209         fclose(fo);
210
211         for (int i = 0; i <= M; i++) pme_theta[i] /= CHAINLEN;
212
213         if (verbose) { printf("Gibbs is finished!\n"); }
214 }
215
216 template<class ModelType>
217 void calcExpectedEffectiveLengths(ModelType& model) {
218   int lb, ub, span;
219   double *pdf = NULL, *cdf = NULL, *clen = NULL; // clen[i] = sigma_{j=1}^{i}pdf[i]*(lb+i)
220   
221   model.getGLD().copyTo(pdf, cdf, lb, ub, span);
222   clen = new double[span + 1];
223   clen[0] = 0.0;
224   for (int i = 1; i <= span; i++) {
225     clen[i] = clen[i - 1] + pdf[i] * (lb + i);
226   }
227
228   eel.clear();
229   eel.resize(M + 1, 0.0);
230   for (int i = 1; i <= M; i++) {
231     int totLen = refs.getRef(i).getTotLen();
232     int fullLen = refs.getRef(i).getFullLen();
233     int pos1 = max(min(totLen - fullLen + 1, ub) - lb, 0);
234     int pos2 = max(min(totLen, ub) - lb, 0);
235
236     if (pos2 == 0) { eel[i] = 0.0; continue; }
237     
238     eel[i] = fullLen * cdf[pos1] + ((cdf[pos2] - cdf[pos1]) * (totLen + 1) - (clen[pos2] - clen[pos1]));
239     assert(eel[i] >= 0);
240     if (eel[i] < MINEEL) { eel[i] = 0.0; }
241   }
242   
243   delete[] pdf;
244   delete[] cdf;
245   delete[] clen;
246 }
247
248 template<class ModelType>
249 void writeEstimatedParameters(char* modelF, char* imdName) {
250         ModelType model;
251         double denom;
252         char outF[STRLEN];
253         FILE *fo;
254
255         model.read(modelF);
256
257         calcExpectedEffectiveLengths<ModelType>(model);
258
259         denom = pme_theta[0];
260         for (int i = 1; i <= M; i++)
261           if (eel[i] < EPSILON) pme_theta[i] = 0.0;
262           else denom += pme_theta[i];
263         if (denom <= 0) { fprintf(stderr, "No Expected Effective Length is no less than %.6g?!\n", MINEEL); exit(-1); }
264         for (int i = 0; i <= M; i++) pme_theta[i] /= denom;
265
266         denom = 0.0;
267         double *mw = model.getMW();
268         for (int i = 0; i <= M; i++) {
269           pme_theta[i] = (mw[i] < EPSILON ? 0.0 : pme_theta[i] / mw[i]);
270           denom += pme_theta[i];
271         }
272         assert(denom >= EPSILON);
273         for (int i = 0; i <= M; i++) pme_theta[i] /= denom;
274
275         //calculate normalized read fraction
276         double *nrf = new double[M + 1];
277         memset(nrf, 0, sizeof(double) * (M + 1));
278
279         denom = 1.0 - pme_theta[0];
280         if (denom <= 0) { fprintf(stderr, "No alignable reads?!\n"); exit(-1); }
281         for (int i = 1; i <= M; i++) nrf[i] = pme_theta[i] / denom;
282
283         //calculate tau values
284         double *tau = new double[M + 1];
285         memset(tau, 0, sizeof(double) * (M + 1));
286
287         denom = 0.0;
288         for (int i = 1; i <= M; i++) 
289           if (eel[i] > EPSILON) {
290             tau[i] = pme_theta[i] / eel[i];
291             denom += tau[i];
292           }
293         if (denom <= 0) { fprintf(stderr, "No alignable reads?!\n"); exit(-1); }
294         //assert(denom > 0);
295         for (int i = 1; i <= M; i++) {
296                 tau[i] /= denom;
297         }
298
299         //isoform level results
300         sprintf(outF, "%s.iso_res", imdName);
301         fo = fopen(outF, "a");
302         if (fo == NULL) { fprintf(stderr, "Cannot open %s!\n", outF); exit(-1); }
303         for (int i = 1; i <= M; i++)
304                 fprintf(fo, "%.15g%c", nrf[i], (i < M ? '\t' : '\n'));
305         for (int i = 1; i <= M; i++)
306                 fprintf(fo, "%.15g%c", tau[i], (i < M ? '\t' : '\n'));
307         fclose(fo);
308
309         //gene level results
310         sprintf(outF, "%s.gene_res", imdName);
311         fo = fopen(outF, "a");
312         if (fo == NULL) { fprintf(stderr, "Cannot open %s!\n", outF); exit(-1); }
313         for (int i = 0; i < m; i++) {
314                 double sumN = 0.0; //  sum of normalized read fraction
315                 int b = gi.spAt(i), e = gi.spAt(i + 1);
316                 for (int j = b; j < e; j++) {
317                         sumN += nrf[j];
318                 }
319                 fprintf(fo, "%.15g%c", sumN, (i < m - 1 ? '\t' : '\n'));
320         }
321         for (int i = 0; i < m; i++) {
322                 double sumT = 0.0; //  sum of tau values
323                 int b = gi.spAt(i), e = gi.spAt(i + 1);
324                 for (int j = b; j < e; j++) {
325                         sumT += tau[j];
326                 }
327                 fprintf(fo, "%.15g%c", sumT, (i < m - 1 ? '\t' : '\n'));
328         }
329         fclose(fo);
330
331         delete[] nrf;
332         delete[] tau;
333
334         if (verbose) { printf("Gibbs based expression values are written!\n"); }
335 }
336
337
338 int main(int argc, char* argv[]) {
339         if (argc < 7) {
340                 printf("Usage: rsem-run-gibbs reference_name sample_name imdName BURNIN CHAINLEN GAP [-q]\n");
341                 exit(-1);
342         }
343
344         BURNIN = atoi(argv[4]);
345         CHAINLEN = atoi(argv[5]);
346         GAP = atoi(argv[6]);
347         load_data(argv[1], argv[2], argv[3]);
348
349         quiet = false;
350         if (argc > 7 && !strcmp(argv[7], "-q")) {
351                 quiet = true;
352         }
353         verbose = !quiet;
354
355         init();
356         Gibbs(argv[3]);
357
358         sprintf(modelF, "%s.model", argv[2]);
359         FILE *fi = fopen(modelF, "r");
360         if (fi == NULL) { fprintf(stderr, "Cannot open %s!\n", modelF); exit(-1); }
361         fscanf(fi, "%d", &model_type);
362         fclose(fi);
363
364         switch(model_type) {
365         case 0 : writeEstimatedParameters<SingleModel>(modelF, argv[3]); break;
366         case 1 : writeEstimatedParameters<SingleQModel>(modelF, argv[3]); break;
367         case 2 : writeEstimatedParameters<PairedEndModel>(modelF, argv[3]); break;
368         case 3 : writeEstimatedParameters<PairedEndQModel>(modelF, argv[3]); break;
369         }
370
371         return 0;
372 }