]> git.donarmstrong.com Git - mothur.git/blob - viterbifast.cpp
3f3fc970a2075a248bdace6b599773e04a84db01
[mothur.git] / viterbifast.cpp
1 //uchime by Robert C. Edgar http://drive5.com/uchime This code is donated to the public domain.
2
3 #include "dp.h"
4 #include "out.h"
5 #include "evalue.h"
6
7 #define CMP_SIMPLE      0
8
9 #if     SAVE_FAST
10 static Mx<float> g_MxDPM;
11 static Mx<float> g_MxDPD;
12 static Mx<float> g_MxDPI;
13
14 static Mx<char> g_MxTBM;
15 static Mx<char> g_MxTBD;
16 static Mx<char> g_MxTBI;
17
18 static float **g_DPM;
19 static float **g_DPD;
20 static float **g_DPI;
21
22 static char **g_TBM;
23 static char **g_TBD;
24 static char **g_TBI;
25
26 #if     CMP_SIMPLE
27 static Mx<float> *g_DPMSimpleMx;
28 static Mx<float> *g_DPDSimpleMx;
29 static Mx<float> *g_DPISimpleMx;
30 static float **g_DPMSimple;
31 static float **g_DPDSimple;
32 static float **g_DPISimple;
33
34 #define cmpm(i, j, x)   { if (!feq(x, g_DPMSimple[i][j])) \
35                                                         { \
36                                                         Die("%s:%d %.1f != DPMSimple[%u][%u] = %.1f", \
37                                                           __FILE__, __LINE__, x, i, j, g_DPMSimple[i][j]); \
38                                                         } \
39                                                 }
40
41 #define cmpd(i, j, x)   { if (!feq(x, g_DPDSimple[i][j])) \
42                                                         { \
43                                                         Die("%s:%d %.1f != DPMSimple[%u][%u] = %.1f", \
44                                                           __FILE__, __LINE__, x, i, j, g_DPDSimple[i][j]); \
45                                                         } \
46                                                 }
47
48 #define cmpi(i, j, x)   { if (!feq(x, g_DPISimple[i][j])) \
49                                                         { \
50                                                         Die("%s:%d %.1f != DPMSimple[%u][%u] = %.1f", \
51                                                           __FILE__, __LINE__, x, i, j, g_DPISimple[i][j]); \
52                                                         } \
53                                                 }
54
55 #else
56
57 #define cmpm(i, j, x)   /* empty */
58 #define cmpd(i, j, x)   /* empty */
59 #define cmpi(i, j, x)   /* empty */
60
61 #endif
62
63 static void AllocSave(unsigned LA, unsigned LB)
64         {
65 #if     CMP_SIMPLE
66         GetSimpleDPMxs(&g_DPMSimpleMx, &g_DPDSimpleMx, &g_DPISimpleMx);
67         g_DPMSimple = g_DPMSimpleMx->GetData();
68         g_DPDSimple = g_DPDSimpleMx->GetData();
69         g_DPISimple = g_DPISimpleMx->GetData();
70 #endif
71         g_MxDPM.Alloc("FastM", LA+1, LB+1);
72         g_MxDPD.Alloc("FastD", LA+1, LB+1);
73         g_MxDPI.Alloc("FastI", LA+1, LB+1);
74
75         g_MxTBM.Alloc("FastTBM", LA+1, LB+1);
76         g_MxTBD.Alloc("FastTBD", LA+1, LB+1);
77         g_MxTBI.Alloc("FastTBI", LA+1, LB+1);
78
79         g_DPM = g_MxDPM.GetData();
80         g_DPD = g_MxDPD.GetData();
81         g_DPI = g_MxDPI.GetData();
82
83         g_TBM = g_MxTBM.GetData();
84         g_TBD = g_MxTBD.GetData();
85         g_TBI = g_MxTBI.GetData();
86         }
87
88 static void SAVE_DPM(unsigned i, unsigned j, float x)
89         {
90         g_DPM[i][j] = x;
91 #if     CMP_SIMPLE
92         if (i > 0 && j > 0)
93         asserta(feq(x, g_DPMSimple[i][j]));
94 #endif
95         }
96
97 static void SAVE_DPD(unsigned i, unsigned j, float x)
98         {
99         g_DPD[i][j] = x;
100 #if     CMP_SIMPLE
101         if (i > 0 && j > 0)
102         asserta(feq(x, g_DPDSimple[i][j]));
103 #endif
104         }
105
106 static void SAVE_DPI(unsigned i, unsigned j, float x)
107         {
108         g_DPI[i][j] = x;
109 #if     CMP_SIMPLE
110         if (i > 0 && j > 0)
111         asserta(feq(x, g_DPISimple[i][j]));
112 #endif
113         }
114
115 static void SAVE_TBM(unsigned i, unsigned j, char x)
116         {
117         g_TBM[i][j] = x;
118         }
119
120 static void SAVE_TBD(unsigned i, unsigned j, char x)
121         {
122         g_TBD[i][j] = x;
123         }
124
125 static void SAVE_TBI(unsigned i, unsigned j, char x)
126         {
127         g_TBI[i][j] = x;
128         }
129
130 void GetFastMxs(Mx<float> **M, Mx<float> **D, Mx<float> **I)
131         {
132         *M = &g_MxDPM;
133         *D = &g_MxDPD;
134         *I = &g_MxDPI;
135         }
136
137 #else   // SAVE_FAST
138
139 #define SAVE_DPM(i, j, x)       /* empty */
140 #define SAVE_DPD(i, j, x)       /* empty */
141 #define SAVE_DPI(i, j, x)       /* empty */
142
143 #define SAVE_TBM(i, j, x)       /* empty */
144 #define SAVE_TBD(i, j, x)       /* empty */
145 #define SAVE_TBI(i, j, x)       /* empty */
146
147 #define AllocSave(LA, LB)       /* empty */
148
149 #define cmpm(i, j, x)   /* empty */
150 #define cmpd(i, j, x)   /* empty */
151 #define cmpi(i, j, x)   /* empty */
152
153 #endif  // SAVE_FAST
154
155 float ViterbiFast(const byte *A, unsigned LA, const byte *B, unsigned LB,
156   const AlnParams &AP, PathData &PD)
157         {
158         if (LA*LB > 100*1000*1000)
159                 Die("ViterbiFast, too long LA=%u, LB=%u", LA, LB);
160
161         AllocBit(LA, LB);
162         AllocSave(LA, LB);
163         
164         StartTimer(ViterbiFast);
165
166         const float * const *Mx = AP.SubstMx;
167         float OpenA = AP.LOpenA;
168         float ExtA = AP.LExtA;
169
170         byte **TB = g_TBBit;
171         float *Mrow = g_DPRow1;
172         float *Drow = g_DPRow2;
173
174 // Use Mrow[-1], so...
175         Mrow[-1] = MINUS_INFINITY;
176         for (unsigned j = 0; j <= LB; ++j)
177                 {
178                 Mrow[j] = MINUS_INFINITY;
179                 SAVE_DPM(0, j, MINUS_INFINITY);
180                 SAVE_TBM(0, j, '?');
181
182                 Drow[j] = MINUS_INFINITY;
183                 SAVE_DPD(0, j, MINUS_INFINITY);
184                 SAVE_TBD(0, j, '?');
185                 }
186         
187 // Main loop
188         float M0 = float (0);
189         SAVE_DPM(0, 0, 0);
190         for (unsigned i = 0; i < LA; ++i)
191                 {
192                 byte a = A[i];
193                 const float *MxRow = Mx[a];
194                 float OpenB = AP.LOpenB;
195                 float ExtB = AP.LExtB;
196                 float I0 = MINUS_INFINITY;
197
198                 SAVE_TBM(i, 0, '?');
199
200                 SAVE_DPI(i, 0, MINUS_INFINITY);
201                 SAVE_DPI(i, 1, MINUS_INFINITY);
202
203                 SAVE_TBI(i, 0, '?');
204                 SAVE_TBI(i, 1, '?');
205                 
206                 byte *TBrow = TB[i];
207                 for (unsigned j = 0; j < LB; ++j)
208                         {
209                         byte b = B[j];
210                         byte TraceBits = 0;
211                         float SavedM0 = M0;
212
213                 // MATCH
214                         {
215                 // M0 = DPM[i][j]
216                 // I0 = DPI[i][j]
217                 // Drow[j] = DPD[i][j]
218                         cmpm(i, j, M0);
219                         cmpd(i, j, Drow[j]);
220                         cmpi(i, j, I0);
221
222                         float xM = M0;
223                         SAVE_TBM(i+1, j+1, 'M');
224                         if (Drow[j] > xM)
225                                 {
226                                 xM = Drow[j];
227                                 TraceBits = TRACEBITS_DM;
228                                 SAVE_TBM(i+1, j+1, 'D');
229                                 }
230                         if (I0 > xM)
231                                 {
232                                 xM = I0;
233                                 TraceBits = TRACEBITS_IM;
234                                 SAVE_TBM(i+1, j+1, 'I');
235                                 }
236                         M0 = Mrow[j];
237                         cmpm(i, j+1, M0);
238
239                         Mrow[j] = xM + MxRow[b];
240                 // Mrow[j] = DPM[i+1][j+1])
241                         SAVE_DPM(i+1, j+1, Mrow[j]);
242                         }
243                         
244                 // DELETE
245                         {
246                 // SavedM0 = DPM[i][j]
247                 // Drow[j] = DPD[i][j]
248                         cmpm(i, j, SavedM0);
249                         cmpd(i, j, Drow[j]);
250
251                         float md = SavedM0 + OpenB;
252                         Drow[j] += ExtB;
253                         SAVE_TBD(i+1, j, 'D');
254                         if (md >= Drow[j])
255                                 {
256                                 Drow[j] = md;
257                                 TraceBits |= TRACEBITS_MD;
258                                 SAVE_TBD(i+1, j, 'M');
259                                 }
260                 // Drow[j] = DPD[i+1][j]
261                         SAVE_DPD(i+1, j, Drow[j]);
262                         }
263                         
264                 // INSERT
265                         {
266                 // SavedM0 = DPM[i][j]
267                 // I0 = DPI[i][j]
268                         cmpm(i, j, SavedM0);
269                         cmpi(i, j, I0);
270                         
271                         float mi = SavedM0 + OpenA;
272                         I0 += ExtA;
273                         SAVE_TBI(i, j+1, 'I');
274                         if (mi >= I0)
275                                 {
276                                 I0 = mi;
277                                 TraceBits |= TRACEBITS_MI;
278                                 SAVE_TBI(i, j+1, 'M');
279                                 }
280                 // I0 = DPI[i][j+1]
281                         SAVE_DPI(i, j+1, I0);
282                         }
283                         
284                         OpenB = AP.OpenB;
285                         ExtB = AP.ExtB;
286                         
287                         TBrow[j] = TraceBits;
288                         }
289                 
290         // Special case for end of Drow[]
291                 {
292         // M0 = DPM[i][LB]
293         // Drow[LB] = DPD[i][LB]
294                 
295                 TBrow[LB] = 0;
296                 float md = M0 + AP.ROpenB;
297                 Drow[LB] += AP.RExtB;
298                 SAVE_TBD(i+1, LB, 'D');
299                 if (md >= Drow[LB])
300                         {
301                         Drow[LB] = md;
302                         TBrow[LB] = TRACEBITS_MD;
303                         SAVE_TBD(i+1, LB, 'M');
304                         }
305         // Drow[LB] = DPD[i+1][LB]
306                 SAVE_DPD(i+1, LB, Drow[LB]);
307                 }
308                 
309                 SAVE_DPM(i+1, 0, MINUS_INFINITY);
310                 M0 = MINUS_INFINITY;
311
312                 OpenA = AP.OpenA;
313                 ExtA = AP.ExtA;
314                 }
315         
316         SAVE_TBM(LA, 0, '?');
317
318 // Special case for last row of DPI
319         byte *TBrow = TB[LA];
320         float I1 = MINUS_INFINITY;
321
322         SAVE_DPI(LA, 0, MINUS_INFINITY);
323         SAVE_TBI(LA, 0, '?');
324
325         SAVE_DPI(LA, 1, MINUS_INFINITY);
326         SAVE_TBI(LA, 1, '?');
327
328         for (unsigned j = 1; j < LB; ++j)
329                 {
330         // Mrow[j-1] = DPM[LA][j]
331         // I1 = DPI[LA][j]
332                 
333                 TBrow[j] = 0;
334                 float mi = Mrow[int(j)-1] + AP.ROpenA;
335                 I1 += AP.RExtA;
336                 SAVE_TBI(LA, j+1, 'I');
337                 if (mi > I1)
338                         {
339                         I1 = mi;
340                         TBrow[j] = TRACEBITS_MI;
341                         SAVE_TBI(LA, j+1, 'M');
342                         }
343                 SAVE_DPI(LA, j+1, I1);
344                 }
345         
346         float FinalM = Mrow[LB-1];
347         float FinalD = Drow[LB];
348         float FinalI = I1;
349 // FinalM = DPM[LA][LB]
350 // FinalD = DPD[LA][LB]
351 // FinalI = DPI[LA][LB]
352         
353         float Score = FinalM;
354         byte State = 'M';
355         if (FinalD > Score)
356                 {
357                 Score = FinalD;
358                 State = 'D';
359                 }
360         if (FinalI > Score)
361                 {
362                 Score = FinalI;
363                 State = 'I';
364                 }
365
366         EndTimer(ViterbiFast);
367         TraceBackBit(LA, LB, State, PD);
368
369 #if     SAVE_FAST
370         g_MxDPM.LogMe();
371         g_MxDPD.LogMe();
372         g_MxDPI.LogMe();
373
374         g_MxTBM.LogMe();
375         g_MxTBD.LogMe();
376         g_MxTBI.LogMe();
377 #endif
378
379         return Score;
380         }