{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32b5295f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0e07ace",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "print(f\"Old working dir {os.getcwd()}\")\n",
    "os.chdir('../')\n",
    "print(f\"New working dir {os.getcwd()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e58cbf3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.stats import norm, multivariate_normal\n",
    "import torch\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "706e6ef4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pushforward_operators import NeuralQuantileRegression, AmortizedNeuralQuantileRegression\n",
    "from conformal.real_datasets.reproducible_split import get_dataset_split\n",
    "from conformal.classes.method_desc import ConformalMethodDescription\n",
    "from conformal.wrappers.cvq_regressor import CVQRegressor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64a1c683",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pushforward_operators.protocol import PushForwardOperator\n",
    "from utils.quantile import get_quantile_level_analytically\n",
    "\n",
    "\n",
    "def plot_quantile_levels_from_model(\n",
    "    ax: matplotlib.axes.Axes,\n",
    "    model: PushForwardOperator,\n",
    "    conditional_value: torch.Tensor,\n",
    "    number_of_quantile_levels: int,\n",
    "    tensor_parameters: dict = {},\n",
    "):\n",
    "\n",
    "    quantile_levels = torch.linspace(0.05, 0.95, number_of_quantile_levels)\n",
    "    radii = get_quantile_level_analytically(\n",
    "        quantile_levels, distribution=\"gaussian\", dimension=2\n",
    "    )\n",
    "\n",
    "    X_batch = conditional_value.repeat(1000, 1).to(**tensor_parameters)\n",
    "    list_of_approximated_Y_quantiles = []\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for _, contour_radius in enumerate(radii):\n",
    "            pi = torch.linspace(-torch.pi, torch.pi, 1000)\n",
    "\n",
    "            ground_truth_U_quantiles = (\n",
    "                torch.stack(\n",
    "                    [\n",
    "                        contour_radius * torch.cos(pi),\n",
    "                        contour_radius * torch.sin(pi),\n",
    "                    ]\n",
    "                ).T\n",
    "            ).to(**tensor_parameters)\n",
    "\n",
    "\n",
    "            try:\n",
    "                approximated_Y_quantiels = model.push_u_given_x(\n",
    "                    u=ground_truth_U_quantiles, x=X_batch\n",
    "                )\n",
    "            except NotImplementedError:\n",
    "                approximated_Y_quantiels = None\n",
    "\n",
    "            list_of_approximated_Y_quantiles.append(approximated_Y_quantiels)\n",
    "\n",
    "    color_map = matplotlib.colormaps['viridis']\n",
    "\n",
    "    for i, approximated_Y_quantiels in enumerate(list_of_approximated_Y_quantiles):\n",
    "        color = color_map(i / number_of_quantile_levels)\n",
    "        label = f'Quantile level {quantile_levels[i]:.2f}'\n",
    "\n",
    "        if approximated_Y_quantiels is not None:\n",
    "            ax.plot(\n",
    "                approximated_Y_quantiels[:, 0],\n",
    "                approximated_Y_quantiels[:, 1],\n",
    "                color=color,\n",
    "                linewidth=2.5,\n",
    "                label=label\n",
    "            )\n",
    "\n",
    "    ax.legend(bbox_to_anchor=(1.1, 1.05))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "443e7af3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from conformal.experiment import _tuned_configs\n",
    "_params = _tuned_configs[\"bio\"]\n",
    "_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79994479",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = get_dataset_split(name=\"bio\", seed=0)\n",
    "reg = CVQRegressor(feature_dimension=ds.n_features, response_dimension=ds.n_outputs, **_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99d88cce",
   "metadata": {},
   "outputs": [],
   "source": [
    "reg.model = AmortizedNeuralQuantileRegression.load_class(\"./conformal_results/bio/0/model_cvqr.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbdf6440",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred_test = reg.predict_mean(ds.X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d551586",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_grid = 1000\n",
    "Y0_grig, Y1_grid = np.meshgrid(np.linspace(-4, 4, n_grid), np.linspace(-4, 4, n_grid))\n",
    "print(f\"{Y0_grig.shape=}, {Y1_grid.shape=}\")\n",
    "\n",
    "#Y0_grig = Y0_grig.reshape(-1, 1)\n",
    "#Y1_grid = Y1_grid.reshape(-1, 1)\n",
    "\n",
    "Y_grid = np.concatenate((Y0_grig.reshape(-1, 1), Y1_grid.reshape(-1, 1)), axis=1)\n",
    "Y_grid.shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06acf03c",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_test = 911\n",
    "X0 = ds.X_test[[idx_test]]\n",
    "X0_grid = np.repeat(X0, Y_grid.shape[0], axis=0)\n",
    "X0_grid.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a238dd5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = reg.calculate_scores(X0_grid, Y_grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14b5ba0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "u_sample = np.random.normal(size=(Y_grid.shape[0], 2), loc=0, scale=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c798f88a",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 2, figsize=(16, 6), sharex=True, sharey=True)\n",
    "plt.suptitle(f\"Density at X_test[{idx_test}]\")\n",
    "c0 = ax[0].contourf(Y0_grig, Y1_grid, np.exp(scores[\"Log Density\"]).reshape(n_grid, n_grid), cmap=\"coolwarm\")\n",
    "plt.colorbar(c0)\n",
    "ax[0].set_title('Density estimate using Hessian')\n",
    "ax[0].set_xlabel('Y0')\n",
    "ax[0].set_ylabel('Y1')\n",
    "\n",
    "ax[0].set_xlim(-4, 4)\n",
    "ax[0].set_ylim(-4, 4)\n",
    "ax[0].set_aspect(\"equal\", adjustable='box')\n",
    "\n",
    "plot_quantile_levels_from_model(\n",
    "    ax[1],\n",
    "    reg.model,\n",
    "    torch.tensor(ds.X_test[idx_test], dtype=torch.float32),\n",
    "    number_of_quantile_levels=10\n",
    ")\n",
    "ax[1].set_title('Quantile levels of U pushed to Y')\n",
    "ax[1].set_xlabel('Y0')\n",
    "ax[1].set_ylabel('Y1')\n",
    "\n",
    "ax[1].set_xlim(-4, 4)\n",
    "ax[1].set_ylim(-4, 4)\n",
    "ax[1].set_aspect(\"equal\", adjustable='box')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7e9de02",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.collections import LineCollection\n",
    "\n",
    "\n",
    "def plot_grid(x,y, ax=None, **kwargs):\n",
    "    ax = ax or plt.gca()\n",
    "    segs1 = np.stack((x,y), axis=2)\n",
    "    segs2 = segs1.transpose(1,0,2)\n",
    "    ax.add_collection(LineCollection(segs1, **kwargs))\n",
    "    ax.add_collection(LineCollection(segs2, **kwargs))\n",
    "    ax.autoscale()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conditional_quantile_function",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
