]> git.donarmstrong.com Git - mothur.git/blob - uchime_src/viterbifast.cpp
changes while testing
[mothur.git] / uchime_src / viterbifast.cpp
1 #include "dp.h"
2 #include "out.h"
3 #include "evalue.h"
4
5 #define CMP_SIMPLE      0
6 \r
7 #if     SAVE_FAST
8 static Mx<float> g_MxDPM;
9 static Mx<float> g_MxDPD;
10 static Mx<float> g_MxDPI;
11
12 static Mx<char> g_MxTBM;
13 static Mx<char> g_MxTBD;
14 static Mx<char> g_MxTBI;
15
16 static float **g_DPM;
17 static float **g_DPD;
18 static float **g_DPI;
19
20 static char **g_TBM;
21 static char **g_TBD;
22 static char **g_TBI;
23
24 #if     CMP_SIMPLE
25 static Mx<float> *g_DPMSimpleMx;
26 static Mx<float> *g_DPDSimpleMx;
27 static Mx<float> *g_DPISimpleMx;
28 static float **g_DPMSimple;
29 static float **g_DPDSimple;
30 static float **g_DPISimple;
31
32 #define cmpm(i, j, x)   { if (!feq(x, g_DPMSimple[i][j])) \
33                                                         { \
34                                                         Die("%s:%d %.1f != DPMSimple[%u][%u] = %.1f", \
35                                                           __FILE__, __LINE__, x, i, j, g_DPMSimple[i][j]); \
36                                                         } \
37                                                 }
38
39 #define cmpd(i, j, x)   { if (!feq(x, g_DPDSimple[i][j])) \
40                                                         { \
41                                                         Die("%s:%d %.1f != DPMSimple[%u][%u] = %.1f", \
42                                                           __FILE__, __LINE__, x, i, j, g_DPDSimple[i][j]); \
43                                                         } \
44                                                 }
45
46 #define cmpi(i, j, x)   { if (!feq(x, g_DPISimple[i][j])) \
47                                                         { \
48                                                         Die("%s:%d %.1f != DPMSimple[%u][%u] = %.1f", \
49                                                           __FILE__, __LINE__, x, i, j, g_DPISimple[i][j]); \
50                                                         } \
51                                                 }
52
53 #else
54
55 #define cmpm(i, j, x)   /* empty */
56 #define cmpd(i, j, x)   /* empty */
57 #define cmpi(i, j, x)   /* empty */
58
59 #endif
60
61 static void AllocSave(unsigned LA, unsigned LB)
62         {
63 #if     CMP_SIMPLE
64         GetSimpleDPMxs(&g_DPMSimpleMx, &g_DPDSimpleMx, &g_DPISimpleMx);
65         g_DPMSimple = g_DPMSimpleMx->GetData();
66         g_DPDSimple = g_DPDSimpleMx->GetData();
67         g_DPISimple = g_DPISimpleMx->GetData();
68 #endif
69         g_MxDPM.Alloc("FastM", LA+1, LB+1);\r
70         g_MxDPD.Alloc("FastD", LA+1, LB+1);\r
71         g_MxDPI.Alloc("FastI", LA+1, LB+1);\r
72 \r
73         g_MxTBM.Alloc("FastTBM", LA+1, LB+1);\r
74         g_MxTBD.Alloc("FastTBD", LA+1, LB+1);\r
75         g_MxTBI.Alloc("FastTBI", LA+1, LB+1);\r
76 \r
77         g_DPM = g_MxDPM.GetData();\r
78         g_DPD = g_MxDPD.GetData();\r
79         g_DPI = g_MxDPI.GetData();\r
80 \r
81         g_TBM = g_MxTBM.GetData();\r
82         g_TBD = g_MxTBD.GetData();\r
83         g_TBI = g_MxTBI.GetData();\r
84         }
85
86 static void SAVE_DPM(unsigned i, unsigned j, float x)
87         {
88         g_DPM[i][j] = x;
89 #if     CMP_SIMPLE
90         if (i > 0 && j > 0)
91         asserta(feq(x, g_DPMSimple[i][j]));
92 #endif
93         }
94
95 static void SAVE_DPD(unsigned i, unsigned j, float x)
96         {
97         g_DPD[i][j] = x;
98 #if     CMP_SIMPLE
99         if (i > 0 && j > 0)
100         asserta(feq(x, g_DPDSimple[i][j]));
101 #endif
102         }
103
104 static void SAVE_DPI(unsigned i, unsigned j, float x)
105         {
106         g_DPI[i][j] = x;
107 #if     CMP_SIMPLE
108         if (i > 0 && j > 0)
109         asserta(feq(x, g_DPISimple[i][j]));
110 #endif
111         }
112
113 static void SAVE_TBM(unsigned i, unsigned j, char x)
114         {
115         g_TBM[i][j] = x;
116         }
117
118 static void SAVE_TBD(unsigned i, unsigned j, char x)
119         {
120         g_TBD[i][j] = x;
121         }
122
123 static void SAVE_TBI(unsigned i, unsigned j, char x)
124         {
125         g_TBI[i][j] = x;
126         }
127
128 void GetFastMxs(Mx<float> **M, Mx<float> **D, Mx<float> **I)
129         {
130         *M = &g_MxDPM;
131         *D = &g_MxDPD;
132         *I = &g_MxDPI;
133         }
134
135 #else   // SAVE_FAST
136
137 #define SAVE_DPM(i, j, x)       /* empty */
138 #define SAVE_DPD(i, j, x)       /* empty */
139 #define SAVE_DPI(i, j, x)       /* empty */
140
141 #define SAVE_TBM(i, j, x)       /* empty */
142 #define SAVE_TBD(i, j, x)       /* empty */
143 #define SAVE_TBI(i, j, x)       /* empty */
144
145 #define AllocSave(LA, LB)       /* empty */
146
147 #define cmpm(i, j, x)   /* empty */
148 #define cmpd(i, j, x)   /* empty */
149 #define cmpi(i, j, x)   /* empty */
150
151 #endif  // SAVE_FAST
152
153 float ViterbiFast(const byte *A, unsigned LA, const byte *B, unsigned LB,
154   const AlnParams &AP, PathData &PD)
155         {
156         if (LA*LB > 100*1000*1000)
157                 Die("ViterbiFast, too long LA=%u, LB=%u", LA, LB);
158
159         AllocBit(LA, LB);
160         AllocSave(LA, LB);
161         
162         StartTimer(ViterbiFast);
163
164         const float * const *Mx = AP.SubstMx;
165         float OpenA = AP.LOpenA;
166         float ExtA = AP.LExtA;
167
168         byte **TB = g_TBBit;
169         float *Mrow = g_DPRow1;
170         float *Drow = g_DPRow2;
171
172 // Use Mrow[-1], so...
173         Mrow[-1] = MINUS_INFINITY;
174         for (unsigned j = 0; j <= LB; ++j)
175                 {
176                 Mrow[j] = MINUS_INFINITY;
177                 SAVE_DPM(0, j, MINUS_INFINITY);
178                 SAVE_TBM(0, j, '?');
179
180                 Drow[j] = MINUS_INFINITY;
181                 SAVE_DPD(0, j, MINUS_INFINITY);
182                 SAVE_TBD(0, j, '?');
183                 }
184         
185 // Main loop
186         float M0 = float (0);
187         SAVE_DPM(0, 0, 0);
188         for (unsigned i = 0; i < LA; ++i)
189                 {
190                 byte a = A[i];
191                 const float *MxRow = Mx[a];
192                 float OpenB = AP.LOpenB;
193                 float ExtB = AP.LExtB;
194                 float I0 = MINUS_INFINITY;
195
196                 SAVE_TBM(i, 0, '?');
197
198                 SAVE_DPI(i, 0, MINUS_INFINITY);
199                 SAVE_DPI(i, 1, MINUS_INFINITY);
200
201                 SAVE_TBI(i, 0, '?');
202                 SAVE_TBI(i, 1, '?');
203                 
204                 byte *TBrow = TB[i];
205                 for (unsigned j = 0; j < LB; ++j)
206                         {
207                         byte b = B[j];
208                         byte TraceBits = 0;
209                         float SavedM0 = M0;
210
211                 // MATCH
212                         {
213                 // M0 = DPM[i][j]
214                 // I0 = DPI[i][j]
215                 // Drow[j] = DPD[i][j]
216                         cmpm(i, j, M0);
217                         cmpd(i, j, Drow[j]);
218                         cmpi(i, j, I0);
219
220                         float xM = M0;
221                         SAVE_TBM(i+1, j+1, 'M');
222                         if (Drow[j] > xM)
223                                 {
224                                 xM = Drow[j];
225                                 TraceBits = TRACEBITS_DM;
226                                 SAVE_TBM(i+1, j+1, 'D');
227                                 }
228                         if (I0 > xM)
229                                 {
230                                 xM = I0;
231                                 TraceBits = TRACEBITS_IM;
232                                 SAVE_TBM(i+1, j+1, 'I');
233                                 }
234                         M0 = Mrow[j];
235                         cmpm(i, j+1, M0);
236
237                         Mrow[j] = xM + MxRow[b];
238                 // Mrow[j] = DPM[i+1][j+1])
239                         SAVE_DPM(i+1, j+1, Mrow[j]);
240                         }
241                         
242                 // DELETE
243                         {
244                 // SavedM0 = DPM[i][j]
245                 // Drow[j] = DPD[i][j]
246                         cmpm(i, j, SavedM0);
247                         cmpd(i, j, Drow[j]);
248
249                         float md = SavedM0 + OpenB;
250                         Drow[j] += ExtB;
251                         SAVE_TBD(i+1, j, 'D');
252                         if (md >= Drow[j])
253                                 {
254                                 Drow[j] = md;
255                                 TraceBits |= TRACEBITS_MD;
256                                 SAVE_TBD(i+1, j, 'M');
257                                 }
258                 // Drow[j] = DPD[i+1][j]
259                         SAVE_DPD(i+1, j, Drow[j]);
260                         }
261                         
262                 // INSERT
263                         {
264                 // SavedM0 = DPM[i][j]
265                 // I0 = DPI[i][j]
266                         cmpm(i, j, SavedM0);
267                         cmpi(i, j, I0);
268                         
269                         float mi = SavedM0 + OpenA;
270                         I0 += ExtA;
271                         SAVE_TBI(i, j+1, 'I');
272                         if (mi >= I0)
273                                 {
274                                 I0 = mi;
275                                 TraceBits |= TRACEBITS_MI;
276                                 SAVE_TBI(i, j+1, 'M');
277                                 }
278                 // I0 = DPI[i][j+1]
279                         SAVE_DPI(i, j+1, I0);
280                         }
281                         
282                         OpenB = AP.OpenB;
283                         ExtB = AP.ExtB;
284                         
285                         TBrow[j] = TraceBits;
286                         }
287                 
288         // Special case for end of Drow[]
289                 {
290         // M0 = DPM[i][LB]
291         // Drow[LB] = DPD[i][LB]
292                 
293                 TBrow[LB] = 0;
294                 float md = M0 + AP.ROpenB;
295                 Drow[LB] += AP.RExtB;
296                 SAVE_TBD(i+1, LB, 'D');
297                 if (md >= Drow[LB])
298                         {
299                         Drow[LB] = md;
300                         TBrow[LB] = TRACEBITS_MD;
301                         SAVE_TBD(i+1, LB, 'M');
302                         }
303         // Drow[LB] = DPD[i+1][LB]
304                 SAVE_DPD(i+1, LB, Drow[LB]);
305                 }
306                 
307                 SAVE_DPM(i+1, 0, MINUS_INFINITY);
308                 M0 = MINUS_INFINITY;
309
310                 OpenA = AP.OpenA;
311                 ExtA = AP.ExtA;
312                 }
313         
314         SAVE_TBM(LA, 0, '?');
315
316 // Special case for last row of DPI
317         byte *TBrow = TB[LA];
318         float I1 = MINUS_INFINITY;
319
320         SAVE_DPI(LA, 0, MINUS_INFINITY);
321         SAVE_TBI(LA, 0, '?');
322
323         SAVE_DPI(LA, 1, MINUS_INFINITY);
324         SAVE_TBI(LA, 1, '?');
325
326         for (unsigned j = 1; j < LB; ++j)
327                 {
328         // Mrow[j-1] = DPM[LA][j]
329         // I1 = DPI[LA][j]
330                 
331                 TBrow[j] = 0;
332                 float mi = Mrow[int(j)-1] + AP.ROpenA;
333                 I1 += AP.RExtA;
334                 SAVE_TBI(LA, j+1, 'I');
335                 if (mi > I1)
336                         {
337                         I1 = mi;
338                         TBrow[j] = TRACEBITS_MI;
339                         SAVE_TBI(LA, j+1, 'M');
340                         }
341                 SAVE_DPI(LA, j+1, I1);
342                 }
343         
344         float FinalM = Mrow[LB-1];
345         float FinalD = Drow[LB];
346         float FinalI = I1;
347 // FinalM = DPM[LA][LB]
348 // FinalD = DPD[LA][LB]
349 // FinalI = DPI[LA][LB]
350         
351         float Score = FinalM;
352         byte State = 'M';
353         if (FinalD > Score)
354                 {
355                 Score = FinalD;
356                 State = 'D';
357                 }
358         if (FinalI > Score)
359                 {
360                 Score = FinalI;
361                 State = 'I';
362                 }
363
364         EndTimer(ViterbiFast);
365         TraceBackBit(LA, LB, State, PD);
366
367 #if     SAVE_FAST
368         g_MxDPM.LogMe();
369         g_MxDPD.LogMe();
370         g_MxDPI.LogMe();
371
372         g_MxTBM.LogMe();
373         g_MxTBD.LogMe();
374         g_MxTBI.LogMe();
375 #endif
376
377         return Score;
378         }