{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Deep Markov Model inferred using SteinVI\nIn this example we infer a deep Markov model (DMM) using SteinVI for generating music\n(chorales by Johan Sebastian Bach).\n\nThe model DMM based on reference [1][2] and the Pyro DMM example: https://pyro.ai/examples/dmm.html.\n\n**Reference:**\n\n    1. Pathwise Derivatives for Multivariate Distributions Martin Jankowiak and Theofanis Karaletsos (2019)\n    2. Structured Inference Networks for Nonlinear State Space Models [arXiv:1609.09869]\n        Rahul G. Krishnan, Uri Shalit and David Sontag (2016)\n\n<img src=\"file://../_static/img/examples/stein_dmm.png\" align=\"center\">\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\n\nimport numpy as np\n\nimport jax\nfrom jax import nn, numpy as jnp, random\nfrom optax import adam, chain\n\nimport numpyro\nfrom numpyro.contrib.einstein import SteinVI\nfrom numpyro.contrib.einstein.mixture_guide_predictive import MixtureGuidePredictive\nfrom numpyro.contrib.einstein.stein_kernels import RBFKernel\nimport numpyro.distributions as dist\nfrom numpyro.examples.datasets import JSB_CHORALES, load_dataset\nfrom numpyro.optim import optax_to_numpyro\n\n\ndef _reverse_padded(padded, lengths):\n    def _reverse_single(p, length):\n        new = jnp.zeros_like(p)\n        reverse = jnp.roll(p[::-1], length, axis=0)\n        return new.at[:].set(reverse)\n\n    return jax.vmap(_reverse_single)(padded, lengths)\n\n\ndef load_data(split=\"train\"):\n    _, fetch = load_dataset(JSB_CHORALES, split=split)\n    lengths, seqs = fetch(0)\n    return (seqs, _reverse_padded(seqs, lengths), lengths)\n\n\ndef emitter(x, params):\n    \"\"\"Parameterizes the bernoulli observation likelihood `p(x_t | z_t)`\"\"\"\n    l1 = nn.relu(jnp.matmul(x, params[\"l1\"]))\n    l2 = nn.relu(jnp.matmul(l1, params[\"l2\"]))\n    return jnp.matmul(l2, params[\"l3\"])\n\n\ndef transition(x, params):\n    \"\"\"Parameterizes the gaussian latent transition probability `p(z_t | z_{t-1})`\n    See section 5 in [1].\n\n    **Reference:**\n        1. Structured Inference Networks for Nonlinear State Space Models [arXiv:1609.09869]\n        Rahul G. Krishnan, Uri Shalit and David Sontag (2016)\n    \"\"\"\n\n    def _gate(x, params):\n        l1 = nn.relu(jnp.matmul(x, params[\"l1\"]))\n        return nn.sigmoid(jnp.matmul(l1, params[\"l2\"]))\n\n    def _shared(x, params):\n        l1 = nn.relu(jnp.matmul(x, params[\"l1\"]))\n        return jnp.matmul(l1, params[\"l2\"])\n\n    def _mean(x, params):\n        return jnp.matmul(x, params[\"l1\"])\n\n    def _std(x, params):\n        l1 = jnp.matmul(nn.relu(x), params[\"l1\"])\n        return nn.softplus(l1)\n\n    gt = _gate(x, params[\"gate\"])\n    ht = _shared(x, params[\"shared\"])\n    loc = (1 - gt) * _mean(x, params[\"mean\"]) + gt * ht\n    std = _std(ht, params[\"std\"])\n    return loc, std\n\n\ndef combiner(x, params):\n    mean = jnp.matmul(x, params[\"mean\"])\n    std = nn.softplus(jnp.matmul(x, params[\"std\"]))\n    return mean, std\n\n\ndef gru(xs, lengths, init_hidden, params):\n    \"\"\"RNN with GRU. Based on https://github.com/google/jax/pull/2298\"\"\"\n\n    def apply_fun_single(state, inputs):\n        i, x = inputs\n        inp_update = jnp.matmul(x, params[\"update_in\"])\n        hidden_update = jnp.dot(state, params[\"update_weight\"])\n        update_gate = nn.sigmoid(inp_update + hidden_update)\n        reset_gate = nn.sigmoid(\n            jnp.matmul(x, params[\"reset_in\"]) + jnp.dot(state, params[\"reset_weight\"])\n        )\n        output_gate = update_gate * state + (1 - update_gate) * jnp.tanh(\n            jnp.matmul(x, params[\"out_in\"])\n            + jnp.dot(reset_gate * state, params[\"out_weight\"])\n        )\n        hidden = jnp.where((i < lengths)[:, None], output_gate, jnp.zeros_like(state))\n        return hidden, hidden\n\n    init_hidden = jnp.broadcast_to(init_hidden, (xs.shape[1], init_hidden.shape[1]))\n    return jax.lax.scan(apply_fun_single, init_hidden, (jnp.arange(xs.shape[0]), xs))\n\n\ndef _normal_init(*shape):\n    return lambda rng_key: dist.Normal(scale=0.1).sample(rng_key, shape)\n\n\ndef model(\n    seqs,\n    seqs_rev,\n    lengths,\n    *,\n    subsample_size=77,\n    latent_dim=32,\n    emission_dim=100,\n    transition_dim=200,\n    data_dim=88,\n    gru_dim=150,\n    annealing_factor=1.0,\n    predict=False,\n):\n    max_seq_length = seqs.shape[1]\n\n    emitter_params = {\n        \"l1\": numpyro.param(\"emitter_l1\", _normal_init(latent_dim, emission_dim)),\n        \"l2\": numpyro.param(\"emitter_l2\", _normal_init(emission_dim, emission_dim)),\n        \"l3\": numpyro.param(\"emitter_l3\", _normal_init(emission_dim, data_dim)),\n    }\n\n    trans_params = {\n        \"gate\": {\n            \"l1\": numpyro.param(\"gate_l1\", _normal_init(latent_dim, transition_dim)),\n            \"l2\": numpyro.param(\"gate_l2\", _normal_init(transition_dim, latent_dim)),\n        },\n        \"shared\": {\n            \"l1\": numpyro.param(\"shared_l1\", _normal_init(latent_dim, transition_dim)),\n            \"l2\": numpyro.param(\"shared_l2\", _normal_init(transition_dim, latent_dim)),\n        },\n        \"mean\": {\"l1\": numpyro.param(\"mean_l1\", _normal_init(latent_dim, latent_dim))},\n        \"std\": {\"l1\": numpyro.param(\"std_l1\", _normal_init(latent_dim, latent_dim))},\n    }\n\n    z0 = numpyro.param(\n        \"z0\", lambda rng_key: dist.Normal(0, 1.0).sample(rng_key, (latent_dim,))\n    )\n    z0 = jnp.broadcast_to(z0, (subsample_size, 1, latent_dim))\n    with numpyro.plate(\n        \"data\", seqs.shape[0], subsample_size=subsample_size, dim=-1\n    ) as idx:\n        if subsample_size == seqs.shape[0]:\n            seqs_batch = seqs\n            lengths_batch = lengths\n        else:\n            seqs_batch = seqs[idx]\n            lengths_batch = lengths[idx]\n\n        masks = jnp.repeat(\n            jnp.expand_dims(jnp.arange(max_seq_length), axis=0), subsample_size, axis=0\n        ) < jnp.expand_dims(lengths_batch, axis=-1)\n        # NB: Mask is to avoid scoring 'z' using distribution at this point\n        z = numpyro.sample(\n            \"z\",\n            dist.Normal(0.0, jnp.ones((max_seq_length, latent_dim)))\n            .mask(False)\n            .to_event(2),\n        )\n\n        z_shift = jnp.concatenate([z0, z[:, :-1, :]], axis=-2)\n        z_loc, z_scale = transition(z_shift, params=trans_params)\n\n        with numpyro.handlers.scale(scale=annealing_factor):\n            # Actually score 'z'\n            numpyro.sample(\n                \"z_aux\",\n                dist.Normal(z_loc, z_scale)\n                .mask(jnp.expand_dims(masks, axis=-1))\n                .to_event(2),\n                obs=z,\n            )\n\n        emission_probs = emitter(z, params=emitter_params)\n        if predict:\n            tunes = None\n        else:\n            tunes = seqs_batch\n        numpyro.sample(\n            \"tunes\",\n            dist.Bernoulli(logits=emission_probs)\n            .mask(jnp.expand_dims(masks, axis=-1))\n            .to_event(2),\n            obs=tunes,\n        )\n\n\ndef guide(\n    seqs,\n    seqs_rev,\n    lengths,\n    *,\n    subsample_size=77,\n    latent_dim=32,\n    emission_dim=100,\n    transition_dim=200,\n    data_dim=88,\n    gru_dim=150,\n    annealing_factor=1.0,\n    predict=False,\n):\n    max_seq_length = seqs.shape[1]\n    seqs_rev = jnp.transpose(seqs_rev, axes=(1, 0, 2))\n\n    combiner_params = {\n        \"mean\": numpyro.param(\"combiner_mean\", _normal_init(gru_dim, latent_dim)),\n        \"std\": numpyro.param(\"combiner_std\", _normal_init(gru_dim, latent_dim)),\n    }\n\n    gru_params = {\n        \"update_in\": numpyro.param(\"update_in\", _normal_init(data_dim, gru_dim)),\n        \"update_weight\": numpyro.param(\"update_weight\", _normal_init(gru_dim, gru_dim)),\n        \"reset_in\": numpyro.param(\"reset_in\", _normal_init(data_dim, gru_dim)),\n        \"reset_weight\": numpyro.param(\"reset_weight\", _normal_init(gru_dim, gru_dim)),\n        \"out_in\": numpyro.param(\"out_in\", _normal_init(data_dim, gru_dim)),\n        \"out_weight\": numpyro.param(\"out_weight\", _normal_init(gru_dim, gru_dim)),\n    }\n\n    with numpyro.plate(\n        \"data\", seqs.shape[0], subsample_size=subsample_size, dim=-1\n    ) as idx:\n        if subsample_size == seqs.shape[0]:\n            seqs_rev_batch = seqs_rev\n            lengths_batch = lengths\n        else:\n            seqs_rev_batch = seqs_rev[:, idx, :]\n            lengths_batch = lengths[idx]\n\n        masks = jnp.repeat(\n            jnp.expand_dims(jnp.arange(max_seq_length), axis=0), subsample_size, axis=0\n        ) < jnp.expand_dims(lengths_batch, axis=-1)\n\n        h0 = numpyro.param(\n            \"h0\",\n            lambda rng_key: dist.Normal(0.0, 1).sample(rng_key, (1, gru_dim)),\n        )\n        _, hs = gru(seqs_rev_batch, lengths_batch, h0, gru_params)\n        hs = _reverse_padded(jnp.transpose(hs, axes=(1, 0, 2)), lengths_batch)\n        with numpyro.handlers.scale(scale=annealing_factor):\n            numpyro.sample(\n                \"z\",\n                dist.Normal(*combiner(hs, combiner_params))\n                .mask(jnp.expand_dims(masks, axis=-1))\n                .to_event(2),\n            )\n\n\ndef vis_tune(i, tunes, lengths, name=\"stein_dmm.pdf\"):\n    tune = tunes[i, : lengths[i]]\n    try:\n        from music21.chord import Chord\n        from music21.pitch import Pitch\n        from music21.stream import Stream\n\n        stream = Stream()\n        for chord in tune:\n            stream.append(\n                Chord(list(Pitch(pitch) for pitch in (np.arange(88) + 21)[chord > 0]))\n            )\n        plot = stream.plot(doneAction=None)\n        plot.write(name)\n    except ModuleNotFoundError:\n        import matplotlib.pyplot as plt\n\n        plt.imshow(tune.T, cmap=\"Greys\")\n        plt.ylabel(\"Pitch\")\n        plt.xlabel(\"Offset\")\n        plt.savefig(name)\n\n\ndef main(args):\n    inf_key, pred_key = random.split(random.PRNGKey(seed=args.rng_seed), 2)\n\n    steinvi = SteinVI(\n        model,\n        guide,\n        optax_to_numpyro(chain(adam(1e-2))),\n        RBFKernel(),\n        num_elbo_particles=args.num_elbo_particles,\n        num_stein_particles=args.num_stein_particles,\n    )\n\n    seqs, rev_seqs, lengths = load_data()\n    results = steinvi.run(\n        inf_key,\n        args.max_iter,\n        seqs,\n        rev_seqs,\n        lengths,\n        gru_dim=args.gru_dim,\n        subsample_size=args.subsample_size,\n    )\n    pred = MixtureGuidePredictive(\n        model,\n        guide,\n        params=results.params,\n        num_samples=1,\n        guide_sites=steinvi.guide_sites,\n    )\n    seqs, rev_seqs, lengths = load_data(\"valid\")\n    pred_notes = pred(\n        pred_key, seqs, rev_seqs, lengths, subsample_size=seqs.shape[0], predict=True\n    )[\"tunes\"]\n\n    vis_tune(0, pred_notes[0], lengths)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--subsample-size\", type=int, default=10)\n    parser.add_argument(\"--max-iter\", type=int, default=100)\n    parser.add_argument(\"--repulsion\", type=float, default=1.0)\n    parser.add_argument(\"--verbose\", type=bool, default=True)\n    parser.add_argument(\"--num-stein-particles\", type=int, default=5)\n    parser.add_argument(\"--num-elbo-particles\", type=int, default=5)\n    parser.add_argument(\"--progress-bar\", type=bool, default=True)\n    parser.add_argument(\"--gru-dim\", type=int, default=150)\n    parser.add_argument(\"--rng-key\", type=int, default=142)\n    parser.add_argument(\"--device\", default=\"cpu\", choices=[\"gpu\", \"cpu\"])\n    parser.add_argument(\"--rng-seed\", default=142, type=int)\n\n    args = parser.parse_args()\n\n    numpyro.set_platform(args.device)\n\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}