{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6a50d078-8c7a-4c40-923b-8da300d8e93f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax, jax.numpy as jnp\n",
    "from jax.tree_util import tree_leaves\n",
    "import jax.tree_util as jtu\n",
    "\n",
    "import argparse\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import math\n",
    "import os\n",
    "import dataclasses\n",
    "from jax import lax\n",
    "plt.rcParams['font.size'] = 20\n",
    "\n",
    "jax.config.update('jax_platform_name', 'cpu') # switching to 'gpu' allows to scale better w.r.t the number of particles\n",
    "jax.config.update('jax_enable_x64', True)\n",
    "#jax.config.update(\"jax_disable_jit\", False)\n",
    "num_fits = 1 # only showing few replicates\n",
    "from src.stats.hmm import get_generative_model,HMM\n",
    "from src.variational import get_variational_model\n",
    "from src.utils.misc import tree_get_idx, tree_dropfirst, tree_droplast\n",
    "from src.training_lg import SVITrainer\n",
    "\n",
    "def set_defaults(args, default_std=0.1):\n",
    "    args.default_prior_mean = 0.0 # default value for the mean of Gaussian prior\n",
    "    args.default_prior_base_scale = default_std # default value for the diagonal components of the covariance matrix of the prior\n",
    "    args.default_transition_base_scale = default_std # default value for the diagonal components  of the covariance matrix of the transition kernel\n",
    "    args.default_transition_bias = 0.0\n",
    "    args.default_emission_base_scale = 0.1 # default value for the diagonal components  of the covariance matrix of the transition kernel\n",
    "    args.parametrization = 'cov_chol'\n",
    "    return args\n",
    "\n",
    "\n",
    "def parameter_mse(true_params, guess_params):\n",
    "    true_leaves = tree_leaves(true_params)\n",
    "    guess_leaves = tree_leaves(guess_params)\n",
    "    # Compute the MSE for each pair of leaves and then average them\n",
    "    mse_list = [jnp.mean((t - g)**2) for t, g in zip(true_leaves, guess_leaves)]\n",
    "    return jnp.mean(jnp.array(mse_list))\n",
    "\n",
    "\n",
    "def plot_x_true_against_x_pred(x_true, x_pred, y=None, save=False):\n",
    "    dims = x_true.shape[-1]\n",
    "    _ , axes = plt.subplots(dims, 1, figsize=(15,2*dims))\n",
    "    for dim in range(dims):\n",
    "        axes[dim].plot(x_true[:,dim], c='red', label='True', alpha=0.7)\n",
    "        axes[dim].plot(x_pred[:,dim], c='green', label='Pred', alpha=0.7)\n",
    "        axes[dim].legend()\n",
    "        if y is not None:\n",
    "            axes[dim].plot(y[:,dim], c='black', label='Obs', alpha=0.5)\n",
    "    if save: plt.savefig('test.pdf', format='pdf')   \n",
    "    \n",
    "def plot_data(y):\n",
    "    dims = y.shape[-1]\n",
    "    _ , axes = plt.subplots(dims, 1, figsize=(15,2*dims))\n",
    "    for dim in range(dims):\n",
    "        axes[dim].plot(y[:,dim], label='Data')\n",
    "\n",
    "def compute_rmse_x_true_against_x_pred(x_true, x_pred):\n",
    "    return jnp.mean(jnp.sqrt(jnp.mean((x_true-x_pred)**2, axis=-1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95a52049-2884-4083-9be8-04cc3214ed55",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num params: 60\n",
      "Monitor ELBO is analytical.\n",
      "['prior', 'transition.noise', 'emission.noise']\n",
      "USING SCORE ELBO.\n",
      "Using full gradients.\n",
      "Streaming on a single sequence only once.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Running for 5 iterations: 100%|██████████| 5/5 [00:00<00:00, 243.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Per-run files saved in: results/20260129_111223  (run_XX.npy with columns [F_mse, G_mse, ELBO])\n",
      "Per-run summary: summary.tsv\n",
      "Aggregate files: aggregate_so_far.npz (after each run) + aggregate.npz (final)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "/Users/matmul/anaconda3/envs/ovi_env/lib/python3.10/site-packages/numpy/_core/_methods.py:223: RuntimeWarning: Degrees of freedom <= 0 for slice\n",
      "  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n",
      "/Users/matmul/anaconda3/envs/ovi_env/lib/python3.10/site-packages/numpy/_core/_methods.py:212: RuntimeWarning: invalid value encountered in divide\n",
      "  ret = um.true_divide(\n"
     ]
    }
   ],
   "source": [
    "import argparse, time\n",
    "from pathlib import Path\n",
    "\n",
    "import jax, jax.numpy as jnp\n",
    "import numpy as np\n",
    "\n",
    "# --- your imports ---\n",
    "from src.stats.hmm import get_generative_model\n",
    "from src.variational import get_variational_model\n",
    "from src.training import SVITrainer\n",
    "\n",
    "# ---------- config ----------\n",
    "NUM_RUNS = 1\n",
    "ELBO_MODE = \"score,resampling,bptt_depth_4\"\n",
    "BASE_SEED_MODEL = 20000    \n",
    "BASE_SEED_TRAIN = 30000    \n",
    "\n",
    "NUM_PARTICLES = 10\n",
    "LR = 1e-2\n",
    "LR_MODEL = 1e-3\n",
    "TRAINING_MODE = \"streaming,1,difference\"\n",
    "\n",
    "\n",
    "def set_defaults(args, default_std=0.1):\n",
    "    args.default_prior_mean = 0.0\n",
    "    args.default_prior_base_scale = default_std\n",
    "    args.default_transition_base_scale = default_std\n",
    "    args.default_transition_bias = 0.0\n",
    "    args.default_emission_base_scale = 0.25\n",
    "    args.parametrization = \"cov_chol\"\n",
    "    return args\n",
    "\n",
    "p_args = argparse.Namespace()\n",
    "p_args.state_dim, p_args.obs_dim = 10, 10\n",
    "p_args.model = \"linear\"\n",
    "p_args.seq_length = 5\n",
    "p_args.emission_bias = False\n",
    "p_args.transition_bias = False\n",
    "p_args.transition_matrix_conditionning = \"diagonal\"\n",
    "p_args.emission_matrix_conditionning   = \"diagonal\"\n",
    "p_args.range_transition_map_params = (0.5, 1.0)\n",
    "p_args.range_emission_map_params   = (0.5, 1.0)\n",
    "p_args.num_seqs = 1\n",
    "p_args = set_defaults(p_args, default_std=0.1)\n",
    "\n",
    "def make_q_args_from_p(p_args, default_noise=0.1):\n",
    "    q_args = argparse.Namespace()\n",
    "    q_args.state_dim, q_args.obs_dim = p_args.state_dim, p_args.obs_dim\n",
    "    q_args.model = p_args.model\n",
    "    q_args.emission_bias = p_args.emission_bias\n",
    "    q_args.transition_bias = p_args.transition_bias\n",
    "    q_args.transition_matrix_conditionning = p_args.transition_matrix_conditionning\n",
    "    q_args.emission_matrix_conditionning   = p_args.emission_matrix_conditionning\n",
    "    q_args.range_transition_map_params = p_args.range_transition_map_params\n",
    "    q_args.range_emission_map_params   = p_args.range_emission_map_params\n",
    "    return set_defaults(q_args, default_noise)\n",
    "\n",
    "# ---------- output dir ----------\n",
    "outdir = Path(\"results\") / time.strftime(\"%Y%m%d_%H%M%S\")\n",
    "outdir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "summary_path = outdir / \"summary.tsv\"\n",
    "\n",
    "with open(summary_path, \"w\") as f:\n",
    "    f.write(\"run\\tseed_model\\tseed_train\\tT\\tnum_particles\\tlr\\tlr_model\\tdim\\ttraining_mode\\tfile\\n\")\n",
    "\n",
    "seeds_model = np.arange(BASE_SEED_MODEL, BASE_SEED_MODEL + NUM_RUNS, dtype=int)\n",
    "seeds_train = np.arange(BASE_SEED_TRAIN, BASE_SEED_TRAIN + NUM_RUNS, dtype=int)\n",
    "\n",
    "\n",
    "F_list, G_list, E_list = [], [], []\n",
    "\n",
    "def write_partial_aggregate(n_done: int):\n",
    "    \"\"\"Compute mean/std over runs completed so far and save aggregate_so_far.npz.\"\"\"\n",
    "    if n_done == 0:\n",
    "        return\n",
    "    F = np.stack(F_list[:n_done], axis=0)  \n",
    "    G = np.stack(G_list[:n_done], axis=0)\n",
    "    E = np.stack(E_list[:n_done], axis=0)\n",
    "    np.savez(\n",
    "        outdir / \"aggregate_so_far.npz\",\n",
    "        F_mean=F.mean(0), F_std=F.std(0, ddof=1) if n_done > 1 else np.zeros_like(F[0]),\n",
    "        G_mean=G.mean(0), G_std=G.std(0, ddof=1) if n_done > 1 else np.zeros_like(G[0]),\n",
    "        ELBO_mean=E.mean(0), ELBO_std=E.std(0, ddof=1) if n_done > 1 else np.zeros_like(E[0]),\n",
    "        seeds_model=seeds_model[:n_done], seeds_train=seeds_train[:n_done],\n",
    "        T=np.array(F.shape[1], dtype=int),\n",
    "        R=np.array(n_done, dtype=int),\n",
    "    )\n",
    "\n",
    "for run in range(NUM_RUNS):\n",
    "    seed_m = int(seeds_model[run])\n",
    "    seed_t = int(seeds_train[run])\n",
    "\n",
    "\n",
    "    key = jax.random.PRNGKey(seed_m)\n",
    "    key, key_theta, key_seq = jax.random.split(key, 3)\n",
    "    p, theta_true, theta_init = get_generative_model(p_args, key_theta)\n",
    "    xs, ys = p.sample_multiple_sequences(\n",
    "        key_seq, theta_true,\n",
    "        num_seqs=p_args.num_seqs,\n",
    "        seq_length=p_args.seq_length,\n",
    "        single_split_seq=False, load_from=''\n",
    "    )\n",
    "\n",
    "   \n",
    "    q_args = make_q_args_from_p(p_args, default_noise=0.1)\n",
    "    q = get_variational_model(q_args)\n",
    "\n",
    "    trainer = SVITrainer(\n",
    "        p=p, theta_true=theta_true, theta_star=theta_init, q=q,\n",
    "        optimizer=\"adam\",\n",
    "        learning_rate=LR, learning_rate_model=LR_MODEL,\n",
    "        optim_options=\"cst\",\n",
    "        num_epochs=1,\n",
    "        seq_length=p_args.seq_length,\n",
    "        num_samples=NUM_PARTICLES,\n",
    "        frozen_params=[\"prior\",\"transition.noise\",\"emission.noise\"],\n",
    "        num_seqs=p_args.num_seqs,\n",
    "        training_mode=TRAINING_MODE,\n",
    "        elbo_mode=ELBO_MODE\n",
    "    )\n",
    "\n",
    "   \n",
    "    key_train = jax.random.PRNGKey(seed_t)\n",
    "    key_init_params, key_montecarlo = jax.random.split(key_train, 2)\n",
    "    phi_params, theta_params, elbos, aux, phi_hist, theta_hist, mse_vals = trainer.fit(\n",
    "        key_init_params, key_montecarlo, data=ys, args=None\n",
    "    )\n",
    "\n",
    "\n",
    "    elbos = elbos.flatten()\n",
    "    E = np.asarray(elbos).reshape(-1)          \n",
    "    F_mse = np.asarray(mse_vals[0].squeeze())                \n",
    "    G_mse = np.asarray(mse_vals[1].squeeze())\n",
    "    run_arr = np.column_stack([F_mse,G_mse, E])  \n",
    "\n",
    "    run_path = outdir / f\"run_{run:02d}.npy\"\n",
    "    np.save(run_path, run_arr)                 \n",
    "\n",
    "\n",
    "    with open(summary_path, \"a\") as f:\n",
    "        f.write(f\"{run}\\t{seed_m}\\t{seed_t}\\t{run_arr.shape[0]}\\t{NUM_PARTICLES}\\t{LR}\\t{LR_MODEL}\\t{p_args.state_dim}x{p_args.obs_dim}\\t{TRAINING_MODE}\\t{run_path.name}\\n\")\n",
    "\n",
    "\n",
    "    F_list.append(run_arr[:, 0])\n",
    "    G_list.append(run_arr[:, 1])\n",
    "    E_list.append(run_arr[:, 2])\n",
    "    write_partial_aggregate(run + 1)\n",
    "\n",
    "\n",
    "F = np.stack(F_list, axis=0)\n",
    "G = np.stack(G_list, axis=0)\n",
    "E = np.stack(E_list, axis=0)\n",
    "np.savez(\n",
    "    outdir / \"aggregate.npz\",\n",
    "    F_mean=F.mean(0), F_std=F.std(0, ddof=1),\n",
    "    G_mean=G.mean(0), G_std=G.std(0, ddof=1),\n",
    "    ELBO_mean=E.mean(0), ELBO_std=E.std(0, ddof=1),\n",
    "    seeds_model=seeds_model, seeds_train=seeds_train,\n",
    "    T=np.array(F.shape[1], dtype=int),\n",
    "    R=np.array(F.shape[0], dtype=int),\n",
    ")\n",
    "\n",
    "print(f\"Per-run files saved in: {outdir}  (run_XX.npy with columns [F_mse, G_mse, ELBO])\")\n",
    "print(f\"Per-run summary: {summary_path.name}\")\n",
    "print(\"Aggregate files: aggregate_so_far.npz (after each run) + aggregate.npz (final)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b88f98ff-e19a-4b34-8abc-d6c0f9cd206a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0240c03-dee0-40f0-b1db-a85ef675ab30",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37398420-a10c-40f7-b8b4-8fadcf8e7fe1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9c0df27-5499-4e22-ac6e-888f1e859c8e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b59599b7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84626cde",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "788a094c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ovi_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
