]> git.donarmstrong.com Git - rsem.git/blob - PairedEndModel.h
rsem-v1.1.14b3
[rsem.git] / PairedEndModel.h
1 #ifndef PAIREDENDMODEL_H_
2 #define PAIREDENDMODEL_H_
3
4 #include<cmath>
5 #include<cstdio>
6 #include<cassert>
7 #include<cstring>
8 #include<string>
9 #include<algorithm>
10 #include<sstream>
11
12 #include "utils.h"
13 #include "Orientation.h"
14 #include "LenDist.h"
15 #include "RSPD.h"
16 #include "Profile.h"
17 #include "NoiseProfile.h"
18
19 #include "ModelParams.h"
20 #include "RefSeq.h"
21 #include "Refs.h"
22 #include "SingleRead.h"
23 #include "PairedEndRead.h"
24 #include "PairedEndHit.h"
25 #include "ReadReader.h"
26
27 #include "simul.h"
28
29 class PairedEndModel {
30 public:
31         PairedEndModel(Refs* refs = NULL) {
32                 this->refs = refs;
33                 M = (refs != NULL ? refs->getM() : 0);
34                 memset(N, 0, sizeof(N));
35                 estRSPD = false;
36                 needCalcConPrb = true;
37
38                 ori = new Orientation();
39                 gld = new LenDist();
40                 rspd = new RSPD(estRSPD);
41                 pro = new Profile();
42                 npro = new NoiseProfile();
43                 mld = new LenDist();
44
45                 mw = NULL;
46                 seedLen = 0;
47         }
48
49         //If it is not a master node, only init & update can be used!
50         PairedEndModel(ModelParams& params, bool isMaster = true) {
51                 M = params.M;
52                 memcpy(N, params.N, sizeof(params.N));
53                 refs = params.refs;
54                 estRSPD = params.estRSPD;
55                 seedLen = params.seedLen;
56                 needCalcConPrb = true;
57
58                 ori = NULL; gld = NULL; rspd = NULL; pro = NULL; npro = NULL; mld = NULL;
59                 mw = NULL;
60
61                 if (isMaster) {
62                         if (!estRSPD) rspd = new RSPD(estRSPD);
63                         mld = new LenDist(params.mate_minL, params.mate_maxL);
64                 }
65
66                 ori = new Orientation(params.probF);
67                 gld = new LenDist(params.minL, params.maxL);
68                 if (estRSPD) rspd = new RSPD(estRSPD, params.B);
69                 pro = new Profile(params.maxL);
70                 npro = new NoiseProfile();
71         }
72
73         ~PairedEndModel() {
74                 refs = NULL;
75                 if (ori != NULL) delete ori;
76                 if (gld != NULL) delete gld;
77                 if (rspd != NULL) delete rspd;
78                 if (pro != NULL) delete pro;
79                 if (npro != NULL) delete npro;
80                 if (mld != NULL) delete mld;
81                 if (mw != NULL) delete mw;
82         }
83
84         void estimateFromReads(const char*);
85
86         //if prob is too small, just make it 0
87         double getConPrb(const PairedEndRead& read, const PairedEndHit& hit) {
88                 if (read.isLowQuality()) return 0.0;
89
90                 double prob;
91                 int sid = hit.getSid();
92                 RefSeq &ref = refs->getRef(sid);
93                 int dir = hit.getDir();
94                 int pos = hit.getPos();
95                 int fullLen = ref.getFullLen();
96                 int totLen = ref.getTotLen();
97                 int insertLen = hit.getInsertL();
98
99                 int fpos = (dir == 0 ? pos : totLen - pos - insertLen); // the aligned position reported in SAM file, should be a coordinate in forward strand
100                 int effL = std::min(fullLen, totLen - insertLen + 1);
101                 
102                 assert(fpos >= 0 && fpos + insertLen <= totLen && insertLen <= totLen);
103                 if (fpos >= fullLen || ref.getMask(fpos)) return 0.0; // For paired-end model, fpos is the seedPos
104
105                 prob = ori->getProb(dir) * gld->getAdjustedProb(insertLen, totLen) *
106                        rspd->getAdjustedProb(fpos, effL, fullLen);
107
108                 const SingleRead& mate1 = read.getMate1();
109                 prob *= mld->getAdjustedProb(mate1.getReadLength(), insertLen) *
110                         pro->getProb(mate1.getReadSeq(), ref, pos, dir);
111
112                 const SingleRead& mate2 = read.getMate2();
113                 int m2pos = totLen - pos - insertLen;
114                 int m2dir = !dir;
115                 prob *= mld->getAdjustedProb(mate2.getReadLength(), insertLen) *
116                         pro->getProb(mate2.getReadSeq(), ref, m2pos, m2dir);
117
118                 if (prob < EPSILON) { prob = 0.0; }
119
120                 prob = (mw[sid] < EPSILON ? 0.0 : prob / mw[sid]);
121
122                 return prob;
123         }
124
125         double getNoiseConPrb(const PairedEndRead& read) {
126                 if (read.isLowQuality()) return 0.0;
127                 double prob;
128                 const SingleRead& mate1 = read.getMate1();
129                 const SingleRead& mate2 = read.getMate2();
130
131                 prob = mld->getProb(mate1.getReadLength()) * npro->getProb(mate1.getReadSeq());
132                 prob *= mld->getProb(mate2.getReadLength()) * npro->getProb(mate2.getReadSeq());
133
134                 if (prob < EPSILON) { prob = 0.0; }
135
136                 prob = (mw[0] < EPSILON ? 0.0: prob / mw[0]);
137
138                 return prob;
139         }
140
141         double getLogP() { return npro->getLogP(); }
142
143         void init();
144
145         void update(const PairedEndRead& read, const PairedEndHit& hit, double frac) {
146                 if (read.isLowQuality() || frac < EPSILON) return;
147
148                 RefSeq& ref = refs->getRef(hit.getSid());
149                 const SingleRead& mate1 = read.getMate1();
150                 const SingleRead& mate2 = read.getMate2();
151
152                 gld->update(hit.getInsertL(), frac);
153                 if (estRSPD) {
154                         int fpos = (hit.getDir() == 0 ? hit.getPos() : ref.getTotLen() - hit.getPos() - hit.getInsertL());
155                         rspd->update(fpos, ref.getFullLen(), frac);
156                 }
157                 pro->update(mate1.getReadSeq(), ref, hit.getPos(), hit.getDir(), frac);
158
159                 int m2pos = ref.getTotLen() - hit.getPos() - hit.getInsertL();
160                 int m2dir = !hit.getDir();
161                 pro->update(mate2.getReadSeq(), ref, m2pos, m2dir, frac);
162         }
163
164         void updateNoise(const PairedEndRead& read, double frac) {
165                 if (read.isLowQuality() || frac < EPSILON) return;
166
167                 const SingleRead& mate1 = read.getMate1();
168                 const SingleRead& mate2 = read.getMate2();
169
170                 npro->update(mate1.getReadSeq(), frac);
171                 npro->update(mate2.getReadSeq(), frac);
172         }
173
174         void finish();
175
176         void collect(const PairedEndModel&);
177
178         bool getNeedCalcConPrb() { return needCalcConPrb; }
179         void setNeedCalcConPrb(bool value) { needCalcConPrb = value; }
180
181         void read(const char*);
182         void write(const char*);
183
184         const LenDist& getGLD() { return *gld; }
185
186         void startSimulation(simul*, double*);
187         bool simulate(int, PairedEndRead&, int&);
188         void finishSimulation();
189
190         //Use it after function 'read' or 'estimateFromReads'
191         double* getMW() { 
192           assert(mw != NULL);
193           return mw;
194         }
195
196         int getModelType() const { return model_type; }
197
198 private:
199         static const int model_type = 2;
200         static const int read_type = 2;
201
202         int M;
203         int N[3];
204         Refs *refs;
205         int seedLen;
206
207         bool estRSPD;
208         bool needCalcConPrb; //true need, false does not need
209
210         Orientation *ori;
211         LenDist *gld, *mld; //mld1 mate_length_dist
212         RSPD *rspd;
213         Profile *pro;
214         NoiseProfile *npro;
215
216         simul *sampler; // for simulation
217         double *theta_cdf; // for simulation
218
219         double *mw; // for masking
220
221         void calcMW();
222 };
223
224 void PairedEndModel::estimateFromReads(const char* readFN) {
225         int s;
226         char readFs[2][STRLEN];
227     PairedEndRead read;
228
229     mld->init();
230     for (int i = 0; i < 3; i++)
231         if (N[i] > 0) {
232                 genReadFileNames(readFN, i, read_type, s, readFs);
233                 ReadReader<PairedEndRead> reader(s, readFs, refs->hasPolyA(), seedLen); // allow calculation of calc_lq() function
234
235                 int cnt = 0;
236                 while (reader.next(read)) {
237                         SingleRead mate1 = read.getMate1();
238                         SingleRead mate2 = read.getMate2();
239                         
240                         if (!read.isLowQuality()) {
241                                 mld->update(mate1.getReadLength(), 1.0);
242                                 mld->update(mate2.getReadLength(), 1.0);
243                           
244                                 if (i == 0) {
245                                         npro->updateC(mate1.getReadSeq());
246                                         npro->updateC(mate2.getReadSeq());
247                                 }
248                         }
249                         else if (verbose && (mate1.getReadLength() < seedLen || mate2.getReadLength() < seedLen)) {
250                                 printf("Warning: Read %s is ignored due to at least one of the mates' length < seed length %d!\n", read.getName().c_str(), seedLen);
251                         }
252
253                         ++cnt;
254                         if (verbose && cnt % 1000000 == 0) { printf("%d READS PROCESSED\n", cnt); }
255                 }
256
257                 if (verbose) { printf("estimateFromReads, N%d finished.\n", i); }
258         }
259
260     mld->finish();
261     npro->calcInitParams();
262
263     mw = new double[M + 1];
264     calcMW();
265 }
266
267 void PairedEndModel::init() {
268         gld->init();
269         if (estRSPD) rspd->init();
270         pro->init();
271         npro->init();
272 }
273
274 void PairedEndModel::finish() {
275         gld->finish();
276         if (estRSPD) rspd->finish();
277         pro->finish();
278         npro->finish();
279         needCalcConPrb = true;
280         calcMW();
281 }
282
283 void PairedEndModel::collect(const PairedEndModel& o) {
284         gld->collect(*(o.gld));
285         if (estRSPD) rspd->collect(*(o.rspd));
286         pro->collect(*(o.pro));
287         npro->collect(*(o.npro));
288 }
289
290 //Only master node can call
291 void PairedEndModel::read(const char* inpF) {
292         int val;
293         FILE *fi = fopen(inpF, "r");
294         if (fi == NULL) { fprintf(stderr, "Cannot open %s! It may not exist.\n", inpF); exit(-1); }
295
296         assert(fscanf(fi, "%d", &val) == 1);
297         assert(val == model_type);
298
299         ori->read(fi);
300         gld->read(fi);
301         mld->read(fi);
302         rspd->read(fi);
303         pro->read(fi);
304         npro->read(fi);
305
306         if (fscanf(fi, "%d", &val) == 1) {
307                 if (M == 0) M = val;
308                 if (M == val) {
309                         mw = new double[M + 1];
310                         for (int i = 0; i <= M; i++) assert(fscanf(fi, "%lf", &mw[i]) == 1);
311                 }
312         }
313
314         fclose(fi);
315 }
316
317 //Only master node can call. Only be called at EM.cpp
318 void PairedEndModel::write(const char* outF) {
319         FILE *fo = fopen(outF, "w");
320
321         fprintf(fo, "%d\n", model_type);
322         fprintf(fo, "\n");
323
324         ori->write(fo);  fprintf(fo, "\n");
325         gld->write(fo);  fprintf(fo, "\n");
326         mld->write(fo);  fprintf(fo, "\n");
327         rspd->write(fo); fprintf(fo, "\n");
328         pro->write(fo);  fprintf(fo, "\n");
329         npro->write(fo);
330
331         if (mw != NULL) {
332           fprintf(fo, "\n%d\n", M);
333           for (int i = 0; i < M; i++) {
334             fprintf(fo, "%.15g ", mw[i]);
335           }
336           fprintf(fo, "%.15g\n", mw[M]);
337         }
338
339         fclose(fo);
340 }
341
342 void PairedEndModel::startSimulation(simul* sampler, double* theta) {
343         this->sampler = sampler;
344
345         theta_cdf = new double[M + 1];
346         for (int i = 0; i <= M; i++) {
347                 theta_cdf[i] = theta[i];
348                 if (i > 0) theta_cdf[i] += theta_cdf[i - 1];
349         }
350
351         rspd->startSimulation(M, refs);
352         pro->startSimulation();
353         npro->startSimulation();
354 }
355
356 bool PairedEndModel::simulate(int rid, PairedEndRead& read, int& sid) {
357         int dir, pos;
358         int insertL, mateL1, mateL2;
359         std::string name;
360         std::string readseq1, readseq2;
361         std::ostringstream strout;
362
363         sid = sampler->sample(theta_cdf, M + 1);
364
365         if (sid == 0) {
366                 dir = pos = insertL = 0;
367                 mateL1 = mld->simulate(sampler, -1);
368                 readseq1 = npro->simulate(sampler, mateL1);
369
370                 mateL2 = mld->simulate(sampler, -1);
371                 readseq2 = npro->simulate(sampler, mateL2);
372         }
373         else {
374                 RefSeq &ref = refs->getRef(sid);
375                 dir = ori->simulate(sampler);
376                 insertL = gld->simulate(sampler, ref.getTotLen());
377                 if (insertL < 0) return false;
378                 int effL = std::min(ref.getFullLen(), ref.getTotLen() - insertL + 1);
379                 pos = rspd->simulate(sampler, sid, effL);
380                 if (pos < 0) return false;
381                 if (dir > 0) pos = ref.getTotLen() - pos - insertL;
382
383                 mateL1 = mld->simulate(sampler, insertL);
384                 readseq1 = pro->simulate(sampler, mateL1, pos, dir, ref);
385
386                 int m2pos = ref.getTotLen() - pos - insertL;
387                 int m2dir = !dir;
388
389                 mateL2 = mld->simulate(sampler, insertL);
390                 readseq2 = pro->simulate(sampler, mateL2, m2pos, m2dir, ref);
391         }
392
393         strout<<rid<<"_"<<dir<<"_"<<sid<<"_"<<pos<<"_"<<insertL;
394         name = strout.str();
395
396         read = PairedEndRead(SingleRead(name + "/1", readseq1), SingleRead(name + "/2", readseq2));
397
398         return true;
399 }
400
401 void PairedEndModel::finishSimulation() {
402         delete[] theta_cdf;
403
404         rspd->finishSimulation();
405         pro->finishSimulation();
406         npro->finishSimulation();
407 }
408
409 void PairedEndModel::calcMW() {
410         assert(mld->getMinL() >= seedLen);
411
412         memset(mw, 0, sizeof(double) * (M + 1));
413         mw[0] = 1.0;
414
415         for (int i = 1; i <= M; i++) {
416                 RefSeq& ref = refs->getRef(i);
417                 int totLen = ref.getTotLen();
418                 int fullLen = ref.getFullLen();
419                 int end = std::min(fullLen, totLen - gld->getMinL() + 1);
420                 double value = 0.0;
421                 int minL, maxL;
422                 int effL, pfpos;
423
424                 //seedPos is fpos here
425                 for (int seedPos = 0; seedPos < end; seedPos++)
426                         if (ref.getMask(seedPos)) {
427                                 minL = gld->getMinL();
428                                 maxL = std::min(gld->getMaxL(), totLen - seedPos);
429                                 pfpos = seedPos;
430                                 for (int fragLen = minL; fragLen <= maxL; fragLen++) {
431                                         effL = std::min(fullLen, totLen - fragLen + 1);
432                                         value += gld->getAdjustedProb(fragLen, totLen) * rspd->getAdjustedProb(pfpos, effL, fullLen);
433                                 }
434                         }
435
436                 mw[i] = 1.0 - value;
437
438                 if (mw[i] < 1e-8) {
439                         //fprintf(stderr, "Warning: %dth reference sequence is masked for almost all positions!\n", i);
440                         mw[i] = 0.0;
441                 }
442         }
443 }
444
445 #endif /* PAIREDENDMODEL_H_ */