]> git.donarmstrong.com Git - samtools.git/blob - kprobaln.c
fixed a minor problem in the example coming with kprobaln.c
[samtools.git] / kprobaln.c
1 /* The MIT License
2
3    Copyright (c) 2003-2006, 2008-2010, by Heng Li <lh3lh3@live.co.uk>
4
5    Permission is hereby granted, free of charge, to any person obtaining
6    a copy of this software and associated documentation files (the
7    "Software"), to deal in the Software without restriction, including
8    without limitation the rights to use, copy, modify, merge, publish,
9    distribute, sublicense, and/or sell copies of the Software, and to
10    permit persons to whom the Software is furnished to do so, subject to
11    the following conditions:
12
13    The above copyright notice and this permission notice shall be
14    included in all copies or substantial portions of the Software.
15
16    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18    MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
20    BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
21    ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
22    CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23    SOFTWARE.
24 */
25
26 #include <stdlib.h>
27 #include <stdio.h>
28 #include <string.h>
29 #include <stdint.h>
30 #include <math.h>
31 #include "kprobaln.h"
32
33 /*****************************************
34  * Probabilistic banded glocal alignment *
35  *****************************************/
36
37 #define EI .25
38 #define EM .33333333333
39
40 static float g_qual2prob[256];
41
42 #define set_u(u, b, i, k) { int x=(i)-(b); x=x>0?x:0; (u)=((k)-x+1)*3; }
43
44 kpa_par_t kpa_par_def = { 0.001, 0.1, 10 };
45
46 /*
47   The topology of the profile HMM:
48
49            /\             /\        /\             /\
50            I[1]           I[k-1]    I[k]           I[L]
51             ^   \      \    ^    \   ^   \      \   ^
52             |    \      \   |     \  |    \      \  |
53     M[0]   M[1] -> ... -> M[k-1] -> M[k] -> ... -> M[L]   M[L+1]
54                 \      \/        \/      \/      /
55                  \     /\        /\      /\     /
56                        -> D[k-1] -> D[k] ->
57
58    M[0] points to every {M,I}[k] and every {M,I}[k] points M[L+1].
59
60    On input, _ref is the reference sequence and _query is the query
61    sequence. Both are sequences of 0/1/2/3/4 where 4 stands for an
62    ambiguous residue. iqual is the base quality. c sets the gap open
63    probability, gap extension probability and band width.
64
65    On output, state and q are arrays of length l_query. The higher 30
66    bits give the reference position the query base is matched to and the
67    lower two bits can be 0 (an alignment match) or 1 (an
68    insertion). q[i] gives the phred scaled posterior probability of
69    state[i] being wrong.
70  */
71 int kpa_glocal(const uint8_t *_ref, int l_ref, const uint8_t *_query, int l_query, const uint8_t *iqual,
72                            const kpa_par_t *c, int *state, uint8_t *q)
73 {
74         double **f, **b, *s, m[9], sI, sM, bI, bM, pb;
75         float *qual, *_qual;
76         const uint8_t *ref, *query;
77         int bw, bw2, i, k, is_diff = 0;
78
79         /*** initialization ***/
80         ref = _ref - 1; query = _query - 1; // change to 1-based coordinate
81         bw = l_ref > l_query? l_ref : l_query;
82         if (bw > c->bw) bw = c->bw;
83         if (bw < abs(l_ref - l_query)) bw = abs(l_ref - l_query);
84         bw2 = bw * 2 + 1;
85         // allocate the forward and backward matrices f[][] and b[][] and the scaling array s[]
86         f = calloc(l_query+1, sizeof(void*));
87         b = calloc(l_query+1, sizeof(void*));
88         for (i = 0; i <= l_query; ++i) {
89                 f[i] = calloc(bw2 * 3 + 6, sizeof(double)); // FIXME: this is over-allocated for very short seqs
90                 b[i] = calloc(bw2 * 3 + 6, sizeof(double));
91         }
92         s = calloc(l_query+2, sizeof(double)); // s[] is the scaling factor to avoid underflow
93         // initialize qual
94         _qual = calloc(l_query, sizeof(float));
95         if (g_qual2prob[0] == 0)
96                 for (i = 0; i < 256; ++i)
97                         g_qual2prob[i] = pow(10, -i/10.);
98         for (i = 0; i < l_query; ++i) _qual[i] = g_qual2prob[iqual? iqual[i] : 30];
99         qual = _qual - 1;
100         // initialize transition probability
101         sM = sI = 1. / (2 * l_query + 2); // the value here seems not to affect results; FIXME: need proof
102         m[0*3+0] = (1 - c->d - c->d) * (1 - sM); m[0*3+1] = m[0*3+2] = c->d * (1 - sM);
103         m[1*3+0] = (1 - c->e) * (1 - sI); m[1*3+1] = c->e * (1 - sI); m[1*3+2] = 0.;
104         m[2*3+0] = 1 - c->e; m[2*3+1] = 0.; m[2*3+2] = c->e;
105         bM = (1 - c->d) / l_query; bI = c->d / l_query; // (bM+bI)*l_query==1
106         /*** forward ***/
107         // f[0]
108         set_u(k, bw, 0, 0);
109         f[0][k] = s[0] = 1.;
110         { // f[1]
111                 double *fi = f[1], sum;
112                 int beg = 1, end = l_ref < bw + 1? l_ref : bw + 1, _beg, _end;
113                 for (k = beg, sum = 0.; k <= end; ++k) {
114                         int u;
115                         double e = (ref[k] > 3 || query[1] > 3)? 1. : ref[k] == query[1]? 1. - qual[1] : qual[1] * EM;
116                         set_u(u, bw, 1, k);
117                         fi[u+0] = e * bM; fi[u+1] = EI * bI;
118                         sum += fi[u] + fi[u+1];
119                 }
120                 // rescale
121                 s[1] = sum;
122                 set_u(_beg, bw, 1, beg); set_u(_end, bw, 1, end); _end += 2;
123                 for (k = _beg; k <= _end; ++k) fi[k] /= sum;
124         }
125         // f[2..l_query]
126         for (i = 2; i <= l_query; ++i) {
127                 double *fi = f[i], *fi1 = f[i-1], sum, qli = qual[i];
128                 int beg = 1, end = l_ref, x, _beg, _end;
129                 uint8_t qyi = query[i];
130                 x = i - bw; beg = beg > x? beg : x; // band start
131                 x = i + bw; end = end < x? end : x; // band end
132                 for (k = beg, sum = 0.; k <= end; ++k) {
133                         int u, v11, v01, v10;
134                         double e;
135                         e = (ref[k] > 3 || qyi > 3)? 1. : ref[k] == qyi? 1. - qli : qli * EM;
136                         set_u(u, bw, i, k); set_u(v11, bw, i-1, k-1); set_u(v10, bw, i-1, k); set_u(v01, bw, i, k-1);
137                         fi[u+0] = e * (m[0] * fi1[v11+0] + m[3] * fi1[v11+1] + m[6] * fi1[v11+2]);
138                         fi[u+1] = EI * (m[1] * fi1[v10+0] + m[4] * fi1[v10+1]);
139                         fi[u+2] = m[2] * fi[v01+0] + m[8] * fi[v01+2];
140                         sum += fi[u] + fi[u+1] + fi[u+2];
141 //                      fprintf(stderr, "F (%d,%d;%d): %lg,%lg,%lg\n", i, k, u, fi[u], fi[u+1], fi[u+2]); // DEBUG
142                 }
143                 // rescale
144                 s[i] = sum;
145                 set_u(_beg, bw, i, beg); set_u(_end, bw, i, end); _end += 2;
146                 for (k = _beg, sum = 1./sum; k <= _end; ++k) fi[k] *= sum;
147         }
148         { // f[l_query+1]
149                 double sum;
150                 for (k = 1, sum = 0.; k <= l_ref; ++k) {
151                         int u;
152                         set_u(u, bw, l_query, k);
153                         if (u < 3 || u >= bw2*3+3) continue;
154                     sum += f[l_query][u+0] * sM + f[l_query][u+1] * sI;
155                 }
156                 s[l_query+1] = sum; // the last scaling factor
157         }
158         /*** backward ***/
159         // b[l_query] (b[l_query+1][0]=1 and thus \tilde{b}[][]=1/s[l_query+1]; this is where s[l_query+1] comes from)
160         for (k = 1; k <= l_ref; ++k) {
161                 int u;
162                 double *bi = b[l_query];
163                 set_u(u, bw, l_query, k);
164                 if (u < 3 || u >= bw2*3+3) continue;
165                 bi[u+0] = sM / s[l_query] / s[l_query+1]; bi[u+1] = sI / s[l_query] / s[l_query+1];
166         }
167         // b[l_query-1..1]
168         for (i = l_query - 1; i >= 1; --i) {
169                 int beg = 1, end = l_ref, x, _beg, _end;
170                 double *bi = b[i], *bi1 = b[i+1], y = (i > 1), qli1 = qual[i+1];
171                 uint8_t qyi1 = query[i+1];
172                 x = i - bw; beg = beg > x? beg : x;
173                 x = i + bw; end = end < x? end : x;
174                 for (k = end; k >= beg; --k) {
175                         int u, v11, v01, v10;
176                         double e;
177                         set_u(u, bw, i, k); set_u(v11, bw, i+1, k+1); set_u(v10, bw, i+1, k); set_u(v01, bw, i, k+1);
178                         e = (k >= l_ref? 0 : (ref[k+1] > 3 || qyi1 > 3)? 1. : ref[k+1] == qyi1? 1. - qli1 : qli1 * EM) * bi1[v11];
179                         bi[u+0] = e * m[0] + EI * m[1] * bi1[v10+1] + m[2] * bi[v01+2]; // bi1[v11] has been foled into e.
180                         bi[u+1] = e * m[3] + EI * m[4] * bi1[v10+1];
181                         bi[u+2] = (e * m[6] + m[8] * bi[v01+2]) * y;
182 //                      fprintf(stderr, "B (%d,%d;%d): %lg,%lg,%lg\n", i, k, u, bi[u], bi[u+1], bi[u+2]); // DEBUG
183                 }
184                 // rescale
185                 set_u(_beg, bw, i, beg); set_u(_end, bw, i, end); _end += 2;
186                 for (k = _beg, y = 1./s[i]; k <= _end; ++k) bi[k] *= y;
187         }
188         { // b[0]
189                 int beg = 1, end = l_ref < bw + 1? l_ref : bw + 1;
190                 double sum = 0.;
191                 for (k = end; k >= beg; --k) {
192                         int u;
193                         double e = (ref[k] > 3 || query[1] > 3)? 1. : ref[k] == query[1]? 1. - qual[1] : qual[1] * EM;
194                         set_u(u, bw, 1, k);
195                         if (u < 3 || u >= bw2*3+3) continue;
196                     sum += e * b[1][u+0] * bM + EI * b[1][u+1] * bI;
197                 }
198                 set_u(k, bw, 0, 0);
199                 pb = b[0][k] = sum / s[0]; // if everything works as is expected, pb == 1.0
200         }
201         is_diff = fabs(pb - 1.) > 1e-7? 1 : 0;
202         /*** MAP ***/
203         for (i = 1; i <= l_query; ++i) {
204                 double sum = 0., *fi = f[i], *bi = b[i], max = 0.;
205                 int beg = 1, end = l_ref, x, max_k = -1;
206                 x = i - bw; beg = beg > x? beg : x;
207                 x = i + bw; end = end < x? end : x;
208                 for (k = beg; k <= end; ++k) {
209                         int u;
210                         double z;
211                         set_u(u, bw, i, k);
212                         z = fi[u+0] * bi[u+0]; if (z > max) max = z, max_k = (k-1)<<2 | 0; sum += z;
213                         z = fi[u+1] * bi[u+1]; if (z > max) max = z, max_k = (k-1)<<2 | 1; sum += z;
214                 }
215                 max /= sum; sum *= s[i]; // if everything works as is expected, sum == 1.0
216                 if (state) state[i-1] = max_k;
217                 if (q) k = (int)(-4.343 * log(1. - max) + .499), q[i-1] = k > 100? 99 : k;
218 #ifdef _MAIN
219                 fprintf(stderr, "(%.10lg,%.10lg) (%d,%d:%c,%c:%d) %lg\n", pb, sum, i-1, max_k>>2,
220                                 "ACGT"[query[i]], "ACGT"[ref[(max_k>>2)+1]], max_k&3, max); // DEBUG
221 #endif
222         }
223         /*** Compute A ***/
224         /* // compute the posterior of a gap, but I do not know how to use it...
225         if (1) {
226                 double *a;
227                 a = calloc(3 * (l_ref + 2), sizeof(double));
228                 for (i = 1; i < l_query; ++i) {
229                         double sum = 0., *fi = f[i], *bi1 = b[i+1], qli1 = qual[i+1];
230                         int beg = 1, end = l_ref, x;
231                         uint8_t qyi1 = query[i+1];
232                         x = i - bw; beg = beg > x? beg : x;
233                         x = i + bw; end = end < x? end : x;
234                         for (k = beg; k <= end; ++k) {
235                                 double *ak = a + 3 * k;
236                                 int u, v11, v01, v10;
237                                 double e;
238                                 set_u(u, bw, i, k); set_u(v11, bw, i+1, k+1); set_u(v10, bw, i+1, k); set_u(v01, bw, i, k+1);
239                                 e = k >= l_ref? 0 : (ref[k+1] > 3 || qyi1 > 3)? 1. : ref[k+1] == qyi1? 1. - qli1 : qli1 * EM;
240                                 ak[0] += fi[u] * bi1[v11] * m[0] * e;
241                                 ak[1] += fi[u] * bi1[v10+1] * m[1] * EI;
242                                 ak[2] += fi[u] * bi1[v01+2] * m[2];
243                         }
244                 }
245                 for (k = 1; k < l_ref; ++k) {
246                         double sum = 0., *ak = a + 3 * k;
247                         sum += 1. / (ak[0] + ak[1] + ak[2]);
248                         ak[0] *= sum; ak[1] *= sum; ak[2] *= sum;
249                         fprintf(stderr, "%d: %lf, %lf, %lf\n", k, ak[0], ak[1], ak[2]);
250                 }
251                 free(a);
252         }
253         */
254         /*** free ***/
255         for (i = 0; i <= l_query; ++i) {
256                 free(f[i]); free(b[i]);
257         }
258         free(f); free(b); free(s); free(_qual);
259         return 0;
260 }
261
262 #ifdef _MAIN
263 #include <unistd.h>
264 int main(int argc, char *argv[])
265 {
266         uint8_t conv[256], *iqual, *ref, *query;
267         int c, l_ref, l_query, i, q = 30, b = 10;
268         while ((c = getopt(argc, argv, "b:q:")) >= 0) {
269                 switch (c) {
270                 case 'b': b = atoi(optarg); break;
271                 case 'q': q = atoi(optarg); break;
272                 }
273         }
274         if (optind + 2 > argc) {
275                 fprintf(stderr, "Usage: %s [-q %d] [-b %d] <ref> <query>\n", argv[0], q, b); // example: acttc attc
276                 return 1;
277         }
278         memset(conv, 4, 256);
279         conv['a'] = conv['A'] = 0; conv['c'] = conv['C'] = 1;
280         conv['g'] = conv['G'] = 2; conv['t'] = conv['T'] = 3;
281         ref = (uint8_t*)argv[optind]; query = (uint8_t*)argv[optind+1];
282         l_ref = strlen((char*)ref); l_query = strlen((char*)query);
283         for (i = 0; i < l_ref; ++i) ref[i] = conv[ref[i]];
284         for (i = 0; i < l_query; ++i) query[i] = conv[query[i]];
285         iqual = malloc(l_query);
286         memset(iqual, q, l_query);
287         kpa_par_def.bw = b;
288         kpa_glocal(ref, l_ref, query, l_query, iqual, &kpa_par_def, 0, 0);
289         free(iqual);
290         return 0;
291 }
292 #endif