{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sbibm.metrics.c2st import c2st\n",
    "import os\n",
    "import re\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle as pkl\n",
    "import torch\n",
    "from torch import distributions as dist\n",
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "import jax.random as random\n",
    "from rsnl.examples.contaminated_normal import (get_prior, assumed_dgp,\n",
    "                                               calculate_summary_statistics,\n",
    "                                               true_dgp, true_posterior)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_reference_samples(x_obs):\n",
    "    prior = torch.distributions.Normal(torch.tensor([0.0]), torch.tensor([10.0]))\n",
    "\n",
    "    # Conjugate prior with known variance\n",
    "    n_obs = 100\n",
    "    true_dgp_var = 1.0\n",
    "    prior_var = prior.variance\n",
    "    obs_mean = x_obs[0]\n",
    "\n",
    "    true_post_var = (1/prior_var + n_obs/true_dgp_var) ** -1\n",
    "    true_post_mu = (true_post_var *\n",
    "                    (prior.mean/prior_var + ((obs_mean * n_obs) / true_dgp_var)))\n",
    "    true_post_std = torch.sqrt(true_post_var)\n",
    "    true_posterior =  dist.Normal(true_post_mu, true_post_std)\n",
    "    reference_samples = true_posterior.sample((40000,))\n",
    "    return reference_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_obs = torch.tensor([1.0, 2.0])\n",
    "get_reference_samples(x_obs).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "directory = '../res/contaminated_normal/rsnl/seed_0/'\n",
    "thetas = pkl.load(open(f'{directory}thetas.pkl', 'rb'))\n",
    "thetas.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_x_obs(seed):\n",
    "    rng_key = random.PRNGKey(seed)\n",
    "    rng_key, sub_key1, sub_key2 = random.split(rng_key, 3)\n",
    "    sim_fn = assumed_dgp\n",
    "    sum_fn = calculate_summary_statistics\n",
    "    true_params = jnp.array([1.0])\n",
    "    # true_params = prior.sample(sub_key1)\n",
    "    x_obs_tmp = true_dgp(sub_key2, true_params)\n",
    "    x_obs = jnp.array(calculate_summary_statistics(x_obs_tmp))\n",
    "    # Convert JAX array to numpy array\n",
    "    numpy_x_obs = np.array(x_obs)\n",
    "    numpy_x_obs[1] = 2.0\n",
    "\n",
    "    # Convert numpy array to PyTorch tensor\n",
    "    x_obs = torch.tensor(numpy_x_obs)\n",
    "    return x_obs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_c2st_res(directory=\"\"):\n",
    "    sub_dirs = [x[0] for x in os.walk(directory)]\n",
    "    sub_dirs = sub_dirs[1:]\n",
    "    c2st_res = []\n",
    "    for ii, sub_dir in enumerate(sub_dirs):\n",
    "        print(f\"ii: {ii}, sub_dir: {sub_dir}\")\n",
    "        try:\n",
    "            with open(f'{sub_dir}/thetas.pkl', 'rb') as f:\n",
    "                thetas = pkl.load(f)\n",
    "                thetas = np.array(thetas)\n",
    "                thetas = np.concatenate(thetas, axis=0)\n",
    "                base_name = os.path.basename(sub_dir)\n",
    "                match = re.search(r'seed_(\\d+)', base_name)\n",
    "                seed = int(match.group(1))\n",
    "                print('seed ', seed)\n",
    "                x_obs = get_x_obs(seed=seed)\n",
    "                print('x_obs ', x_obs)\n",
    "                algorithm_samples = torch.as_tensor(thetas).reshape(-1, 1)\n",
    "                reference_samples = get_reference_samples(x_obs).reshape(-1, 1)\n",
    "                c2st_ii = c2st(algorithm_samples, reference_samples)\n",
    "                print(f\"c2st_ii: {c2st_ii}\")\n",
    "                c2st_res.append(c2st_ii)\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "            print(f\"Error with {sub_dir}\")\n",
    "\n",
    "    return c2st_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "directory = \"../res/contaminated_normal/rsnl/\"\n",
    "c2st_rsnl = get_c2st_res(directory=directory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c2st_rsnl = np.concatenate(c2st_rsnl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.boxplot(c2st_rsnl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "directory = \"../res/contaminated_normal/snl/\"\n",
    "c2st_snl = get_c2st_res(directory=directory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c2st_snl = np.concatenate(c2st_snl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "df = pd.DataFrame()\n",
    "df['C2ST'] = np.concatenate([c2st_rsnl, c2st_snl])\n",
    "df['Method'] = ['RSNL'] * len(c2st_rsnl) + ['SNL'] * len(c2st_snl)\n",
    "sns.set(font_scale=3.5, font='Times New Roman', style='white')\n",
    "        # xlabel='', ylabel='Log density')\n",
    "plt.ylim(0.5, 1)\n",
    "plt.yticks([0.5, 0.75, 1.0])\n",
    "# fig, ax = plt.subplots()\n",
    "# ax = sns.boxplot(x='method', y='logprob', data=df, showfliers = False, ax=axs[i])\n",
    "\n",
    "sns.boxplot(x='Method', y='C2ST', data=df)\n",
    "plt.xlabel(\"\")\n",
    "plt.savefig('c2st.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
