]> git.donarmstrong.com Git - samtools.git/blob - kprobaln.c
works
[samtools.git] / kprobaln.c
1 /* The MIT License
2
3    Copyright (c) 2003-2006, 2008-2010, by Heng Li <lh3lh3@live.co.uk>
4
5    Permission is hereby granted, free of charge, to any person obtaining
6    a copy of this software and associated documentation files (the
7    "Software"), to deal in the Software without restriction, including
8    without limitation the rights to use, copy, modify, merge, publish,
9    distribute, sublicense, and/or sell copies of the Software, and to
10    permit persons to whom the Software is furnished to do so, subject to
11    the following conditions:
12
13    The above copyright notice and this permission notice shall be
14    included in all copies or substantial portions of the Software.
15
16    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18    MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
20    BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
21    ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
22    CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23    SOFTWARE.
24 */
25
26 #include <stdlib.h>
27 #include <stdio.h>
28 #include <string.h>
29 #include <stdint.h>
30 #include <math.h>
31 #include "kprobaln.h"
32
33 /*****************************************
34  * Probabilistic banded glocal alignment *
35  *****************************************/
36
37 #define EI .25
38 #define EM .33333333333
39
40 static float g_qual2prob[256];
41
42 #define set_u(u, b, i, k) { int x=(i)-(b); x=x>0?x:0; (u)=((k)-x+1)*3; }
43
44 kpa_par_t kpa_par_def = { 0.001, 0.1, 10 };
45 kpa_par_t kpa_par_alt = { 0.0001, 0.01, 10 };
46
47 /*
48   The topology of the profile HMM:
49
50            /\             /\        /\             /\
51            I[1]           I[k-1]    I[k]           I[L]
52             ^   \      \    ^    \   ^   \      \   ^
53             |    \      \   |     \  |    \      \  |
54     M[0]   M[1] -> ... -> M[k-1] -> M[k] -> ... -> M[L]   M[L+1]
55                 \      \/        \/      \/      /
56                  \     /\        /\      /\     /
57                        -> D[k-1] -> D[k] ->
58
59    M[0] points to every {M,I}[k] and every {M,I}[k] points M[L+1].
60
61    On input, _ref is the reference sequence and _query is the query
62    sequence. Both are sequences of 0/1/2/3/4 where 4 stands for an
63    ambiguous residue. iqual is the base quality. c sets the gap open
64    probability, gap extension probability and band width.
65
66    On output, state and q are arrays of length l_query. The higher 30
67    bits give the reference position the query base is matched to and the
68    lower two bits can be 0 (an alignment match) or 1 (an
69    insertion). q[i] gives the phred scaled posterior probability of
70    state[i] being wrong.
71  */
72 int kpa_glocal(const uint8_t *_ref, int l_ref, const uint8_t *_query, int l_query, const uint8_t *iqual,
73                            const kpa_par_t *c, int *state, uint8_t *q)
74 {
75         double **f, **b = 0, *s, m[9], sI, sM, bI, bM, pb;
76         float *qual, *_qual;
77         const uint8_t *ref, *query;
78         int bw, bw2, i, k, is_diff = 0, is_backward = 1, Pr;
79
80     if ( l_ref<=0 || l_query<=0 ) return 0; // FIXME: this may not be an ideal fix, just prevents sefgault
81
82         /*** initialization ***/
83         is_backward = state && q? 1 : 0;
84         ref = _ref - 1; query = _query - 1; // change to 1-based coordinate
85         bw = l_ref > l_query? l_ref : l_query;
86         if (bw > c->bw) bw = c->bw;
87         if (bw < abs(l_ref - l_query)) bw = abs(l_ref - l_query);
88         bw2 = bw * 2 + 1;
89         // allocate the forward and backward matrices f[][] and b[][] and the scaling array s[]
90         f = calloc(l_query+1, sizeof(void*));
91         if (is_backward) b = calloc(l_query+1, sizeof(void*));
92         for (i = 0; i <= l_query; ++i) {    // FIXME: this will lead in segfault for l_query==0
93                 f[i] = calloc(bw2 * 3 + 6, sizeof(double)); // FIXME: this is over-allocated for very short seqs
94                 if (is_backward) b[i] = calloc(bw2 * 3 + 6, sizeof(double));
95         }
96         s = calloc(l_query+2, sizeof(double)); // s[] is the scaling factor to avoid underflow
97         // initialize qual
98         _qual = calloc(l_query, sizeof(float));
99         if (g_qual2prob[0] == 0)
100                 for (i = 0; i < 256; ++i)
101                         g_qual2prob[i] = pow(10, -i/10.);
102         for (i = 0; i < l_query; ++i) _qual[i] = g_qual2prob[iqual? iqual[i] : 30];
103         qual = _qual - 1;
104         // initialize transition probability
105         sM = sI = 1. / (2 * l_query + 2); // the value here seems not to affect results; FIXME: need proof
106         m[0*3+0] = (1 - c->d - c->d) * (1 - sM); m[0*3+1] = m[0*3+2] = c->d * (1 - sM);
107         m[1*3+0] = (1 - c->e) * (1 - sI); m[1*3+1] = c->e * (1 - sI); m[1*3+2] = 0.;
108         m[2*3+0] = 1 - c->e; m[2*3+1] = 0.; m[2*3+2] = c->e;
109         bM = (1 - c->d) / l_ref; bI = c->d / l_ref; // (bM+bI)*l_ref==1
110         /*** forward ***/
111         // f[0]
112         set_u(k, bw, 0, 0);
113         f[0][k] = s[0] = 1.;
114         { // f[1]
115                 double *fi = f[1], sum;
116                 int beg = 1, end = l_ref < bw + 1? l_ref : bw + 1, _beg, _end;
117                 for (k = beg, sum = 0.; k <= end; ++k) {
118                         int u;
119                         double e = (ref[k] > 3 || query[1] > 3)? 1. : ref[k] == query[1]? 1. - qual[1] : qual[1] * EM;
120                         set_u(u, bw, 1, k);
121                         fi[u+0] = e * bM; fi[u+1] = EI * bI;
122                         sum += fi[u] + fi[u+1];
123                 }
124                 // rescale
125                 s[1] = sum;
126                 set_u(_beg, bw, 1, beg); set_u(_end, bw, 1, end); _end += 2;
127                 for (k = _beg; k <= _end; ++k) fi[k] /= sum;
128         }
129         // f[2..l_query]
130         for (i = 2; i <= l_query; ++i) {
131                 double *fi = f[i], *fi1 = f[i-1], sum, qli = qual[i];
132                 int beg = 1, end = l_ref, x, _beg, _end;
133                 uint8_t qyi = query[i];
134                 x = i - bw; beg = beg > x? beg : x; // band start
135                 x = i + bw; end = end < x? end : x; // band end
136                 for (k = beg, sum = 0.; k <= end; ++k) {
137                         int u, v11, v01, v10;
138                         double e;
139                         e = (ref[k] > 3 || qyi > 3)? 1. : ref[k] == qyi? 1. - qli : qli * EM;
140                         set_u(u, bw, i, k); set_u(v11, bw, i-1, k-1); set_u(v10, bw, i-1, k); set_u(v01, bw, i, k-1);
141                         fi[u+0] = e * (m[0] * fi1[v11+0] + m[3] * fi1[v11+1] + m[6] * fi1[v11+2]);
142                         fi[u+1] = EI * (m[1] * fi1[v10+0] + m[4] * fi1[v10+1]);
143                         fi[u+2] = m[2] * fi[v01+0] + m[8] * fi[v01+2];
144                         sum += fi[u] + fi[u+1] + fi[u+2];
145 //                      fprintf(stderr, "F (%d,%d;%d): %lg,%lg,%lg\n", i, k, u, fi[u], fi[u+1], fi[u+2]); // DEBUG
146                 }
147                 // rescale
148                 s[i] = sum;
149                 set_u(_beg, bw, i, beg); set_u(_end, bw, i, end); _end += 2;
150                 for (k = _beg, sum = 1./sum; k <= _end; ++k) fi[k] *= sum;
151         }
152         { // f[l_query+1]
153                 double sum;
154                 for (k = 1, sum = 0.; k <= l_ref; ++k) {
155                         int u;
156                         set_u(u, bw, l_query, k);
157                         if (u < 3 || u >= bw2*3+3) continue;
158                     sum += f[l_query][u+0] * sM + f[l_query][u+1] * sI;
159                 }
160                 s[l_query+1] = sum; // the last scaling factor
161         }
162         { // compute likelihood
163                 double p = 1., Pr1 = 0.;
164                 for (i = 0; i <= l_query + 1; ++i) {
165                         p *= s[i];
166                         if (p < 1e-100) Pr1 += -4.343 * log(p), p = 1.;
167                 }
168                 Pr1 += -4.343 * log(p * l_ref * l_query);
169                 Pr = (int)(Pr1 + .499);
170                 if (!is_backward) { // skip backward and MAP
171                         for (i = 0; i <= l_query; ++i) free(f[i]);
172                         free(f); free(s); free(_qual);
173                         return Pr;
174                 }
175         }
176         /*** backward ***/
177         // b[l_query] (b[l_query+1][0]=1 and thus \tilde{b}[][]=1/s[l_query+1]; this is where s[l_query+1] comes from)
178         for (k = 1; k <= l_ref; ++k) {
179                 int u;
180                 double *bi = b[l_query];
181                 set_u(u, bw, l_query, k);
182                 if (u < 3 || u >= bw2*3+3) continue;
183                 bi[u+0] = sM / s[l_query] / s[l_query+1]; bi[u+1] = sI / s[l_query] / s[l_query+1];
184         }
185         // b[l_query-1..1]
186         for (i = l_query - 1; i >= 1; --i) {
187                 int beg = 1, end = l_ref, x, _beg, _end;
188                 double *bi = b[i], *bi1 = b[i+1], y = (i > 1), qli1 = qual[i+1];
189                 uint8_t qyi1 = query[i+1];
190                 x = i - bw; beg = beg > x? beg : x;
191                 x = i + bw; end = end < x? end : x;
192                 for (k = end; k >= beg; --k) {
193                         int u, v11, v01, v10;
194                         double e;
195                         set_u(u, bw, i, k); set_u(v11, bw, i+1, k+1); set_u(v10, bw, i+1, k); set_u(v01, bw, i, k+1);
196                         e = (k >= l_ref? 0 : (ref[k+1] > 3 || qyi1 > 3)? 1. : ref[k+1] == qyi1? 1. - qli1 : qli1 * EM) * bi1[v11];
197                         bi[u+0] = e * m[0] + EI * m[1] * bi1[v10+1] + m[2] * bi[v01+2]; // bi1[v11] has been foled into e.
198                         bi[u+1] = e * m[3] + EI * m[4] * bi1[v10+1];
199                         bi[u+2] = (e * m[6] + m[8] * bi[v01+2]) * y;
200 //                      fprintf(stderr, "B (%d,%d;%d): %lg,%lg,%lg\n", i, k, u, bi[u], bi[u+1], bi[u+2]); // DEBUG
201                 }
202                 // rescale
203                 set_u(_beg, bw, i, beg); set_u(_end, bw, i, end); _end += 2;
204                 for (k = _beg, y = 1./s[i]; k <= _end; ++k) bi[k] *= y;
205         }
206         { // b[0]
207                 int beg = 1, end = l_ref < bw + 1? l_ref : bw + 1;
208                 double sum = 0.;
209                 for (k = end; k >= beg; --k) {
210                         int u;
211                         double e = (ref[k] > 3 || query[1] > 3)? 1. : ref[k] == query[1]? 1. - qual[1] : qual[1] * EM;
212                         set_u(u, bw, 1, k);
213                         if (u < 3 || u >= bw2*3+3) continue;
214                     sum += e * b[1][u+0] * bM + EI * b[1][u+1] * bI;
215                 }
216                 set_u(k, bw, 0, 0);
217                 pb = b[0][k] = sum / s[0]; // if everything works as is expected, pb == 1.0
218         }
219         is_diff = fabs(pb - 1.) > 1e-7? 1 : 0;
220         /*** MAP ***/
221         for (i = 1; i <= l_query; ++i) {
222                 double sum = 0., *fi = f[i], *bi = b[i], max = 0.;
223                 int beg = 1, end = l_ref, x, max_k = -1;
224                 x = i - bw; beg = beg > x? beg : x;
225                 x = i + bw; end = end < x? end : x;
226                 for (k = beg; k <= end; ++k) {
227                         int u;
228                         double z;
229                         set_u(u, bw, i, k);
230                         z = fi[u+0] * bi[u+0]; if (z > max) max = z, max_k = (k-1)<<2 | 0; sum += z;
231                         z = fi[u+1] * bi[u+1]; if (z > max) max = z, max_k = (k-1)<<2 | 1; sum += z;
232                 }
233                 max /= sum; sum *= s[i]; // if everything works as is expected, sum == 1.0
234                 if (state) state[i-1] = max_k;
235                 if (q) k = (int)(-4.343 * log(1. - max) + .499), q[i-1] = k > 100? 99 : k;
236 #ifdef _MAIN
237                 fprintf(stderr, "(%.10lg,%.10lg) (%d,%d:%c,%c:%d) %lg\n", pb, sum, i-1, max_k>>2,
238                                 "ACGT"[query[i]], "ACGT"[ref[(max_k>>2)+1]], max_k&3, max); // DEBUG
239 #endif
240         }
241         /*** free ***/
242         for (i = 0; i <= l_query; ++i) {
243                 free(f[i]); free(b[i]);
244         }
245         free(f); free(b); free(s); free(_qual);
246         return Pr;
247 }
248
249 #ifdef _MAIN
250 #include <unistd.h>
251 int main(int argc, char *argv[])
252 {
253         uint8_t conv[256], *iqual, *ref, *query;
254         int c, l_ref, l_query, i, q = 30, b = 10, P;
255         while ((c = getopt(argc, argv, "b:q:")) >= 0) {
256                 switch (c) {
257                 case 'b': b = atoi(optarg); break;
258                 case 'q': q = atoi(optarg); break;
259                 }
260         }
261         if (optind + 2 > argc) {
262                 fprintf(stderr, "Usage: %s [-q %d] [-b %d] <ref> <query>\n", argv[0], q, b); // example: acttc attc
263                 return 1;
264         }
265         memset(conv, 4, 256);
266         conv['a'] = conv['A'] = 0; conv['c'] = conv['C'] = 1;
267         conv['g'] = conv['G'] = 2; conv['t'] = conv['T'] = 3;
268         ref = (uint8_t*)argv[optind]; query = (uint8_t*)argv[optind+1];
269         l_ref = strlen((char*)ref); l_query = strlen((char*)query);
270         for (i = 0; i < l_ref; ++i) ref[i] = conv[ref[i]];
271         for (i = 0; i < l_query; ++i) query[i] = conv[query[i]];
272         iqual = malloc(l_query);
273         memset(iqual, q, l_query);
274         kpa_par_def.bw = b;
275         P = kpa_glocal(ref, l_ref, query, l_query, iqual, &kpa_par_alt, 0, 0);
276         fprintf(stderr, "%d\n", P);
277         free(iqual);
278         return 0;
279 }
280 #endif