]> git.donarmstrong.com Git - rsem.git/blob - BamWriter.h
rsem v1.1.11, allow spaces appear in field seqname and attributes gene_id, transcript...
[rsem.git] / BamWriter.h
1 #ifndef BAMWRITER_H_
2 #define BAMWRITER_H_
3
4 #include<cmath>
5 #include<cstdio>
6 #include<cstring>
7 #include<cassert>
8 #include<string>
9 #include<map>
10 #include<sstream>
11
12 #include "sam/bam.h"
13 #include "sam/sam.h"
14
15 #include "utils.h"
16 #include "SingleHit.h"
17 #include "PairedEndHit.h"
18
19 #include "HitWrapper.h"
20 #include "Transcript.h"
21 #include "Transcripts.h"
22
23 class BamWriter {
24 public:
25         BamWriter(char, const char*, const char*, const char*, const char*);
26         ~BamWriter();
27
28         void work(HitWrapper<SingleHit>, Transcripts&);
29         void work(HitWrapper<PairedEndHit>, Transcripts&);
30 private:
31         samfile_t *in, *out;
32
33         std::map<std::string, int> refmap;
34         std::map<std::string, int>::iterator iter;
35
36         struct SingleEndT {
37                 bam1_t *b;
38
39                 SingleEndT(bam1_t *b = NULL) {
40                         this->b = b;
41                 }
42
43                 bool operator< (const SingleEndT& o) const {
44                         int strand1, strand2;
45                         uint32_t *p1, *p2;
46
47                         if (b->core.tid != o.b->core.tid) return b->core.tid < o.b->core.tid;
48                         if (b->core.pos != o.b->core.pos) return b->core.pos < o.b->core.pos;
49                         strand1 = b->core.flag & 0x0010; strand2 = o.b->core.flag & 0x0010;
50                         if (strand1 != strand2) return strand1 < strand2;
51                         if (b->core.n_cigar != o.b->core.n_cigar) return b->core.n_cigar < o.b->core.n_cigar;
52                         p1 = bam1_cigar(b); p2 = bam1_cigar(o.b);
53                         for (int i = 0; i < (int)b->core.n_cigar; i++) {
54                                 if (*p1 != *p2) return *p1 < *p2;
55                                 ++p1; ++p2;
56                         }
57                         return false;
58                 }
59         };
60
61         //b is mate 1, b2 is mate 2
62         struct PairedEndT {
63                 bam1_t *b, *b2;
64
65                 PairedEndT() { b = NULL; b2 = NULL;}
66
67                 PairedEndT(bam1_t *b, bam1_t *b2) {
68                         this->b = b;
69                         this->b2 = b2;
70                 }
71
72                 bool operator< (const PairedEndT& o) const {
73                         int strand1, strand2;
74                         uint32_t *p1, *p2;
75
76                         //compare b
77                         if (b->core.tid != o.b->core.tid) return b->core.tid < o.b->core.tid;
78                         if (b->core.pos != o.b->core.pos) return b->core.pos < o.b->core.pos;
79                         strand1 = b->core.flag & 0x0010; strand2 = o.b->core.flag & 0x0010;
80                         if (strand1 != strand2) return strand1 < strand2;
81                         if (b->core.n_cigar != o.b->core.n_cigar) return b->core.n_cigar < o.b->core.n_cigar;
82                         p1 = bam1_cigar(b); p2 = bam1_cigar(o.b);
83                         for (int i = 0; i < (int)b->core.n_cigar; i++) {
84                                 if (*p1 != *p2) return *p1 < *p2;
85                                 ++p1; ++p2;
86                         }
87
88                         //compare b2
89                         if (b2->core.tid != o.b2->core.tid) return b2->core.tid < o.b2->core.tid;
90                         if (b2->core.pos != o.b2->core.pos) return b2->core.pos < o.b2->core.pos;
91                         strand1 = b2->core.flag & 0x0010; strand2 = o.b2->core.flag & 0x0010;
92                         if (strand1 != strand2) return strand1 < strand2;
93                         if (b2->core.n_cigar != o.b2->core.n_cigar) return b2->core.n_cigar < o.b2->core.n_cigar;
94                         p1 = bam1_cigar(b2); p2 = bam1_cigar(o.b2);
95                         for (int i = 0; i < (int)b2->core.n_cigar; i++) {
96                                 if (*p1 != *p2) return *p1 < *p2;
97                                 ++p1; ++p2;
98                         }
99
100                         return false;
101                 }
102         };
103
104         uint8_t getMAPQ(double val) {
105                 double err = 1.0 - val;
106                 if (err <= 1e-10) return 100;
107                 return (uint8_t)(-10 * log10(err) + .5); // round it
108         }
109
110         void push_qname(const uint8_t* qname, int l_qname, std::vector<uint8_t>& data) {
111                 for (int i = 0; i < l_qname; i++) data.push_back(*(qname + i));
112         }
113
114         void push_seq(const uint8_t* seq, int readlen, char strand, std::vector<uint8_t>& data) {
115                 int seq_len = (readlen + 1) / 2;
116
117                 switch (strand) {
118                 case '+': for (int i = 0; i < seq_len; i++) data.push_back(*(seq + i)); break;
119                 case '-':
120                         uint8_t code, base;
121                         code = 0; base = 0;
122                         for (int i = 0; i < readlen; i++) {
123                                 switch (bam1_seqi(seq, readlen - i - 1)) {
124                                 case 1: base = 8; break;
125                                 case 2: base = 4; break;
126                                 case 4: base = 2; break;
127                                 case 8: base = 1; break;
128                                 case 15: base = 15; break;
129                                 default: assert(false);
130                                 }
131                                 code |=  base << (4 * (1 - i % 2));
132                                 if (i % 2 == 1) { data.push_back(code); code = 0; }
133                         }
134
135                         if (readlen % 2 == 1) { data.push_back(code); }
136                         break;
137                 default: assert(false);
138                 }
139         }
140
141         void push_qual(const uint8_t* qual, int readlen, char strand, std::vector<uint8_t>& data) {
142                 switch (strand) {
143                 case '+': for (int i = 0; i < readlen; i++) data.push_back(*(qual + i)); break;
144                 case '-': for (int i = readlen - 1; i >= 0; i--) data.push_back(*(qual + i)); break;
145                 default: assert(false);
146                 }
147         }
148
149         //convert transcript coordinate to chromosome coordinate and generate CIGAR string
150         void tr2chr(const Transcript&, int, int, int&, int&, std::vector<uint8_t>&);
151 };
152
153 //fn_list can be NULL
154 BamWriter::BamWriter(char inpType, const char* inpF, const char* fn_list, const char* outF, const char* chr_list) {
155         switch(inpType) {
156         case 's': in = samopen(inpF, "r", fn_list); break;
157         case 'b': in = samopen(inpF, "rb", fn_list); break;
158         default: assert(false);
159         }
160         assert(in != 0);
161
162         //generate output's header
163         bam_header_t *out_header = NULL;
164         refmap.clear();
165
166         if (chr_list == NULL) {
167                 out_header = in->header;
168         }
169         else {
170                 out_header = sam_header_read2(chr_list);
171
172                 for (int i = 0; i < out_header->n_targets; i++) {
173                         refmap[out_header->target_name[i]] = i;
174                 }
175         }
176
177
178         out = samopen(outF, "wb", out_header);
179         assert(out != 0);
180
181         if (chr_list != NULL) { bam_header_destroy(out_header); }
182 }
183
184 BamWriter::~BamWriter() {
185         samclose(in);
186         samclose(out);
187 }
188
189 void BamWriter::work(HitWrapper<SingleHit> wrapper, Transcripts& transcripts) {
190         bam1_t *b;
191         std::string cqname; // cqname : current query name
192         std::map<SingleEndT, double> hmap;
193         std::map<SingleEndT, double>::iterator hmapIter;
194         SingleHit *hit;
195
196         int cnt = 0;
197
198         cqname = "";
199         b = bam_init1();
200         hmap.clear();
201
202         while (samread(in, b) >= 0) {
203
204                 if (verbose && cnt > 0 && cnt % 1000000 == 0) { printf("%d entries are finished!\n", cnt); }
205                 ++cnt;
206
207                 if (b->core.flag & 0x0004) continue;
208
209                 hit = wrapper.getNextHit();
210                 assert(hit != NULL);
211
212                 int sid = b->core.tid + 1;
213                 assert(sid == hit->getSid());
214                 const Transcript& transcript = transcripts.getTranscriptAt(sid);
215
216                 if (transcripts.getType() == 0) {
217                         int pos = b->core.pos;
218                         int readlen = b->core.l_qseq;
219                         uint8_t *qname = b->data, *seq = bam1_seq(b), *qual = bam1_qual(b);
220                         std::vector<uint8_t> data;
221                         data.clear();
222
223                         iter = refmap.find(transcript.getSeqName());
224                         assert(iter != refmap.end());
225                         b->core.tid = iter->second;
226                         b->core.qual = 255;
227
228                         uint16_t rstrand = b->core.flag & 0x0010; // read strand
229                         b->core.flag -= rstrand;
230                         rstrand = (((!rstrand && transcript.getStrand() == '+') || (rstrand && transcript.getStrand() == '-')) ? 0 : 0x0010);
231                         b->core.flag += rstrand;
232
233                         push_qname(qname, b->core.l_qname, data);
234                         int core_pos, core_n_cigar;
235                         tr2chr(transcript, pos + 1, pos + readlen, core_pos, core_n_cigar, data);
236                         if (core_pos < 0) b->core.tid = -1;
237                         b->core.pos = core_pos;
238                         b->core.n_cigar = core_n_cigar;
239                         push_seq(seq, readlen, transcript.getStrand(), data);
240                         push_qual(qual, readlen, transcript.getStrand(), data);
241
242                         free(b->data);
243                         b->m_data = b->data_len = data.size() + 7; // 7 extra bytes for ZW tag
244                         b->l_aux = 7;
245                         b->data = (uint8_t*)malloc(b->m_data);
246                         for (int i = 0; i < b->data_len; i++) b->data[i] = data[i];
247
248                         b->core.bin = bam_reg2bin(b->core.pos, bam_calend(&(b->core), bam1_cigar(b)));
249                 }
250                 else {
251                         b->m_data = b->data_len = b->data_len - b->l_aux + 7; // 7 extra bytes for ZW tag
252                         b->l_aux = 7;
253                         b->data = (uint8_t*)realloc(b->data, b->m_data);
254                 }
255
256
257                 if (cqname != bam1_qname(b)) {
258                         if (!hmap.empty()) {
259                                 for (hmapIter = hmap.begin(); hmapIter != hmap.end(); hmapIter++) {
260                                         bam1_t *tmp_b = hmapIter->first.b;
261                                         tmp_b->core.qual = getMAPQ(hmapIter->second);
262                                         uint8_t *p = bam1_aux(tmp_b);
263                                         *p = 'Z'; ++p; *p = 'W'; ++p; *p = 'f'; ++p;
264                                         float val = (float)hmapIter->second;
265                                         memcpy(p, &val, 4);
266                                         samwrite(out, tmp_b);
267                                         bam_destroy1(tmp_b); // now hmapIter->b makes no sense
268                                 }
269                                 hmap.clear();
270                         }
271                         cqname = bam1_qname(b);
272                 }
273
274                 hmapIter = hmap.find(SingleEndT(b));
275                 if (hmapIter == hmap.end()) {
276                         hmap[SingleEndT(bam_dup1(b))] = hit->getConPrb();
277                 }
278                 else {
279                         hmapIter->second += hit->getConPrb();
280                 }
281         }
282
283         assert(wrapper.getNextHit() == NULL);
284
285         if (!hmap.empty()) {
286                 for (hmapIter = hmap.begin(); hmapIter != hmap.end(); hmapIter++) {
287                         bam1_t *tmp_b = hmapIter->first.b;
288                         tmp_b->core.qual = getMAPQ(hmapIter->second);
289                         uint8_t *p = bam1_aux(tmp_b);
290                         *p = 'Z'; ++p; *p = 'W'; ++p; *p = 'f'; ++p;
291                         float val = (float)hmapIter->second;
292                         memcpy(p, &val, 4);
293                         samwrite(out, tmp_b);
294                         bam_destroy1(tmp_b); // now hmapIter->b makes no sense
295                 }
296                 hmap.clear();
297         }
298
299         bam_destroy1(b);
300         if (verbose) { printf("Bam output file is generated!\n"); }
301 }
302
303 void BamWriter::work(HitWrapper<PairedEndHit> wrapper, Transcripts& transcripts) {
304         bam1_t *b, *b2;
305         std::string cqname; // cqname : current query name
306         std::map<PairedEndT, double> hmap;
307         std::map<PairedEndT, double>::iterator hmapIter;
308         PairedEndHit *hit;
309
310         int cnt = 0;
311
312         cqname = "";
313         b = bam_init1();
314         b2 = bam_init1();
315         hmap.clear();
316
317         while (samread(in, b) >= 0 && samread(in, b2) >= 0) {
318
319                 if (verbose && cnt > 0 && cnt % 1000000 == 0) { printf("%d entries are finished!\n", cnt); }
320                 ++cnt;
321
322                 if (!((b->core.flag & 0x0002) && (b2->core.flag & 0x0002))) continue;
323
324                 //swap if b is mate 2
325                 if (b->core.flag & 0x0080) {
326                         assert(b2->core.flag & 0x0040);
327                         bam1_t *tmp = b;
328                         b = b2; b2 = tmp;
329                 }
330
331                 hit = wrapper.getNextHit();
332                 assert(hit != NULL);
333
334                 int sid = b->core.tid + 1;
335                 assert(sid == hit->getSid());
336                 assert(sid == b2->core.tid + 1);
337                 const Transcript& transcript = transcripts.getTranscriptAt(sid);
338
339                 if (transcripts.getType() == 0) {
340                         int pos = b->core.pos, pos2 = b2->core.pos;
341                         int readlen = b->core.l_qseq, readlen2 = b2->core.l_qseq;
342                         uint8_t *qname = b->data, *seq = bam1_seq(b), *qual = bam1_qual(b);
343                         uint8_t *qname2 = b2->data, *seq2 = bam1_seq(b2), *qual2 = bam1_qual(b2);
344                         std::vector<uint8_t> data, data2;
345
346                         data.clear();
347                         data2.clear();
348
349                         iter = refmap.find(transcript.getSeqName());
350                         assert(iter != refmap.end());
351                         b->core.tid = iter->second; b->core.mtid = iter->second;
352                         b2->core.tid = iter->second; b2->core.mtid = iter->second;
353
354                         uint16_t rstrand = b->core.flag & 0x0010;
355                         b->core.flag = b->core.flag - (b->core.flag & 0x0010) - (b->core.flag & 0x0020);
356                         b2->core.flag = b2->core.flag - (b2->core.flag & 0x0010) - (b2->core.flag & 0x0020);
357
358                         uint16_t add, add2;
359                         if ((!rstrand && transcript.getStrand() == '+') || (rstrand && transcript.getStrand() == '-')) {
360                                 add = 0x0020; add2 = 0x0010;
361                         }
362                         else {
363                                 add = 0x0010; add2 = 0x0020;
364                         }
365                         b->core.flag += add;
366                         b2->core.flag += add2;
367
368                         b->core.qual = b2->core.qual = 255;
369
370                         //Do I really need this? The insert size uses transcript coordinates
371                         if (transcript.getStrand() == '-') {
372                                 b->core.isize = -b->core.isize;
373                                 b2->core.isize = -b2->core.isize;
374                         }
375
376                         push_qname(qname, b->core.l_qname, data);
377                         push_qname(qname2, b2->core.l_qname, data2);
378                         int core_pos, core_n_cigar;
379                         tr2chr(transcript, pos + 1, pos + readlen, core_pos, core_n_cigar, data);
380                         if (core_pos < 0) b->core.tid = -1;
381                         b->core.pos = core_pos; b->core.n_cigar = core_n_cigar;
382                         tr2chr(transcript, pos2 + 1, pos2 + readlen2, core_pos, core_n_cigar, data2);
383                         if (core_pos < 0) b2->core.tid = -1;
384                         b2->core.pos = core_pos; b2->core.n_cigar = core_n_cigar;
385                         b->core.mpos = b2->core.pos;
386                         b2->core.mpos = b->core.pos;
387                         push_seq(seq, readlen, transcript.getStrand(), data);
388                         push_seq(seq2, readlen2, transcript.getStrand(), data2);
389                         push_qual(qual, readlen, transcript.getStrand(), data);
390                         push_qual(qual2, readlen2, transcript.getStrand(), data2);
391
392                         free(b->data);
393                         b->m_data = b->data_len = data.size() + 7; // 7 extra bytes for ZW tag
394                         b->l_aux = 7;
395                         b->data = (uint8_t*)malloc(b->m_data);
396                         for (int i = 0; i < b->data_len; i++) b->data[i] = data[i];
397
398                         free(b2->data);
399                         b2->m_data = b2->data_len = data2.size() + 7; // 7 extra bytes for ZW tag
400                         b2->l_aux = 7;
401                         b2->data = (uint8_t*)malloc(b2->m_data);
402                         for (int i = 0; i < b2->data_len; i++) b2->data[i] = data2[i];
403
404                         b->core.bin = bam_reg2bin(b->core.pos, bam_calend(&(b->core), bam1_cigar(b)));
405                         b2->core.bin = bam_reg2bin(b2->core.pos, bam_calend(&(b2->core), bam1_cigar(b2)));
406                 }
407                 else {
408                         b->m_data = b->data_len = b->data_len - b->l_aux + 7; // 7 extra bytes for ZW tag
409                         b->l_aux = 7;
410                         b->data = (uint8_t*)realloc(b->data, b->m_data);
411
412                         b2->m_data = b2->data_len = b2->data_len - b2->l_aux + 7; // 7 extra bytes for ZW tag
413                         b2->l_aux = 7;
414                         b2->data = (uint8_t*)realloc(b2->data, b2->m_data);
415                 }
416
417                 if (cqname != bam1_qname(b)) {
418                         if (!hmap.empty()) {
419                                 for (hmapIter = hmap.begin(); hmapIter != hmap.end(); hmapIter++) {
420                                         bam1_t *tmp_b = hmapIter->first.b;
421                                         bam1_t *tmp_b2 = hmapIter->first.b2;
422
423                                         tmp_b->core.qual = tmp_b2->core.qual = getMAPQ(hmapIter->second);
424
425                                         uint8_t *p = bam1_aux(tmp_b), *p2 = bam1_aux(tmp_b2);
426                                         *p = 'Z'; ++p; *p = 'W'; ++p; *p = 'f'; ++p;
427                                         *p2 = 'Z'; ++p2; *p2 = 'W'; ++p2; *p2 = 'f'; ++p2;
428
429                                         float val = (float)hmapIter->second;
430                                         memcpy(p, &val, 4);
431                                         memcpy(p2, &val, 4);
432
433                                         samwrite(out, tmp_b);
434                                         samwrite(out, tmp_b2);
435
436                                         bam_destroy1(tmp_b);
437                                         bam_destroy1(tmp_b2);
438                                 }
439                                 hmap.clear();
440                         }
441                         cqname = bam1_qname(b);
442                 }
443
444                 hmapIter = hmap.find(PairedEndT(b, b2));
445                 if (hmapIter == hmap.end()) {
446                         hmap[PairedEndT(bam_dup1(b), bam_dup1(b2))] = hit->getConPrb();
447                 }
448                 else {
449                         hmapIter->second += hit->getConPrb();
450                 }
451         }
452
453         assert(wrapper.getNextHit() == NULL);
454
455         if (!hmap.empty()) {
456                 for (hmapIter = hmap.begin(); hmapIter != hmap.end(); hmapIter++) {
457                         bam1_t *tmp_b = hmapIter->first.b;
458                         bam1_t *tmp_b2 = hmapIter->first.b2;
459
460                         tmp_b->core.qual = tmp_b2->core.qual = getMAPQ(hmapIter->second);
461
462                         uint8_t *p = bam1_aux(tmp_b), *p2 = bam1_aux(tmp_b2);
463                         *p = 'Z'; ++p; *p = 'W'; ++p; *p = 'f'; ++p;
464                         *p2 = 'Z'; ++p2; *p2 = 'W'; ++p2; *p2 = 'f'; ++p2;
465
466                         float val = (float)hmapIter->second;
467                         memcpy(p, &val, 4);
468                         memcpy(p2, &val, 4);
469
470                         samwrite(out, tmp_b);
471                         samwrite(out, tmp_b2);
472
473                         bam_destroy1(tmp_b);
474                         bam_destroy1(tmp_b2);
475                 }
476                 hmap.clear();
477         }
478
479         bam_destroy1(b);
480         bam_destroy1(b2);
481
482         if (verbose) { printf("Bam output file is generated!\n"); }
483 }
484
485 void BamWriter::tr2chr(const Transcript& transcript, int sp, int ep, int& pos, int& n_cigar, std::vector<uint8_t>& data) {
486         int length = transcript.getLength();
487         char strand = transcript.getStrand();
488         const std::vector<Interval>& structure = transcript.getStructure();
489
490         int s, i;
491         int oldlen, curlen;
492
493         uint32_t operation;
494         uint8_t *p;
495
496         n_cigar = 0;
497         s = structure.size();
498
499         if (strand == '-') {
500                 int tmp = sp;
501                 sp = length - ep + 1;
502                 ep = length - tmp + 1;
503         }
504
505         if (ep < 1 || sp > length) { // a read which align to polyA tails totally! 
506           pos = (sp > length ? structure[s - 1].end : structure[0].start - 1); // 0 based
507
508           n_cigar = 1;
509           operation = (ep - sp + 1) << BAM_CIGAR_SHIFT | BAM_CINS; //BAM_CSOFT_CLIP;
510           p = (uint8_t*)(&operation);
511           for (int j = 0; j < 4; j++) data.push_back(*(p + j));
512
513           return;
514         }
515
516         if (sp < 1) {
517                 n_cigar++;
518                 operation = (1 - sp) << BAM_CIGAR_SHIFT | BAM_CINS; //BAM_CSOFT_CLIP;
519                 p = (uint8_t*)(&operation);
520                 for (int j = 0; j < 4; j++) data.push_back(*(p + j));
521                 sp = 1;
522         }
523
524         oldlen = curlen = 0;
525
526         for (i = 0; i < s; i++) {
527                 oldlen = curlen;
528                 curlen += structure[i].end - structure[i].start + 1;
529                 if (curlen >= sp) break;
530         }
531         assert(i < s);
532         pos = structure[i].start + (sp - oldlen - 1) - 1; // 0 based
533
534         while (curlen < ep && i < s) {
535                 n_cigar++;
536                 operation = (curlen - sp + 1) << BAM_CIGAR_SHIFT | BAM_CMATCH;
537                 p = (uint8_t*)(&operation);
538                 for (int j = 0; j < 4; j++) data.push_back(*(p + j));
539
540                 ++i;
541                 if (i >= s) continue;
542                 n_cigar++;
543                 operation = (structure[i].start - structure[i - 1].end - 1) << BAM_CIGAR_SHIFT | BAM_CREF_SKIP;
544                 p = (uint8_t*)(&operation);
545                 for (int j = 0; j < 4; j++) data.push_back(*(p + j));
546
547                 oldlen = curlen;
548                 sp = oldlen + 1;
549                 curlen += structure[i].end - structure[i].start + 1;
550         }
551
552         if (i >= s) {
553                 n_cigar++;
554                 operation = (ep - length) << BAM_CIGAR_SHIFT | BAM_CINS; //BAM_CSOFT_CLIP;
555                 p = (uint8_t*)(&operation);
556                 for (int j = 0; j < 4; j++) data.push_back(*(p + j));
557         }
558         else {
559                 n_cigar++;
560                 operation = (ep - sp + 1) << BAM_CIGAR_SHIFT | BAM_CMATCH;
561                 p = (uint8_t*)(&operation);
562                 for (int j = 0; j < 4; j++) data.push_back(*(p + j));
563         }
564 }
565
566 #endif /* BAMWRITER_H_ */