]> git.donarmstrong.com Git - rsem.git/blob - SingleModel.h
Merge remote-tracking branch 'origin/master'
[rsem.git] / SingleModel.h
1 #ifndef SINGLEMODEL_H_
2 #define SINGLEMODEL_H_
3
4 #include<cmath>
5 #include<cstdio>
6 #include<cassert>
7 #include<cstring>
8 #include<string>
9 #include<algorithm>
10 #include<sstream>
11
12 #include "utils.h"
13 #include "Orientation.h"
14 #include "LenDist.h"
15 #include "RSPD.h"
16 #include "Profile.h"
17 #include "NoiseProfile.h"
18
19 #include "ModelParams.h"
20 #include "RefSeq.h"
21 #include "Refs.h"
22 #include "SingleRead.h"
23 #include "SingleHit.h"
24 #include "ReadReader.h"
25
26 #include "simul.h"
27
28 class SingleModel {
29 public:
30         SingleModel(Refs* refs = NULL) {
31                 this->refs = refs;
32                 M = (refs != NULL ? refs->getM() : 0);
33                 memset(N, 0, sizeof(N));
34                 estRSPD = false;
35                 needCalcConPrb = true;
36
37                 ori = new Orientation();
38                 gld = new LenDist();
39                 mld = NULL;
40                 rspd = new RSPD(estRSPD);
41                 pro = new Profile();
42                 npro = new NoiseProfile();
43
44                 mean = -1.0; sd = 0.0;
45                 mw = NULL;
46
47                 seedLen = 0;
48         }
49
50         //If it is not a master node, only init & update can be used!
51         SingleModel(ModelParams& params, bool isMaster = true) {
52                 M = params.M;
53                 memcpy(N, params.N, sizeof(params.N));
54                 refs = params.refs;
55                 estRSPD = params.estRSPD;
56                 mean = params.mean; sd = params.sd;
57                 seedLen = params.seedLen;
58                 needCalcConPrb = true;
59
60                 ori = NULL; gld = NULL; mld = NULL; rspd = NULL; pro = NULL; npro = NULL;
61                 mw = NULL;
62
63                 if (isMaster) {
64                         gld = new LenDist(params.minL, params.maxL);
65                         if (mean >= EPSILON) {
66                                 mld = new LenDist(params.mate_minL, params.mate_maxL);
67                         }
68                         if (!estRSPD) { rspd = new RSPD(estRSPD); }
69                 }
70
71                 ori = new Orientation(params.probF);
72                 if (estRSPD) { rspd = new RSPD(estRSPD, params.B); }
73                 pro = new Profile(params.maxL);
74                 npro = new NoiseProfile();
75         }
76
77         ~SingleModel() {
78                 refs = NULL;
79                 if (ori != NULL) delete ori;
80                 if (gld != NULL) delete gld;
81                 if (mld != NULL) delete mld;
82                 if (rspd != NULL) delete rspd;
83                 if (pro != NULL) delete pro;
84                 if (npro != NULL) delete npro;
85                 if (mw != NULL) delete[] mw;
86                 /* delete[] p1, p2 */
87         }
88
89         void estimateFromReads(const char*);
90
91         //if prob is too small, just make it 0
92         double getConPrb(const SingleRead& read, const SingleHit& hit) {
93                 if (read.isLowQuality()) return 0.0;
94
95                 double prob;
96                 int sid = hit.getSid();
97                 RefSeq &ref = refs->getRef(sid);
98                 int fullLen = ref.getFullLen();
99                 int totLen = ref.getTotLen();
100                 int dir = hit.getDir();
101                 int pos = hit.getPos();
102                 int readLen = read.getReadLength();
103                 int fpos = (dir == 0 ? pos : totLen - pos - readLen); // the aligned position reported in SAM file, should be a coordinate in forward strand
104
105                 assert(fpos >= 0 && fpos + readLen <= totLen && readLen <= totLen);
106                 int seedPos = (dir == 0 ? pos : totLen - pos - seedLen); // the aligned position of the seed in forward strand coordinates
107                 if (seedPos >= fullLen || ref.getMask(seedPos)) return 0.0;
108
109                 int effL;
110                 double value;
111
112                 if (mld != NULL) {
113                         int minL = std::max(readLen, gld->getMinL());
114                         int maxL = std::min(totLen - pos, gld->getMaxL());
115                         int pfpos; // possible fpos for fragment
116                         value = 0.0;
117                         for (int fragLen = minL; fragLen <= maxL; fragLen++) {
118                                 pfpos = (dir == 0 ? pos : totLen - pos - fragLen);
119                                 effL = std::min(fullLen, totLen - fragLen + 1);
120                                 value += gld->getAdjustedProb(fragLen, totLen) * rspd->getAdjustedProb(pfpos, effL, fullLen) * mld->getAdjustedProb(readLen, fragLen);
121                         }
122                 }
123                 else {
124                         effL = std::min(fullLen, totLen - readLen + 1);
125                         value = gld->getAdjustedProb(readLen, totLen) * rspd->getAdjustedProb(fpos, effL, fullLen);
126                 }
127
128                 prob = ori->getProb(dir) * value * pro->getProb(read.getReadSeq(), ref, pos, dir);
129
130                 if (prob < EPSILON) { prob = 0.0; }
131
132
133                 prob = (mw[sid] < EPSILON ? 0.0 : prob / mw[sid]);
134
135                 return prob;
136         }
137
138         double getNoiseConPrb(const SingleRead& read) {
139                 if (read.isLowQuality()) return 0.0;
140                 double prob = mld != NULL ? mld->getProb(read.getReadLength()) : gld->getProb(read.getReadLength());
141                 prob *= npro->getProb(read.getReadSeq());
142                 if (prob < EPSILON) { prob = 0.0; }
143
144                 prob = (mw[0] < EPSILON ? 0.0 : prob / mw[0]);
145
146                 return prob;
147         }
148
149         double getLogP() { return npro->getLogP(); }
150
151         void init();
152
153         void update(const SingleRead& read, const SingleHit& hit, double frac) {
154                 if (read.isLowQuality() || frac < EPSILON) return;
155
156                 RefSeq& ref = refs->getRef(hit.getSid());
157                 int dir = hit.getDir();
158                 int pos = hit.getPos();
159
160                 if (estRSPD) {
161                         int fullLen = ref.getFullLen();
162
163                         // Only use one strand to estimate RSPD
164                         if (ori->getProb(0) >= ORIVALVE && dir == 0) {
165                                 rspd->update(pos, fullLen, frac);
166                         }
167
168                         if (ori->getProb(0) < ORIVALVE && dir == 1) {
169                                 int totLen = ref.getTotLen();
170                                 int readLen = read.getReadLength();
171
172                                 int pfpos, effL; 
173
174                                 if (mld != NULL) {
175                                         int minL = std::max(readLen, gld->getMinL());
176                                         int maxL = std::min(totLen - pos, gld->getMaxL());
177                                         double sum = 0.0;
178                                         assert(maxL >= minL);
179                                         std::vector<double> frag_vec(maxL - minL + 1, 0.0);
180
181                                         for (int fragLen = minL; fragLen <= maxL; fragLen++) {
182                                                 pfpos = totLen - pos - fragLen;
183                                                 effL = std::min(fullLen, totLen - fragLen + 1);
184                                                 frag_vec[fragLen - minL] = gld->getAdjustedProb(fragLen, totLen) * rspd->getAdjustedProb(pfpos, effL, fullLen) * mld->getAdjustedProb(readLen, fragLen);
185                                                 sum += frag_vec[fragLen - minL];
186                                         }
187                                         assert(sum >= EPSILON);
188                                         for (int fragLen = minL; fragLen <= maxL; fragLen++) {
189                                                 pfpos = totLen - pos - fragLen;
190                                                 rspd->update(pfpos, fullLen, frac * (frag_vec[fragLen - minL] / sum));
191                                         }
192                                 }
193                                 else {
194                                         rspd->update(totLen - pos - readLen, fullLen, frac);
195                                 }
196                         }
197                 }
198                 pro->update(read.getReadSeq(), ref, pos, dir, frac);
199         }
200
201         void updateNoise(const SingleRead& read, double frac) {
202                 if (read.isLowQuality() || frac < EPSILON) return;
203
204                 npro->update(read.getReadSeq(), frac);
205         }
206
207         void finish();
208
209         void collect(const SingleModel&);
210
211         bool getNeedCalcConPrb() { return needCalcConPrb; }
212         void setNeedCalcConPrb(bool value) { needCalcConPrb = value; }
213
214         //void calcP1();
215         //void calcP2();
216         //double* getP1() { return p1; }
217         //double* getP2() { return p2; }
218
219         void read(const char*);
220         void write(const char*);
221
222         const LenDist& getGLD() { return *gld; }
223
224         void startSimulation(simul*, double*);
225         bool simulate(int, SingleRead&, int&);
226         void finishSimulation();
227
228         double* getMW() { 
229           assert(mw != NULL);
230           return mw;
231         }
232
233         int getModelType() const { return model_type; }
234
235 private:
236         static const int model_type = 0;
237         static const int read_type = 0;
238
239         int M;
240         int N[3];
241         Refs *refs;
242         double mean, sd;
243         int seedLen;
244         //double *p1, *p2; P_i' & P_i''
245
246         bool estRSPD; // true if estimate RSPD
247         bool needCalcConPrb; // true need, false does not need
248
249         Orientation *ori;
250         LenDist *gld, *mld;
251         RSPD *rspd;
252         Profile *pro;
253         NoiseProfile *npro;
254
255         simul *sampler; // for simulation
256         double *theta_cdf; // for simulation
257
258         double *mw; // for masking
259
260         void calcMW();
261 };
262
263 void SingleModel::estimateFromReads(const char* readFN) {
264         int s;
265         char readFs[2][STRLEN];
266         SingleRead read;
267
268         mld != NULL ? mld->init() : gld->init();
269         for (int i = 0; i < 3; i++)
270                 if (N[i] > 0) {
271                         genReadFileNames(readFN, i, read_type, s, readFs);
272                         ReadReader<SingleRead> reader(s, readFs, refs->hasPolyA(), seedLen); // allow calculation of calc_lq() function
273
274                         int cnt = 0;
275                         while (reader.next(read)) {
276                                 if (!read.isLowQuality()) {
277                                         mld != NULL ? mld->update(read.getReadLength(), 1.0) : gld->update(read.getReadLength(), 1.0);
278                                         if (i == 0) { npro->updateC(read.getReadSeq()); }
279                                 }
280                                 else if (verbose && read.getReadLength() < seedLen) {
281                                         printf("Warning: Read %s is ignored due to read length %d < seed length %d!\n", read.getName().c_str(), read.getReadLength(), seedLen);
282                                 }
283
284                                 ++cnt;
285                                 if (verbose && cnt % 1000000 == 0) { printf("%d READS PROCESSED\n", cnt); }
286                         }
287
288                         if (verbose) { printf("estimateFromReads, N%d finished.\n", i); }
289                 }
290
291         mld != NULL ? mld->finish() : gld->finish();
292         //mean should be > 0
293         if (mean >= EPSILON) { 
294           assert(mld->getMaxL() <= gld->getMaxL());
295           gld->setAsNormal(mean, sd, std::max(mld->getMinL(), gld->getMinL()), gld->getMaxL());
296         }
297         npro->calcInitParams();
298
299         mw = new double[M + 1];
300         calcMW();
301 }
302
303 void SingleModel::init() {
304         if (estRSPD) rspd->init();
305         pro->init();
306         npro->init();
307 }
308
309 void SingleModel::finish() {
310         if (estRSPD) rspd->finish();
311         pro->finish();
312         npro->finish();
313         needCalcConPrb = true;
314         if (estRSPD) calcMW();
315 }
316
317 void SingleModel::collect(const SingleModel& o) {
318         if (estRSPD) rspd->collect(*(o.rspd));
319         pro->collect(*(o.pro));
320         npro->collect(*(o.npro));
321 }
322
323 //Only master node can call
324 void SingleModel::read(const char* inpF) {
325         int val;
326         FILE *fi = fopen(inpF, "r");
327         if (fi == NULL) { fprintf(stderr, "Cannot open %s! It may not exist.\n", inpF); exit(-1); }
328
329         assert(fscanf(fi, "%d", &val) == 1);
330         assert(val == model_type);
331
332         ori->read(fi);
333         gld->read(fi);
334         assert(fscanf(fi, "%d", &val) == 1);
335         if (val > 0) {
336                 if (mld == NULL) mld = new LenDist();
337                 mld->read(fi);
338         }
339         rspd->read(fi);
340         pro->read(fi);
341         npro->read(fi);
342
343         if (fscanf(fi, "%d", &val) == 1) {
344                 if (M == 0) M = val;
345                 if (M == val) {
346                         mw = new double[M + 1];
347                         for (int i = 0; i <= M; i++) assert(fscanf(fi, "%lf", &mw[i]) == 1);
348                 }
349         }
350
351         fclose(fi);
352 }
353
354 //Only master node can call. Only be called at EM.cpp
355 void SingleModel::write(const char* outF) {
356         FILE *fo = fopen(outF, "w");
357
358         fprintf(fo, "%d\n", model_type);
359         fprintf(fo, "\n");
360
361         ori->write(fo);  fprintf(fo, "\n");
362         gld->write(fo);  fprintf(fo, "\n");
363         if (mld != NULL) {
364                 fprintf(fo, "1\n");
365                 mld->write(fo);
366         }
367         else { fprintf(fo, "0\n"); }
368         fprintf(fo, "\n");
369         rspd->write(fo); fprintf(fo, "\n");
370         pro->write(fo);  fprintf(fo, "\n");
371         npro->write(fo);
372
373         if (mw != NULL) {
374           fprintf(fo, "\n%d\n", M);
375           for (int i = 0; i < M; i++) {
376             fprintf(fo, "%.15g ", mw[i]);
377           }
378           fprintf(fo, "%.15g\n", mw[M]);
379         }
380
381         fclose(fo);
382 }
383
384 void SingleModel::startSimulation(simul* sampler, double* theta) {
385         this->sampler = sampler;
386
387         theta_cdf = new double[M + 1];
388         for (int i = 0; i <= M; i++) {
389                 theta_cdf[i] = theta[i];
390                 if (i > 0) theta_cdf[i] += theta_cdf[i - 1];
391         }
392
393         rspd->startSimulation(M, refs);
394         pro->startSimulation();
395         npro->startSimulation();
396 }
397
398 bool SingleModel::simulate(int rid, SingleRead& read, int& sid) {
399         int dir, pos, readLen, fragLen;
400         std::string name;
401         std::string readseq;
402         std::ostringstream strout;
403
404         sid = sampler->sample(theta_cdf, M + 1);
405
406         if (sid == 0) {
407                 dir = pos = 0;
408                 readLen = (mld != NULL ? mld->simulate(sampler, -1) : gld->simulate(sampler, -1));
409                 readseq = npro->simulate(sampler, readLen);
410         }
411         else {
412                 RefSeq &ref = refs->getRef(sid);
413                 dir = ori->simulate(sampler);
414                 fragLen = gld->simulate(sampler, ref.getTotLen());
415                 if (fragLen < 0) return false;
416                 int effL = std::min(ref.getFullLen(), ref.getTotLen() - fragLen + 1);
417                 pos = rspd->simulate(sampler, sid, effL);
418                 if (pos < 0) return false;
419                 if (dir > 0) pos = ref.getTotLen() - pos - fragLen;
420
421                 if (mld != NULL) {
422                         readLen = mld->simulate(sampler, fragLen);
423                         if (readLen < 0) return false;
424                         readseq = pro->simulate(sampler, readLen, pos, dir, ref);
425                 }
426                 else {
427                         readseq = pro->simulate(sampler, fragLen, pos, dir, ref);
428                 }
429         }
430
431         strout<<rid<<"_"<<dir<<"_"<<sid<<"_"<<pos;
432         name = strout.str();
433
434         read = SingleRead(name, readseq);
435
436         return true;
437 }
438
439 void SingleModel::finishSimulation() {
440         delete[] theta_cdf;
441
442         rspd->finishSimulation();
443         pro->finishSimulation();
444         npro->finishSimulation();
445 }
446
447 void SingleModel::calcMW() {
448         double probF, probR;
449
450         assert((mld == NULL ? gld->getMinL() : mld->getMinL()) >= seedLen);
451   
452         memset(mw, 0, sizeof(double) * (M + 1));
453         mw[0] = 1.0;
454
455         probF = ori->getProb(0);
456         probR = ori->getProb(1);
457
458         for (int i = 1; i <= M; i++) {
459                 RefSeq& ref = refs->getRef(i);
460                 int totLen = ref.getTotLen();
461                 int fullLen = ref.getFullLen();
462                 double value = 0.0;
463                 int minL, maxL;
464                 int effL, pfpos;
465                 int end = std::min(fullLen, totLen - seedLen + 1);
466                 double factor;
467
468                 for (int seedPos = 0; seedPos < end; seedPos++)
469                         if (ref.getMask(seedPos)) {
470                                 //forward
471                                 minL = gld->getMinL();
472                                 maxL = std::min(gld->getMaxL(), totLen - seedPos);
473                                 pfpos = seedPos;
474                                 for (int fragLen = minL; fragLen <= maxL; fragLen++) {
475                                         effL = std::min(fullLen, totLen - fragLen + 1);
476                                         factor = (mld == NULL ? 1.0 : mld->getAdjustedCumulativeProb(std::min(mld->getMaxL(), fragLen), fragLen));
477                                         value += probF * gld->getAdjustedProb(fragLen, totLen) * rspd->getAdjustedProb(pfpos, effL, fullLen) * factor;
478                                 }
479                                 //reverse
480                                 minL = gld->getMinL();
481                                 maxL = std::min(gld->getMaxL(), seedPos + seedLen);
482                                 for (int fragLen = minL; fragLen <= maxL; fragLen++) {
483                                         pfpos = seedPos - (fragLen - seedLen);
484                                         effL = std::min(fullLen, totLen - fragLen + 1);
485                                         factor = (mld == NULL ? 1.0 : mld->getAdjustedCumulativeProb(std::min(mld->getMaxL(), fragLen), fragLen));
486                                         value += probR * gld->getAdjustedProb(fragLen, totLen) * rspd->getAdjustedProb(pfpos, effL, fullLen) * factor;
487                                 }
488                         }
489     
490                 //for reverse strand masking
491                 for (int seedPos = end; seedPos <= totLen - seedLen; seedPos++) {
492                         minL = std::max(gld->getMinL(), seedPos + seedLen - fullLen + 1);
493                         maxL = std::min(gld->getMaxL(), seedPos + seedLen);
494                         for (int fragLen = minL; fragLen <= maxL; fragLen++) {
495                                 pfpos = seedPos - (fragLen - seedLen);
496                                 effL = std::min(fullLen, totLen - fragLen + 1);
497                                 factor = (mld == NULL ? 1.0 : mld->getAdjustedCumulativeProb(std::min(mld->getMaxL(), fragLen), fragLen));
498                                 value += probR * gld->getAdjustedProb(fragLen, totLen) * rspd->getAdjustedProb(pfpos, effL, fullLen) * factor;
499                         }
500                 }
501
502                 mw[i] = 1.0 - value;
503
504                 if (mw[i] < 1e-8) {
505                         //      fprintf(stderr, "Warning: %dth reference sequence is masked for almost all positions!\n", i);
506                         mw[i] = 0.0;
507                 }
508         }
509 }
510
511 #endif /* SINGLEMODEL_H_ */