{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Sine-skewed sine (bivariate von Mises) mixture\n\nThis example models the dihedral angles that occur in the backbone of a protein as a mixture of skewed\ndirectional distributions. The backbone angle pairs, called $\\phi$ and $\\psi$, are a canonical\nrepresentation for the fold of a protein. In this model, we fix the third dihedral angle (omega) as it usually only\ntakes angles 0 and pi radian, with the latter being the most common. We model the angle pairs as a distribution on\nthe torus using the sine distribution [1] and break point-wise (toroidal) symmetry using sine-skewing [2].\n\n<img src=\"file://../_static/img/examples/ssbvm_mixture_torus_top.png\" align=\"center\" scale=\"30%\">\n\n**References:**\n\n    1. Singh et al. (2002). Probabilistic model for two dependent circular variables. Biometrika.\n    2. Jose Ameijeiras-Alonso and Christophe Ley (2021). Sine-skewed toroidal distributions and their application\n       in protein bioinformatics. Biostatistics.\n\n<img src=\"file://../_static/img/examples/ssbvm_mixture.png\" align=\"center\" scale=\"125%\">\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport math\nfrom math import pi\n\nimport matplotlib.colors\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom sklearn.cluster import KMeans\n\nfrom jax import numpy as jnp, random\n\nimport numpyro\nfrom numpyro.distributions import (\n    Beta,\n    Categorical,\n    Dirichlet,\n    Gamma,\n    Normal,\n    SineBivariateVonMises,\n    SineSkewed,\n    Uniform,\n    VonMises,\n)\nfrom numpyro.distributions.transforms import L1BallTransform\nfrom numpyro.examples.datasets import NINE_MERS, load_dataset\nfrom numpyro.infer import MCMC, NUTS, Predictive, init_to_value\nfrom numpyro.infer.reparam import CircularReparam\n\nAMINO_ACIDS = [\n    \"M\",\n    \"N\",\n    \"I\",\n    \"F\",\n    \"E\",\n    \"L\",\n    \"R\",\n    \"D\",\n    \"G\",\n    \"K\",\n    \"Y\",\n    \"T\",\n    \"H\",\n    \"S\",\n    \"P\",\n    \"A\",\n    \"V\",\n    \"Q\",\n    \"W\",\n    \"C\",\n]\n\n\n# The support of the von Mises is [-\u03c0,\u03c0) with a periodic boundary at \u00b1\u03c0. However, the support of\n# the implemented von Mises distribution is just the interval [-\u03c0,\u03c0) without the periodic boundary. If the\n# loc is close to one of the boundaries (-\u03c0 or \u03c0), the sampler must traverse the entire interval to cross the\n# boundary. This produces a bias, especially if the concentration is high. The interval around\n# zero will have a low probability, making the jump to the other boundary unlikely for the sampler.\n# Using the `CircularReparam` introduces the periodic boundary by transforming the real line to [-\u03c0,\u03c0).\n# The sampler can sample from the real line, thus crossing the periodic boundary without having to traverse the\n# the entire interval, which eliminates the bias.\n@numpyro.handlers.reparam(\n    config={\"phi_loc\": CircularReparam(), \"psi_loc\": CircularReparam()}\n)\ndef ss_model(data, num_data, num_mix_comp=2):\n    # Mixture prior\n    mix_weights = numpyro.sample(\"mix_weights\", Dirichlet(jnp.ones((num_mix_comp,))))\n\n    # Hprior BvM\n    # Bayesian Inference and Decision Theory by Kathryn Blackmond Laskey\n    beta_mean_phi = numpyro.sample(\"beta_mean_phi\", Uniform(0.0, 1.0))\n    beta_count_phi = numpyro.sample(\n        \"beta_count_phi\", Gamma(1.0, 1.0 / num_mix_comp)\n    )  # shape, rate\n    halpha_phi = beta_mean_phi * beta_count_phi\n    beta_mean_psi = numpyro.sample(\"beta_mean_psi\", Uniform(0, 1.0))\n    beta_count_psi = numpyro.sample(\n        \"beta_count_psi\", Gamma(1.0, 1.0 / num_mix_comp)\n    )  # shape, rate\n    halpha_psi = beta_mean_psi * beta_count_psi\n\n    with numpyro.plate(\"mixture\", num_mix_comp):\n        # BvM priors\n\n        # Place gap in forbidden region of the Ramachandran plot (protein backbone dihedral angle pairs)\n        phi_loc = numpyro.sample(\"phi_loc\", VonMises(pi, 2.0))\n        psi_loc = numpyro.sample(\"psi_loc\", VonMises(0.0, 0.1))\n\n        phi_conc = numpyro.sample(\n            \"phi_conc\", Beta(halpha_phi, beta_count_phi - halpha_phi)\n        )\n        psi_conc = numpyro.sample(\n            \"psi_conc\", Beta(halpha_psi, beta_count_psi - halpha_psi)\n        )\n        corr_scale = numpyro.sample(\"corr_scale\", Beta(2.0, 10.0))\n\n        # Skewness prior\n        ball_transform = L1BallTransform()\n        skewness = numpyro.sample(\"skewness\", Normal(0, 0.5).expand((2,)).to_event(1))\n        skewness = ball_transform(skewness)\n\n    with numpyro.plate(\"obs_plate\", num_data, dim=-1):\n        assign = numpyro.sample(\n            \"mix_comp\", Categorical(mix_weights), infer={\"enumerate\": \"parallel\"}\n        )\n        sine = SineBivariateVonMises(\n            phi_loc=phi_loc[assign],\n            psi_loc=psi_loc[assign],\n            # These concentrations are an order of magnitude lower than expected (550-1000)!\n            phi_concentration=70 * phi_conc[assign],\n            psi_concentration=70 * psi_conc[assign],\n            weighted_correlation=corr_scale[assign],\n        )\n        return numpyro.sample(\"phi_psi\", SineSkewed(sine, skewness[assign]), obs=data)\n\n\ndef run_hmc(rng_key, model, data, num_mix_comp, args, bvm_init_locs):\n    kernel = NUTS(\n        model, init_strategy=init_to_value(values=bvm_init_locs), max_tree_depth=7\n    )\n    mcmc = MCMC(kernel, num_samples=args.num_samples, num_warmup=args.num_warmup)\n    mcmc.run(rng_key, data, len(data), num_mix_comp)\n    mcmc.print_summary()\n    post_samples = mcmc.get_samples()\n    return post_samples\n\n\ndef fetch_aa_dihedrals(aa):\n    _, fetch = load_dataset(NINE_MERS, split=aa)\n    return jnp.stack(fetch())\n\n\ndef num_mix_comps(amino_acid):\n    num_mix = {\"G\": 10, \"P\": 7}\n    return num_mix.get(amino_acid, 9)\n\n\ndef ramachandran_plot(data, pred_data, aas, file_name=\"ssbvm_mixture.pdf\"):\n    amino_acids = {\"S\": \"Serine\", \"P\": \"Proline\", \"G\": \"Glycine\"}\n    fig, axss = plt.subplots(2, len(aas))\n    cdata = data\n    for i in range(len(axss)):\n        if i == 1:\n            cdata = pred_data\n        for ax, aa in zip(axss[i], aas):\n            aa_data = cdata[aa]\n            nbins = 50\n            ax.hexbin(\n                aa_data[..., 0].reshape(-1),\n                aa_data[..., 1].reshape(-1),\n                norm=matplotlib.colors.LogNorm(),\n                bins=nbins,\n                gridsize=100,\n                cmap=\"Blues\",\n            )\n\n            # label the contours\n            ax.set_aspect(\"equal\", \"box\")\n            ax.set_xlim([-math.pi, math.pi])\n            ax.set_ylim([-math.pi, math.pi])\n            ax.xaxis.set_major_locator(plt.MultipleLocator(np.pi / 2))\n            ax.xaxis.set_minor_locator(plt.MultipleLocator(np.pi / 12))\n            ax.xaxis.set_major_formatter(plt.FuncFormatter(multiple_formatter()))\n            ax.yaxis.set_major_locator(plt.MultipleLocator(np.pi / 2))\n            ax.yaxis.set_minor_locator(plt.MultipleLocator(np.pi / 12))\n            ax.yaxis.set_major_formatter(plt.FuncFormatter(multiple_formatter()))\n            if i == 0:\n                axtop = ax.secondary_xaxis(\"top\")\n                axtop.set_xlabel(amino_acids[aa])\n                axtop.xaxis.set_major_locator(plt.MultipleLocator(np.pi / 2))\n                axtop.xaxis.set_minor_locator(plt.MultipleLocator(np.pi / 12))\n                axtop.xaxis.set_major_formatter(plt.FuncFormatter(multiple_formatter()))\n\n            if i == 1:\n                ax.set_xlabel(r\"$\\phi$\")\n\n    for i in range(len(axss)):\n        axss[i, 0].set_ylabel(r\"$\\psi$\")\n        axss[i, 0].yaxis.set_major_locator(plt.MultipleLocator(np.pi / 2))\n        axss[i, 0].yaxis.set_minor_locator(plt.MultipleLocator(np.pi / 12))\n        axss[i, 0].yaxis.set_major_formatter(plt.FuncFormatter(multiple_formatter()))\n        axright = axss[i, -1].secondary_yaxis(\"right\")\n        axright.set_ylabel(\"data\" if i == 0 else \"simulation\")\n        axright.yaxis.set_major_locator(plt.MultipleLocator(np.pi / 2))\n        axright.yaxis.set_minor_locator(plt.MultipleLocator(np.pi / 12))\n        axright.yaxis.set_major_formatter(plt.FuncFormatter(multiple_formatter()))\n\n    for ax in axss[:, 1:].reshape(-1):\n        ax.tick_params(labelleft=False)\n        ax.tick_params(labelleft=False)\n\n    for ax in axss[0, :].reshape(-1):\n        ax.tick_params(labelbottom=False)\n        ax.tick_params(labelbottom=False)\n\n    if file_name:\n        fig.tight_layout()\n        plt.savefig(file_name, bbox_inches=\"tight\")\n\n\ndef multiple_formatter(denominator=2, number=np.pi, latex=r\"\\pi\"):\n    def gcd(a, b):\n        while b:\n            a, b = b, a % b\n        return a\n\n    def _multiple_formatter(x, pos):\n        den = denominator\n        num = int(np.rint(den * x / number))\n        com = gcd(num, den)\n        (num, den) = (int(num / com), int(den / com))\n        if den == 1:\n            if num == 0:\n                return r\"$0$\"\n            if num == 1:\n                return r\"$%s$\" % latex\n            elif num == -1:\n                return r\"$-%s$\" % latex\n            else:\n                return r\"$%s%s$\" % (num, latex)\n        else:\n            if num == 1:\n                return r\"$\\frac{%s}{%s}$\" % (latex, den)\n            elif num == -1:\n                return r\"$\\frac{-%s}{%s}$\" % (latex, den)\n            else:\n                return r\"$\\frac{%s%s}{%s}$\" % (num, latex, den)\n\n    return _multiple_formatter\n\n\ndef main(args):\n    data = {}\n    pred_datas = {}\n    rng_key = random.PRNGKey(args.rng_seed)\n    for aa in args.amino_acids:\n        rng_key, inf_key, pred_key = random.split(rng_key, 3)\n        data[aa] = fetch_aa_dihedrals(aa)\n        num_mix_comp = num_mix_comps(aa)\n\n        # Use kmeans to initialize the chain location.\n        kmeans = KMeans(num_mix_comp)\n        kmeans.fit(data[aa])\n        means = {\n            \"phi_loc\": kmeans.cluster_centers_[:, 0],\n            \"psi_loc\": kmeans.cluster_centers_[:, 1],\n        }\n\n        posterior_samples = {\n            \"ss\": run_hmc(inf_key, ss_model, data[aa], num_mix_comp, args, means)\n        }\n        predictive = Predictive(ss_model, posterior_samples[\"ss\"], parallel=True)\n\n        pred_datas[aa] = predictive(pred_key, None, 1, num_mix_comp)[\"phi_psi\"].reshape(\n            -1, 2\n        )\n\n    ramachandran_plot(data, pred_datas, args.amino_acids)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Sine-skewed sine (bivariate von mises) mixture model example\"\n    )\n    parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=1000, type=int)\n    parser.add_argument(\"--num-warmup\", nargs=\"?\", default=500, type=int)\n    parser.add_argument(\"--amino-acids\", nargs=\"+\", default=[\"S\", \"P\", \"G\"])\n    parser.add_argument(\"--rng_seed\", type=int, default=123)\n    parser.add_argument(\"--device\", default=\"gpu\", type=str, help='use \"cpu\" or \"gpu\".')\n\n    args = parser.parse_args()\n    assert all(\n        aa in AMINO_ACIDS for aa in args.amino_acids\n    ), f\"{list(filter(lambda aa: aa not in AMINO_ACIDS, args.amino_acids))} are not amino acids.\"\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}