{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from analysis_utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "names = ['Prior', 'SBI', 'NPE', 'OT-only(single sample)', \n",
    "'OT-only(full test)', 'finetune-only',\n",
    " 'RoPE(single sample)', 'RoPE (full test)', 'ours']\n",
    "\n",
    "x_label = 'num of calibration samples'\n",
    "x = np.array([10, 50, 200, 1000])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Performance across different sizes of calibration set (vs. baselines)\n",
    "\n",
    "Replace the placeholders for the wandb runs in `list_of_runs` with the paths of the runs you have run. Should be on the form \"wandb_user/project_name/run_id\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list_of_runs = ['run_w_10_samples'\n",
    "'run_w_50_samples',\n",
    "'run_w_200_samples',\n",
    "'run_w_1000_samples'\n",
    "]\n",
    "results = import_results(list_of_runs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#prior_lpp = [-2.88, -16.635,-3.8 ] # precomputed for pendulum, light tunnel, wind tunnel respectively \n",
    "metric_name = 'lpp' \n",
    "means, stds = get_mean_stds_matrix(names, metric_name, x, results, -2.88)\n",
    "plot_metric_vs_x_piecewise(means, stds, names, x, 'LPP', x_label, prior_value=-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#prior_acauc = 0 #where prior match \n",
    "means, stds = get_mean_stds_matrix(names, 'acauc', x, results, 0)\n",
    "plot_metric_vs_x_linear_y(means, stds, names, x, 'AC-AUC', x_label)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Noise experiemnt\n",
    "\n",
    "Analysed separately for each size of calibration set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# noise rate\n",
    "x = np.array([0, 1, 10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list_of_runs = ['noise_rate_0',\n",
    "'noise_rate_1',\n",
    "'noise_rate_10']\n",
    "\n",
    "results = import_results(list_of_runs, x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "means, stds = get_mean_stds_matrix(names, 'lpp', x, results, -2.88)\n",
    "plot_metric_vs_x_piecewise(means, stds, names, x, 'LPP', x_label, prior_value=-5, is_log=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "means, stds = get_mean_stds_matrix(names, 'acauc', x, results, 0)\n",
    "plot_metric_vs_x_linear_y(means, stds, names, x, 'AC-AUC', x_label, is_log=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ablation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "names = ['amortised solution only', 'joint training only', 'full pipeline']\n",
    "def get_mean_stds_matrix_ablation(namses, metric_name, x, results, prior):\n",
    "    means = np.zeros([len(names), x.shape[0]])\n",
    "    stds = np.zeros([len(names), x.shape[0]])\n",
    "    for i, name in enumerate(names):  \n",
    "        for j, num_samples in enumerate(x):\n",
    "            if name=='amortised solution only':\n",
    "                name = 'test_NF_align_{}'.format(metric_name)\n",
    "            if name=='joint training only':\n",
    "                name = 'test_WassOT_NF_align_{}_wass'.format(metric_name)\n",
    "            if name == 'full pipeline':\n",
    "                name = 'test_WassOT_NF_align_{}'.format(metric_name)\n",
    "            means[i, j] = results[str(num_samples)][metric_name]['mean'][name]\n",
    "            stds[i, j] = results[str(num_samples)][metric_name]['std'][name]\n",
    "    return means, stds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "names = ['amortised solution only', 'joint training only', 'full pipeline']\n",
    "x = np.array([10, 50, 200, 1000])\n",
    "list_of_runs = ['run_w_10_samples'\n",
    "'run_w_50_samples',\n",
    "'run_w_200_samples',\n",
    "'run_w_1000_samples'\n",
    "]\n",
    "\n",
    "results = import_results(list_of_runs)\n",
    "x_label = 'num of calibration samples'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "means, stds = get_mean_stds_matrix_ablation(names, 'lpp', x, results, -2.88)\n",
    "plot_metric_vs_x_piecewise(means, stds, names, x, 'LPP', x_label, prior_value=-5, is_log=True, is_legend=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "means, stds = get_mean_stds_matrix_ablation(names, 'acauc', x, results, 0)\n",
    "plot_metric_vs_x_linear_y(means, stds, names, x, 'AC-AUC', x_label, is_log=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
