]> git.donarmstrong.com Git - samtools.git/blobdiff - bcftools/prob1.c
added the haploid mode
[samtools.git] / bcftools / prob1.c
index 193c4a0c535126496b254869a6aee193ac4ebbf6..4804e6e24c3c6787f2ca3a18fb6e83687f379ed4 100644 (file)
@@ -3,6 +3,7 @@
 #include <string.h>
 #include <stdio.h>
 #include <errno.h>
+#include <assert.h>
 #include "prob1.h"
 
 #include "kseq.h"
@@ -33,6 +34,7 @@ unsigned char seq_nt4_table[256] = {
 
 struct __bcf_p1aux_t {
        int n, M, n1, is_indel;
+       uint8_t *ploidy; // haploid or diploid ONLY
        double *q2p, *pdg; // pdg -> P(D|g)
        double *phi, *phi_indel;
        double *z, *zswap; // aux for afs
@@ -123,25 +125,34 @@ int bcf_p1_read_prior(bcf_p1aux_t *ma, const char *fn)
        return 0;
 }
 
-bcf_p1aux_t *bcf_p1_init(int n)
+bcf_p1aux_t *bcf_p1_init(int n, uint8_t *ploidy)
 {
        bcf_p1aux_t *ma;
        int i;
        ma = calloc(1, sizeof(bcf_p1aux_t));
        ma->n1 = -1;
        ma->n = n; ma->M = 2 * n;
+       if (ploidy) {
+               ma->ploidy = malloc(n);
+               memcpy(ma->ploidy, ploidy, n);
+               for (i = 0, ma->M = 0; i < n; ++i) ma->M += ploidy[i];
+               if (ma->M == 2 * n) {
+                       free(ma->ploidy);
+                       ma->ploidy = 0;
+               }
+       }
        ma->q2p = calloc(256, sizeof(double));
        ma->pdg = calloc(3 * ma->n, sizeof(double));
        ma->phi = calloc(ma->M + 1, sizeof(double));
        ma->phi_indel = calloc(ma->M + 1, sizeof(double));
        ma->phi1 = calloc(ma->M + 1, sizeof(double));
        ma->phi2 = calloc(ma->M + 1, sizeof(double));
-       ma->z = calloc(2 * ma->n + 1, sizeof(double));
-       ma->zswap = calloc(2 * ma->n + 1, sizeof(double));
+       ma->z = calloc(ma->M + 1, sizeof(double));
+       ma->zswap = calloc(ma->M + 1, sizeof(double));
        ma->z1 = calloc(ma->M + 1, sizeof(double)); // actually we do not need this large
        ma->z2 = calloc(ma->M + 1, sizeof(double));
-       ma->afs = calloc(2 * ma->n + 1, sizeof(double));
-       ma->afs1 = calloc(2 * ma->n + 1, sizeof(double));
+       ma->afs = calloc(ma->M + 1, sizeof(double));
+       ma->afs1 = calloc(ma->M + 1, sizeof(double));
        for (i = 0; i < 256; ++i)
                ma->q2p[i] = pow(10., -i / 10.);
        bcf_p1_init_prior(ma, MC_PTYPE_FULL, 1e-3); // the simplest prior
@@ -151,6 +162,10 @@ bcf_p1aux_t *bcf_p1_init(int n)
 int bcf_p1_set_n1(bcf_p1aux_t *b, int n1)
 {
        if (n1 == 0 || n1 >= b->n) return -1;
+       if (b->M != b->n * 2) {
+               fprintf(stderr, "[%s] unable to set `n1' when there are haploid samples.\n", __func__);
+               return -1;
+       }
        b->n1 = n1;
        return 0;
 }
@@ -158,7 +173,7 @@ int bcf_p1_set_n1(bcf_p1aux_t *b, int n1)
 void bcf_p1_destroy(bcf_p1aux_t *ma)
 {
        if (ma) {
-               free(ma->q2p); free(ma->pdg);
+               free(ma->ploidy); free(ma->q2p); free(ma->pdg);
                free(ma->phi); free(ma->phi_indel); free(ma->phi1); free(ma->phi2);
                free(ma->z); free(ma->zswap); free(ma->z1); free(ma->z2);
                free(ma->afs); free(ma->afs1);
@@ -207,11 +222,16 @@ int bcf_p1_call_gt(const bcf_p1aux_t *ma, double f0, int k)
 {
        double sum, g[3];
        double max, f3[3], *pdg = ma->pdg + k * 3;
-       int q, i, max_i;
-       f3[0] = (1.-f0)*(1.-f0); f3[1] = 2.*f0*(1.-f0); f3[2] = f0*f0;
+       int q, i, max_i, ploidy;
+       ploidy = ma->ploidy? ma->ploidy[k] : 2;
+       if (ploidy == 2) {
+               f3[0] = (1.-f0)*(1.-f0); f3[1] = 2.*f0*(1.-f0); f3[2] = f0*f0;
+       } else {
+               f3[0] = 1. - f0; f3[1] = 0; f3[2] = f0;
+       }
        for (i = 0, sum = 0.; i < 3; ++i)
                sum += (g[i] = pdg[i] * f3[i]);
-       for (i = 0, max = -1., max_i = 0; i < 3; ++i) {
+       for (i = 0, max = -1., max_i = 0; i <= ploidy; ++i) {
                g[i] /= sum;
                if (g[i] > max) max = g[i], max_i = i;
        }
@@ -228,6 +248,7 @@ static void mc_cal_y_core(bcf_p1aux_t *ma, int beg)
 {
        double *z[2], *tmp, *pdg;
        int _j, last_min, last_max;
+       assert(beg == 0 || ma->M == ma->n*2);
        z[0] = ma->z;
        z[1] = ma->zswap;
        pdg = ma->pdg;
@@ -236,41 +257,81 @@ static void mc_cal_y_core(bcf_p1aux_t *ma, int beg)
        z[0][0] = 1.;
        last_min = last_max = 0;
        ma->t = 0.;
-       for (_j = beg; _j < ma->n; ++_j) {
-               int k, j = _j - beg, _min = last_min, _max = last_max;
-               double p[3], sum;
-               pdg = ma->pdg + _j * 3;
-               p[0] = pdg[0]; p[1] = 2. * pdg[1]; p[2] = pdg[2];
-               for (; _min < _max && z[0][_min] < TINY; ++_min) z[0][_min] = z[1][_min] = 0.;
-               for (; _max > _min && z[0][_max] < TINY; --_max) z[0][_max] = z[1][_max] = 0.;
-               _max += 2;
-               if (_min == 0) 
-                       k = 0, z[1][k] = (2*j+2-k)*(2*j-k+1) * p[0] * z[0][k];
-               if (_min <= 1)
-                       k = 1, z[1][k] = (2*j+2-k)*(2*j-k+1) * p[0] * z[0][k] + k*(2*j+2-k) * p[1] * z[0][k-1];
-               for (k = _min < 2? 2 : _min; k <= _max; ++k)
-                       z[1][k] = (2*j+2-k)*(2*j-k+1) * p[0] * z[0][k]
-                               + k*(2*j+2-k) * p[1] * z[0][k-1]
-                               + k*(k-1)* p[2] * z[0][k-2];
-               for (k = _min, sum = 0.; k <= _max; ++k) sum += z[1][k];
-               ma->t += log(sum / ((2. * j + 2) * (2. * j + 1)));
-               for (k = _min; k <= _max; ++k) z[1][k] /= sum;
-               if (_min >= 1) z[1][_min-1] = 0.;
-               if (_min >= 2) z[1][_min-2] = 0.;
-               if (j < ma->n - 1) z[1][_max+1] = z[1][_max+2] = 0.;
-               if (_j == ma->n1 - 1) { // set pop1
-                       ma->t1 = ma->t;
-                       memcpy(ma->z1, z[1], sizeof(double) * (ma->n1 * 2 + 1));
+       if (ma->M == ma->n * 2) {
+               for (_j = beg; _j < ma->n; ++_j) {
+                       int k, j = _j - beg, _min = last_min, _max = last_max;
+                       double p[3], sum;
+                       pdg = ma->pdg + _j * 3;
+                       p[0] = pdg[0]; p[1] = 2. * pdg[1]; p[2] = pdg[2];
+                       for (; _min < _max && z[0][_min] < TINY; ++_min) z[0][_min] = z[1][_min] = 0.;
+                       for (; _max > _min && z[0][_max] < TINY; --_max) z[0][_max] = z[1][_max] = 0.;
+                       _max += 2;
+                       if (_min == 0) 
+                               k = 0, z[1][k] = (2*j+2-k)*(2*j-k+1) * p[0] * z[0][k];
+                       if (_min <= 1)
+                               k = 1, z[1][k] = (2*j+2-k)*(2*j-k+1) * p[0] * z[0][k] + k*(2*j+2-k) * p[1] * z[0][k-1];
+                       for (k = _min < 2? 2 : _min; k <= _max; ++k)
+                               z[1][k] = (2*j+2-k)*(2*j-k+1) * p[0] * z[0][k]
+                                       + k*(2*j+2-k) * p[1] * z[0][k-1]
+                                       + k*(k-1)* p[2] * z[0][k-2];
+                       for (k = _min, sum = 0.; k <= _max; ++k) sum += z[1][k];
+                       ma->t += log(sum / ((2. * j + 2) * (2. * j + 1)));
+                       for (k = _min; k <= _max; ++k) z[1][k] /= sum;
+                       if (_min >= 1) z[1][_min-1] = 0.;
+                       if (_min >= 2) z[1][_min-2] = 0.;
+                       if (j < ma->n - 1) z[1][_max+1] = z[1][_max+2] = 0.;
+                       if (_j == ma->n1 - 1) { // set pop1; ma->n1==-1 when unset
+                               ma->t1 = ma->t;
+                               memcpy(ma->z1, z[1], sizeof(double) * (ma->n1 * 2 + 1));
+                       }
+                       tmp = z[0]; z[0] = z[1]; z[1] = tmp;
+                       last_min = _min; last_max = _max;
+               }
+       } else { // this block is very similar to the block above; these two might be merged in future
+               int j, M = 0;
+               for (j = 0; j < ma->n; ++j) {
+                       int k, M0, _min = last_min, _max = last_max;
+                       double p[3], sum;
+                       pdg = ma->pdg + j * 3;
+                       for (; _min < _max && z[0][_min] < TINY; ++_min) z[0][_min] = z[1][_min] = 0.;
+                       for (; _max > _min && z[0][_max] < TINY; --_max) z[0][_max] = z[1][_max] = 0.;
+                       M0 = M;
+                       M += ma->ploidy[j];
+                       if (ma->ploidy[j] == 1) {
+                               p[0] = pdg[0]; p[1] = pdg[2];
+                               _max++;
+                               if (_min == 0) k = 0, z[1][k] = (M0+1-k) * p[0] * z[0][k];
+                               for (k = _min < 1? 1 : _min; k <= _max; ++k)
+                                       z[1][k] = (M0+1-k) * p[0] * z[0][k] + k * p[1] * z[0][k-1];
+                               for (k = _min, sum = 0.; k <= _max; ++k) sum += z[1][k];
+                               ma->t += log(sum / M);
+                               for (k = _min; k <= _max; ++k) z[1][k] /= sum;
+                               if (_min >= 1) z[1][_min-1] = 0.;
+                               if (j < ma->n - 1) z[1][_max+1] = 0.;
+                       } else if (ma->ploidy[j] == 2) {
+                               p[0] = pdg[0]; p[1] = 2 * pdg[1]; p[2] = pdg[2];
+                               _max += 2;
+                               if (_min == 0) k = 0, z[1][k] = (M0-k+1) * (M0-k+2) * p[0] * z[0][k];
+                               if (_min <= 1) k = 1, z[1][k] = (M0-k+1) * (M0-k+2) * p[0] * z[0][k] + k*(M0-k+2) * p[1] * z[0][k-1];
+                               for (k = _min < 2? 2 : _min; k <= _max; ++k)
+                                       z[1][k] = (M0-k+1)*(M0-k+2) * p[0] * z[0][k] + k*(M0-k+2) * p[1] * z[0][k-1] + k*(k-1)* p[2] * z[0][k-2];
+                               for (k = _min, sum = 0.; k <= _max; ++k) sum += z[1][k];
+                               ma->t += log(sum / (M * (M - 1.)));
+                               for (k = _min; k <= _max; ++k) z[1][k] /= sum;
+                               if (_min >= 1) z[1][_min-1] = 0.;
+                               if (_min >= 2) z[1][_min-2] = 0.;
+                               if (j < ma->n - 1) z[1][_max+1] = z[1][_max+2] = 0.;
+                       }
+                       tmp = z[0]; z[0] = z[1]; z[1] = tmp;
+                       last_min = _min; last_max = _max;
                }
-               tmp = z[0]; z[0] = z[1]; z[1] = tmp;
-               last_min = _min; last_max = _max;
        }
        if (z[0] != ma->z) memcpy(ma->z, z[0], sizeof(double) * (ma->M + 1));
 }
 
 static void mc_cal_y(bcf_p1aux_t *ma)
 {
-       if (ma->n1 > 0 && ma->n1 < ma->n) {
+       if (ma->n1 > 0 && ma->n1 < ma->n && ma->M == ma->n * 2) { // NB: ma->n1 is ineffective when there are haploid samples
                int k;
                long double x;
                memset(ma->z1, 0, sizeof(double) * (2 * ma->n1 + 1));
@@ -337,32 +398,6 @@ static double mc_cal_afs(bcf_p1aux_t *ma, double *p_ref_folded, double *p_var_fo
        return sum / ma->M;
 }
 
-long double bcf_p1_cal_g3(bcf_p1aux_t *p1a, double g[3])
-{
-       long double pd = 0., g2[3];
-       int i, k;
-       memset(g2, 0, sizeof(long double) * 3);
-       for (k = 0; k < p1a->M; ++k) {
-               double f = (double)k / p1a->M, f3[3], g1[3];
-               long double z = 1.;
-               g1[0] = g1[1] = g1[2] = 0.;
-               f3[0] = (1. - f) * (1. - f); f3[1] = 2. * f * (1. - f); f3[2] = f * f;
-               for (i = 0; i < p1a->n; ++i) {
-                       double *pdg = p1a->pdg + i * 3;
-                       double x = pdg[0] * f3[0] + pdg[1] * f3[1] + pdg[2] * f3[2];
-                       z *= x;
-                       g1[0] += pdg[0] * f3[0] / x;
-                       g1[1] += pdg[1] * f3[1] / x;
-                       g1[2] += pdg[2] * f3[2] / x;
-               }
-               pd += p1a->phi[k] * z;
-               for (i = 0; i < 3; ++i)
-                       g2[i] += p1a->phi[k] * z * g1[i];
-       }
-       for (i = 0; i < 3; ++i) g[i] = g2[i] / pd;
-       return pd;
-}
-
 int bcf_p1_cal(const bcf1_t *b, bcf_p1aux_t *ma, bcf_p1rst_t *rst)
 {
        int i, k;