]> git.donarmstrong.com Git - rsem.git/blob - SingleQModel.h
Allowed > 2^31 hits
[rsem.git] / SingleQModel.h
1 #ifndef SINGLEQMODEL_H_
2 #define SINGLEQMODEL_H_
3
4 #include<cmath>
5 #include<cstdio>
6 #include<cassert>
7 #include<cstring>
8 #include<string>
9 #include<algorithm>
10 #include<sstream>
11 #include<iostream>
12
13 #include "utils.h"
14 #include "my_assert.h"
15 #include "Orientation.h"
16 #include "LenDist.h"
17 #include "RSPD.h"
18 #include "QualDist.h"
19 #include "QProfile.h"
20 #include "NoiseQProfile.h"
21
22 #include "ModelParams.h"
23 #include "RefSeq.h"
24 #include "Refs.h"
25 #include "SingleReadQ.h"
26 #include "SingleHit.h"
27 #include "ReadReader.h"
28
29 #include "simul.h"
30
31 class SingleQModel {
32 public:
33         SingleQModel(Refs* refs = NULL) {
34                 this->refs = refs;
35                 M = (refs != NULL ? refs->getM() : 0);
36                 memset(N, 0, sizeof(N));
37                 estRSPD = false;
38                 needCalcConPrb = true;
39                 
40                 ori = new Orientation();
41                 gld = new LenDist();
42                 mld = NULL;
43                 rspd = new RSPD(estRSPD);
44                 qd = new QualDist();
45                 qpro = new QProfile();
46                 nqpro = new NoiseQProfile();
47
48                 mean = -1.0; sd = 0.0;
49                 mw = NULL;
50
51                 seedLen = 0;
52         }
53
54         //If it is not a master node, only init & update can be used!
55         SingleQModel(ModelParams& params, bool isMaster = true) {
56                 M = params.M;
57                 memcpy(N, params.N, sizeof(params.N));
58                 refs = params.refs;
59                 estRSPD = params.estRSPD;
60                 mean = params.mean; sd = params.sd;
61                 seedLen = params.seedLen;
62                 needCalcConPrb = true;
63
64                 ori = NULL; gld = NULL; mld = NULL; rspd = NULL; qd = NULL; qpro = NULL; nqpro = NULL;
65                 mw = NULL;
66
67                 if (isMaster) {
68                         gld = new LenDist(params.minL, params.maxL);                    
69                         if (mean >= EPSILON) {
70                                 mld = new LenDist(params.mate_minL, params.mate_maxL);
71                         }
72                         if (!estRSPD) { rspd = new RSPD(estRSPD); }
73                         qd = new QualDist();
74                 }
75
76                 ori = new Orientation(params.probF);
77                 if (estRSPD) { rspd = new RSPD(estRSPD, params.B); }
78                 qpro = new QProfile();
79                 nqpro = new NoiseQProfile();
80         }
81
82         ~SingleQModel() {
83                 refs = NULL;
84                 if (ori != NULL) delete ori;
85                 if (gld != NULL) delete gld;
86                 if (mld != NULL) delete mld;
87                 if (rspd != NULL) delete rspd;
88                 if (qd != NULL) delete qd;
89                 if (qpro != NULL) delete qpro;
90                 if (nqpro != NULL) delete nqpro;
91                 if (mw != NULL) delete[] mw;
92                 /*delete[] p1, p2;*/
93         }
94
95         //SingleQModel& operator=(const SingleQModel&);
96
97         void estimateFromReads(const char*);
98
99         //if prob is too small, just make it 0
100         double getConPrb(const SingleReadQ& read, const SingleHit& hit) const {
101                 if (read.isLowQuality()) return 0.0;
102
103                 double prob;
104                 int sid = hit.getSid();
105                 RefSeq &ref = refs->getRef(sid);
106                 int fullLen = ref.getFullLen();
107                 int totLen = ref.getTotLen();
108                 int dir = hit.getDir();
109                 int pos = hit.getPos();
110                 int readLen = read.getReadLength();
111                 int fpos = (dir == 0 ? pos : totLen - pos - readLen); // the aligned position reported in SAM file, should be a coordinate in forward strand
112
113                 general_assert(fpos >= 0, "The alignment of read " + read.getName() + " to transcript " + itos(sid) + " starts at " + itos(fpos) + \
114                                 " from the forward direction, which should be a non-negative number! " + \
115                                 "It is possible that the aligner you use gave different read lengths for a same read in SAM file.");
116                 general_assert(fpos + readLen <= totLen,"Read " + read.getName() + " is hung over the end of transcript " + itos(sid) + "! " \
117                                 + "It is possible that the aligner you use gave different read lengths for a same read in SAM file.");
118                 general_assert(readLen <= totLen, "Read " + read.getName() + " has length " + itos(readLen) + ", but it is aligned to transcript " \
119                                 + itos(sid) + ", whose length (" + itos(totLen) + ") is shorter than the read's length!");
120
121                 int seedPos = (dir == 0 ? pos : totLen - pos - seedLen); // the aligned position of the seed in forward strand coordinates
122                 if (seedPos >= fullLen || ref.getMask(seedPos)) return 0.0;
123
124                 int effL;
125                 double value;
126
127                 if (mld != NULL) {
128                         int minL = std::max(readLen, gld->getMinL());
129                         int maxL = std::min(totLen - pos, gld->getMaxL());
130                         int pfpos; // possible fpos for fragment
131                         value = 0.0;
132                         for (int fragLen = minL; fragLen <= maxL; fragLen++) {
133                                 pfpos = (dir == 0 ? pos : totLen - pos - fragLen);
134                                 effL = std::min(fullLen, totLen - fragLen + 1);
135                                 value += gld->getAdjustedProb(fragLen, totLen) * rspd->getAdjustedProb(pfpos, effL, fullLen) * mld->getAdjustedProb(readLen, fragLen);
136                         }
137                 }
138                 else {
139                         effL = std::min(fullLen, totLen - readLen + 1);
140                         value = gld->getAdjustedProb(readLen, totLen) * rspd->getAdjustedProb(fpos, effL, fullLen);
141                 }
142
143                 prob = ori->getProb(dir) * value * qpro->getProb(read.getReadSeq(), read.getQScore(), ref, pos, dir);
144
145                 if (prob < EPSILON) { prob = 0.0; }
146
147                 prob = (mw[sid] < EPSILON ? 0.0 : prob / mw[sid]);
148
149                 return prob;
150         }
151
152         double getNoiseConPrb(const SingleReadQ& read) {
153                 if (read.isLowQuality()) return 0.0;
154                 double prob = mld != NULL ? mld->getProb(read.getReadLength()) : gld->getProb(read.getReadLength());
155                 prob *= nqpro->getProb(read.getReadSeq(), read.getQScore());
156                 if (prob < EPSILON) { prob = 0.0; }
157
158                 prob = (mw[0] < EPSILON ? 0.0 : prob / mw[0]);
159
160                 return prob;
161         }
162
163         double getLogP() { return nqpro->getLogP(); }
164
165         void init();
166
167         void update(const SingleReadQ& read, const SingleHit& hit, double frac) {
168                 if (read.isLowQuality() || frac < EPSILON) return;
169
170                 const RefSeq& ref = refs->getRef(hit.getSid());
171
172                 int dir = hit.getDir();
173                 int pos = hit.getPos();
174
175                 if (estRSPD) {
176                         int fullLen = ref.getFullLen();
177
178                         // Only use one strand to estimate RSPD
179                         if (ori->getProb(0) >= ORIVALVE && dir == 0) {
180                                 rspd->update(pos, fullLen, frac);
181                         }
182
183                         if (ori->getProb(0) < ORIVALVE && dir == 1) {
184                                 int totLen = ref.getTotLen();                     
185                                 int readLen = read.getReadLength();
186                                 
187                                 int pfpos, effL; 
188
189                                 if (mld != NULL) {
190                                         int minL = std::max(readLen, gld->getMinL());
191                                         int maxL = std::min(totLen - pos, gld->getMaxL());
192                                         double sum = 0.0;
193                                         assert(maxL >= minL);
194                                         std::vector<double> frag_vec(maxL - minL + 1, 0.0);
195
196                                         for (int fragLen = minL; fragLen <= maxL; fragLen++) {
197                                                 pfpos = totLen - pos - fragLen;
198                                                 effL = std::min(fullLen, totLen - fragLen + 1);
199                                                 frag_vec[fragLen - minL] = gld->getAdjustedProb(fragLen, totLen) * rspd->getAdjustedProb(pfpos, effL, fullLen) * mld->getAdjustedProb(readLen, fragLen);
200                                                 sum += frag_vec[fragLen - minL];
201                                         }
202                                         assert(sum >= EPSILON);
203                                         for (int fragLen = minL; fragLen <= maxL; fragLen++) {
204                                                 pfpos = totLen - pos - fragLen;
205                                                 rspd->update(pfpos, fullLen, frac * (frag_vec[fragLen - minL] / sum));
206                                         }
207                                 }
208                                 else {
209                                         rspd->update(totLen - pos - readLen, fullLen, frac);
210                                 }
211                         }
212                 }
213                 qpro->update(read.getReadSeq(), read.getQScore(), ref, pos, dir, frac);
214         }
215
216         void updateNoise(const SingleReadQ& read, double frac) {
217                 if (read.isLowQuality() || frac < EPSILON) return;
218
219                 nqpro->update(read.getReadSeq(), read.getQScore(), frac);
220         }
221
222         void finish();
223
224         void collect(const SingleQModel&);
225
226         //void copy(const SingleQModel&);
227
228         bool getNeedCalcConPrb() { return needCalcConPrb; }
229         void setNeedCalcConPrb(bool value) { needCalcConPrb = value; }
230
231         //void calcP1();
232         //void calcP2();
233         //double* getP1() { return p1; }
234         //double* getP2() { return p2; }
235
236         void read(const char*);
237         void write(const char*);
238
239         const LenDist& getGLD() { return *gld; }
240
241         void startSimulation(simul*, double*);
242         bool simulate(READ_INT_TYPE, SingleReadQ&, int&);
243         void finishSimulation();
244
245         //Use it after function 'read' or 'estimateFromReads'
246         double* getMW() { 
247           assert(mw != NULL);
248           return mw;
249         }
250
251         int getModelType() const { return model_type; }
252
253 private:
254         static const int model_type = 1;
255         static const int read_type = 1;
256
257         int M;
258         READ_INT_TYPE N[3];
259         Refs *refs;
260         double mean, sd;
261         int seedLen;
262         //double *p1, *p2; P_i' & P_i'';
263
264         bool estRSPD; // true if estimate RSPD
265         bool needCalcConPrb; //true need, false does not need
266
267         Orientation *ori;
268         LenDist *gld, *mld;
269         RSPD *rspd;
270         QualDist *qd;
271         QProfile *qpro;
272         NoiseQProfile *nqpro;
273
274         simul *sampler; // for simulation
275         double *theta_cdf; // for simulation
276
277         double *mw; // for masking
278
279         void calcMW();
280 };
281
282 void SingleQModel::estimateFromReads(const char* readFN) {
283         int s;
284         char readFs[2][STRLEN];
285         SingleReadQ read;
286
287         mld != NULL ? mld->init() : gld->init();
288         for (int i = 0; i < 3; i++)
289                 if (N[i] > 0) {
290                         genReadFileNames(readFN, i, read_type, s, readFs);
291                         ReadReader<SingleReadQ> reader(s, readFs, refs->hasPolyA(), seedLen); // allow calculation of calc_lq() function
292
293                         READ_INT_TYPE cnt = 0;
294                         while (reader.next(read)) {
295                                 if (!read.isLowQuality()) {
296                                         mld != NULL ? mld->update(read.getReadLength(), 1.0) : gld->update(read.getReadLength(), 1.0);
297                                         qd->update(read.getQScore());
298                                         if (i == 0) { nqpro->updateC(read.getReadSeq(), read.getQScore()); }
299                                 }
300                                 else if (verbose && read.getReadLength() < seedLen) {
301                                         std::cout<< "Warning: Read "<< read.getName()<< " is ignored due to read length "<< read.getReadLength()<< " < seed length "<< seedLen<< "!"<< std::endl;
302                                 }
303
304                                 ++cnt;
305                                 if (verbose && cnt % 1000000 == 0) { std::cout<< cnt<< " READS PROCESSED"<< std::endl; }
306                         }
307
308                         if (verbose) { std::cout<< "estimateFromReads, N"<< i<< " finished."<< std::endl; }
309                 }
310
311         mld != NULL ? mld->finish() : gld->finish();
312
313         //mean should be > 0
314         if (mean >= EPSILON) { 
315           assert(mld->getMaxL() <= gld->getMaxL());
316           gld->setAsNormal(mean, sd, std::max(mld->getMinL(), gld->getMinL()), gld->getMaxL());
317         }
318         qd->finish();
319         nqpro->calcInitParams();
320
321         mw = new double[M + 1];
322         calcMW();
323 }
324
325 void SingleQModel::init() {
326         if (estRSPD) rspd->init();
327         qpro->init();
328         nqpro->init();
329 }
330
331 void SingleQModel::finish() {
332         if (estRSPD) rspd->finish();
333         qpro->finish();
334         nqpro->finish();
335         needCalcConPrb = true;
336         if (estRSPD) calcMW();
337 }
338
339 void SingleQModel::collect(const SingleQModel& o) {
340         if (estRSPD) rspd->collect(*(o.rspd));
341         qpro->collect(*(o.qpro));
342         nqpro->collect(*(o.nqpro));
343 }
344
345 //Only master node can call
346 void SingleQModel::read(const char* inpF) {
347         int val;
348         FILE *fi = fopen(inpF, "r");
349         if (fi == NULL) { fprintf(stderr, "Cannot open %s! It may not exist.\n", inpF); exit(-1); }
350
351         assert(fscanf(fi, "%d", &val) == 1);
352         assert(val == model_type);
353
354         ori->read(fi);
355         gld->read(fi);
356         assert(fscanf(fi, "%d", &val) == 1);
357         if (val > 0) {
358                 if (mld == NULL) mld = new LenDist();
359                 mld->read(fi);
360         }
361         rspd->read(fi);
362         qd->read(fi);
363         qpro->read(fi);
364         nqpro->read(fi);
365
366         if (fscanf(fi, "%d", &val) == 1) {
367                 if (M == 0) M = val;
368                 if (M == val) {
369                         mw = new double[M + 1];
370                         for (int i = 0; i <= M; i++) assert(fscanf(fi, "%lf", &mw[i]) == 1);
371                 }
372         }
373
374         fclose(fi);
375 }
376
377 //Only master node can call. Only be called at EM.cpp
378 void SingleQModel::write(const char* outF) {
379         FILE *fo = fopen(outF, "w");
380
381         fprintf(fo, "%d\n", model_type);
382         fprintf(fo, "\n");
383
384         ori->write(fo);  fprintf(fo, "\n");
385         gld->write(fo);  fprintf(fo, "\n");
386         if (mld != NULL) {
387                 fprintf(fo, "1\n");
388                 mld->write(fo);
389         }
390         else { fprintf(fo, "0\n"); }
391         fprintf(fo, "\n");
392         rspd->write(fo); fprintf(fo, "\n");
393         qd->write(fo);   fprintf(fo, "\n");
394         qpro->write(fo); fprintf(fo, "\n");
395         nqpro->write(fo);
396
397         if (mw != NULL) {
398           fprintf(fo, "\n%d\n", M);
399           for (int i = 0; i < M; i++) {
400             fprintf(fo, "%.15g ", mw[i]);
401           }
402           fprintf(fo, "%.15g\n", mw[M]);
403         }
404
405         fclose(fo);
406 }
407
408 void SingleQModel::startSimulation(simul* sampler, double* theta) {
409         this->sampler = sampler;
410
411         theta_cdf = new double[M + 1];
412         for (int i = 0; i <= M; i++) {
413                 theta_cdf[i] = theta[i];
414                 if (i > 0) theta_cdf[i] += theta_cdf[i - 1];
415         }
416
417         rspd->startSimulation(M, refs);
418         qd->startSimulation();
419         qpro->startSimulation();
420         nqpro->startSimulation();
421 }
422
423 bool SingleQModel::simulate(READ_INT_TYPE rid, SingleReadQ& read, int& sid) {
424         int dir, pos, readLen, fragLen;
425         std::string name;
426         std::string qual, readseq;
427         std::ostringstream strout;
428
429         sid = sampler->sample(theta_cdf, M + 1);
430
431         if (sid == 0) {
432                 dir = pos = 0;
433                 readLen = (mld != NULL ? mld->simulate(sampler, -1) : gld->simulate(sampler, -1));
434                 qual = qd->simulate(sampler, readLen);
435                 readseq = nqpro->simulate(sampler, readLen, qual);
436         }
437         else {
438                 RefSeq &ref = refs->getRef(sid);
439                 dir = ori->simulate(sampler);
440                 fragLen = gld->simulate(sampler, ref.getTotLen());
441                 if (fragLen < 0) return false;
442
443                 int effL = std::min(ref.getFullLen(), ref.getTotLen() - fragLen + 1);
444                 pos = rspd->simulate(sampler, sid, effL);
445                 if (pos < 0) return false;
446                 if (dir > 0) pos = ref.getTotLen() - pos - fragLen;
447
448                 if (mld != NULL) {
449                         readLen = mld->simulate(sampler, fragLen);
450                         if (readLen < 0) return false;
451                         qual = qd->simulate(sampler, readLen);
452                         readseq = qpro->simulate(sampler, readLen, pos, dir, qual, ref);
453                 }
454                 else {
455                         qual = qd->simulate(sampler, fragLen);
456                         readseq = qpro->simulate(sampler, fragLen, pos, dir, qual, ref);
457                 }
458         }
459
460         strout<<rid<<"_"<<dir<<"_"<<sid<<"_"<<pos;
461         name = strout.str();
462
463         read = SingleReadQ(name, readseq, qual);
464
465         return true;
466 }
467
468 void SingleQModel::finishSimulation() {
469         delete[] theta_cdf;
470
471         rspd->finishSimulation();
472         qd->finishSimulation();
473         qpro->finishSimulation();
474         nqpro->finishSimulation();
475 }
476
477 void SingleQModel::calcMW() {
478         double probF, probR;
479
480         assert((mld == NULL ? gld->getMinL() : mld->getMinL()) >= seedLen);
481
482         memset(mw, 0, sizeof(double) * (M + 1));
483         mw[0] = 1.0;
484
485         probF = ori->getProb(0);
486         probR = ori->getProb(1);
487
488         for (int i = 1; i <= M; i++) {
489                 RefSeq& ref = refs->getRef(i);
490                 int totLen = ref.getTotLen();
491                 int fullLen = ref.getFullLen();
492                 double value = 0.0;
493                 int minL, maxL;
494                 int effL, pfpos;
495                 int end = std::min(fullLen, totLen - seedLen + 1);
496                 double factor;
497
498                 for (int seedPos = 0; seedPos < end; seedPos++)
499                         if (ref.getMask(seedPos)) {
500                                 //forward
501                                 minL = gld->getMinL();
502                                 maxL = std::min(gld->getMaxL(), totLen - seedPos);
503                                 pfpos = seedPos;
504                                 for (int fragLen = minL; fragLen <= maxL; fragLen++) {
505                                         effL = std::min(fullLen, totLen - fragLen + 1);
506                                         factor = (mld == NULL ? 1.0 : mld->getAdjustedCumulativeProb(std::min(mld->getMaxL(), fragLen), fragLen));
507                                         value += probF * gld->getAdjustedProb(fragLen, totLen) * rspd->getAdjustedProb(pfpos, effL, fullLen) * factor;
508                                 }
509                                 //reverse
510                                 minL = gld->getMinL();
511                                 maxL = std::min(gld->getMaxL(), seedPos + seedLen);
512                                 for (int fragLen = minL; fragLen <= maxL; fragLen++) {
513                                         pfpos = seedPos - (fragLen - seedLen);
514                                         effL = std::min(fullLen, totLen - fragLen + 1);
515                                         factor = (mld == NULL ? 1.0 : mld->getAdjustedCumulativeProb(std::min(mld->getMaxL(), fragLen), fragLen));
516                                         value += probR * gld->getAdjustedProb(fragLen, totLen) * rspd->getAdjustedProb(pfpos, effL, fullLen) * factor;
517                                 }
518                         }
519
520                 //for reverse strand masking
521                 for (int seedPos = end; seedPos <= totLen - seedLen; seedPos++) {
522                         minL = std::max(gld->getMinL(), seedPos + seedLen - fullLen + 1);
523                         maxL = std::min(gld->getMaxL(), seedPos + seedLen);
524                         for (int fragLen = minL; fragLen <= maxL; fragLen++) {
525                                 pfpos = seedPos - (fragLen - seedLen);
526                                 effL = std::min(fullLen, totLen - fragLen + 1);
527                                 factor = (mld == NULL ? 1.0 : mld->getAdjustedCumulativeProb(std::min(mld->getMaxL(), fragLen), fragLen));
528                                 value += probR * gld->getAdjustedProb(fragLen, totLen) * rspd->getAdjustedProb(pfpos, effL, fullLen) * factor;
529                         }
530                 }
531
532                 mw[i] = 1.0 - value;
533
534                 if (mw[i] < 1e-8) {
535                         //      fprintf(stderr, "Warning: %dth reference sequence is masked for almost all positions!\n", i);
536                         mw[i] = 0.0;
537                 }
538         }
539 }
540
541 #endif /* SINGLEQMODEL_H_ */