{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd75247f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b74836c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "print(f\"Old working dir {os.getcwd()}\")\n",
    "os.chdir('../')\n",
    "print(f\"New working dir {os.getcwd()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f2e86d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.stats import norm, multivariate_normal\n",
    "import torch\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "%matplotlib inline\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9978485",
   "metadata": {},
   "outputs": [],
   "source": [
    "from conformal.real_datasets.reproducible_split import get_dataset_split\n",
    "from conformal.classes.method_desc import ConformalMethodDescription\n",
    "from conformal.score_calculators import CVQRegressor, CVQRegressorRF, CVQRegressorY, CVQRegressorYRF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff883e1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 0.3\n",
    "scale = 0.1\n",
    "n = 10000\n",
    "rng = np.random.default_rng(31337)\n",
    "x = np.linspace(0, 1, n)\n",
    "y_true =  x ** 1.6\n",
    "y = y_true + rng.normal(scale=scale, size=n)\n",
    "interval_1a = norm.interval(1 - alpha, loc=0, scale=scale)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ab69680",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(x, y, label=r\"Data $y=f(x)+\\epsilon$\", alpha=0.5)\n",
    "plt.plot(x, y_true, \"g\", label=rf'$y=f(x) + {int((1 - alpha) * 100)}\\%$')\n",
    "plt.fill_between(x, y_true + interval_1a[0], y_true + interval_1a[1], color=\"g\", alpha=0.3)\n",
    "#plt.plot(x, x, \"k--\", label=r'$y=x$')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5066fe55",
   "metadata": {},
   "outputs": [],
   "source": [
    "reg = CVQRegressor(\n",
    "    feature_dimension=1,\n",
    "    response_dimension=1,\n",
    "    hidden_dimension=8,\n",
    "    number_of_hidden_layers=4,\n",
    "    batch_size=512,\n",
    "    n_epochs=150,\n",
    "    learning_rate=0.01,\n",
    "    dtype=torch.float32,\n",
    "    betas=(0.5, 0.5),\n",
    "    weight_decay=1e-4,\n",
    "    warmup_iterations=5\n",
    ")\n",
    "fn_model = \"cvqregressor_for_1d_check_hpd3.pth\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f44ee3c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.isfile(fn_model):\n",
    "    reg.model.load(fn_model)\n",
    "else:\n",
    "    reg.fit(x.reshape(-1, 1), y.reshape(-1, 1))\n",
    "    reg.model.eval()\n",
    "    reg.model.save(fn_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9761f36c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#reg.model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7693d2eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = reg.predict_mean(x.reshape(-1, 1))\n",
    "interval_99 = norm.interval(1 - alpha, loc=0, scale=1)\n",
    "y_pred_low = reg.predict_inverse_quantile(x.reshape(-1, 1), np.repeat(interval_99[0], repeats=n, axis=0).reshape(-1, 1))\n",
    "y_pred_high = reg.predict_inverse_quantile(x.reshape(-1, 1), np.repeat(interval_99[1], repeats=n, axis=0).reshape(-1, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2974e20",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred_low.shape, y_pred_high.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12b6ae11",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred_low[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b5892c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(x, y, label=r\"Data $y=f(x)+\\epsilon$\", alpha=0.5)\n",
    "plt.plot(x, y_true, \"g\", label=rf'$y=f(x) + {int((1 - alpha) * 100)}\\%$')\n",
    "plt.plot(x, y_pred, \"r\", label=rf'$y=\\hat{{f}}(x) + {int((1 - alpha) * 100)}\\%$')\n",
    "plt.fill_between(x, y_true + interval_1a[0], y_true + interval_1a[1], color=\"g\", alpha=0.3)\n",
    "plt.fill_between(x, y_pred_low[:, 0], y_pred_high[:, 0], color=\"r\", alpha=0.3)\n",
    "#plt.plot(x, x, \"k--\", label=r'$y=x$')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d629d06a",
   "metadata": {},
   "outputs": [],
   "source": [
    "quantiles = reg.predict_quantile(x.reshape(-1, 1), y.reshape(-1, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "772e5752",
   "metadata": {},
   "outputs": [],
   "source": [
    "t = np.linspace(-4, 4, 1000)\n",
    "sns.histplot(quantiles, kde=True, stat=\"density\")\n",
    "plt.plot(t, norm.pdf(t), \"k--\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71ef5aaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Look at log density at the point x = 0.5\n",
    "x0 = 0.9 * np.ones((1000, 1))\n",
    "#u_sample = rng.random.normal(size=1000)\n",
    "scores_smaples = reg.calculate_scores(x0.reshape(-1, 1), t.reshape(-1, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bc702fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(t, np.exp(scores_smaples[\"Log Density\"]), label=\"Log Density Estimate\") #scores[\"Log Density\"]\n",
    "plt.plot(t, norm.pdf(t, loc=x0[0]**1.6, scale=scale), \"k--\", label=\"True Density\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e189fd9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "username = \"\"\n",
    "os.chdir('/home/{username}/repos/conditional_quantile_function')\n",
    "\n",
    "from argparse import Namespace\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "from scipy.stats import norm, multivariate_normal\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "%matplotlib inline\n",
    "\n",
    "from pushforward_operators import AmortizedNeuralQuantileRegression\n",
    "from conformal.real_datasets.reproducible_split import get_dataset_split\n",
    "from conformal.score_calculators import CVQRegressorRF, CVQRegressorY, CVQRegressor, CPFlowRegressor\n",
    "from conformal.classes.conformalizers import QuantileEstimatePredictor, SplitConformalPredictor\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03a3a110",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_sgemm_a = get_dataset_split(\"rf1\", seed=1239)\n",
    "ds_sgemm_b = get_dataset_split(\"rf1\", seed=1239)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87b41304",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_sgemm_a.X_cal[0, 0], ds_sgemm_b.X_cal[0, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e1ca7c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Multiple dimensions\n",
    "\n",
    "seed = 0\n",
    "dataset = \"scm20d\"\n",
    "ds = get_dataset_split(dataset, seed=seed)\n",
    "args = Namespace(\n",
    "    dataset=dataset,\n",
    "    seed=seed,\n",
    "    n_cpus=8,\n",
    ")\n",
    "\n",
    "#model_cpflow = CPFlowRegressor.create_or_load(\n",
    "#    path=Path(f\"./conformal_results_u/{dataset}/{seed}\"), args=args, dataset_split=ds\n",
    "#)\n",
    "#model_cpflow.model.eval()\n",
    "\n",
    "model_u = CVQRegressor.create_or_load(\n",
    "    path=Path(f\"./conformal_results_u/{dataset}/{seed}\"), args=args, dataset_split=ds\n",
    ")\n",
    "model_u.model.eval()\n",
    "#model_y = CVQRegressorY.create_or_load(\n",
    "#    path=Path(f\"./conformal_results_u/{dataset}/{seed}\"), args=args, dataset_split=ds\n",
    "#)\n",
    "#model_y.model.eval()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ece8626",
   "metadata": {},
   "outputs": [],
   "source": [
    "#model_u.model.init_dict, model_y.model.init_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b13ace1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#scores_cpflow_cal = model_cpflow.calculate_scores(ds.X_cal, ds.Y_cal)\n",
    "#scores_cpflow_test = model_cpflow.calculate_scores(ds.X_test, ds.Y_test)\n",
    "\n",
    "scores_u_cal = model_u.calculate_scores(ds.X_cal, ds.Y_cal)\n",
    "scores_u_test = model_u.calculate_scores(ds.X_test, ds.Y_test)\n",
    "\n",
    "#scores_y_cal = model_y.calculate_scores(ds.X_cal, ds.Y_cal)\n",
    "#scores_y_test = model_y.calculate_scores(ds.X_test, ds.Y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52a84800",
   "metadata": {},
   "outputs": [],
   "source": [
    "#scores_u_cal, scores_y_cal, scores_cpflow_cal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e9f1049",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_tensor = torch.tensor(ds.X_test)\n",
    "Y_tensor = torch.tensor(ds.Y_test)\n",
    "raw_model.to(X_tensor)\n",
    "U_pullback = raw_model.push_y_given_x(x=X_tensor, y=Y_tensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12764924",
   "metadata": {},
   "outputs": [],
   "source": [
    "t = np.linspace(-4, 4, 1000)\n",
    "sns.histplot(U_pullback.numpy(force=True), kde=True, stat=\"density\")\n",
    "plt.plot(t, norm.pdf(t), \"k--\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2409928",
   "metadata": {},
   "outputs": [],
   "source": [
    "t = np.linspace(-4, 4, 1000)\n",
    "sns.histplot(scores_u_test[\"MK Quantile\"], kde=True, stat=\"density\", common_norm=False)\n",
    "plt.plot(t, norm.pdf(t), \"k--\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4be866bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(np.exp(scores_u_test[\"Log Density\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62f342f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_method = QuantileEstimatePredictor(d_y=ds.n_outputs, seed=0, alpha=0.1)\n",
    "pb_method = SplitConformalPredictor(d_y=ds.n_outputs, seed=0, alpha=0.1, lower_is_better=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67afc72b",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_method.fit(\n",
    "    X_cal=ds.X_cal,\n",
    "    scores_cal=scores_u_cal[\"MK Quantile\"],\n",
    "    alpha=0.1,\n",
    ")\n",
    "\n",
    "pb_method.fit(\n",
    "    X_cal=ds.X_cal,\n",
    "    scores_cal=scores_u_cal[\"MK Rank\"],\n",
    "    alpha=0.1,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dba98ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_method.is_covered(ds.X_test, scores_u_test[\"MK Quantile\"]).mean(), pb_method.is_covered(ds.X_test, scores_u_test[\"MK Rank\"]).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28265664",
   "metadata": {},
   "outputs": [],
   "source": [
    "pb_method.threshold, scores_u_test[\"MK Rank\"].min(), scores_u_test[\"MK Rank\"].max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45cc3478",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate areas with smapling\n",
    "n_samples = 10_000\n",
    "\n",
    "rng = np.random.default_rng(args.seed)\n",
    "ymin = ds.Y_train.min(axis=0)\n",
    "ymax = ds.Y_train.max(axis=0)\n",
    "\n",
    "scale = np.prod(ymax - ymin)\n",
    "print(f\"Bounding box volume: {scale}\")\n",
    "\n",
    "i = 101\n",
    "X_samples = np.repeat(ds.X_test[i:i + 1], repeats=n_samples, axis=0)\n",
    "Y_smaples = ymin + rng.random((n_samples, ds.n_outputs)) * (ymax - ymin)\n",
    "\n",
    "scores_smaples = model_u.calculate_scores(X_samples, Y_smaples)\n",
    "volume_i = np.mean(pb_method.is_covered(X_samples, scores_smaples[\"MK Rank\"])) * scale\n",
    "print(f\"Volume estimate (sampling): {volume_i}, {np.log(volume_i) / ds.n_outputs}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be36aeae",
   "metadata": {},
   "outputs": [],
   "source": [
    "log_volumes = []\n",
    "for _ in range(200):\n",
    "    log_volumes.append(\n",
    "        model_u.model.get_log_volume(\n",
    "            torch.tensor(ds.X_test[i], dtype=torch.float32),\n",
    "            pb_method.threshold,\n",
    "            number_of_points_to_estimate_bounding_box=100,\n",
    "            number_of_points_to_estimate_volume=10000,\n",
    "        )\n",
    "    )\n",
    "    mean, std = torch.tensor(log_volumes).mean().item(), torch.tensor(log_volumes).std().item()\n",
    "    print(f\"{mean=}, {std=}\")\n",
    "\n",
    "\n",
    "\n",
    "log_v = model_u.model.get_log_volume(torch.tensor(ds.X_test[i], dtype=torch.float32), pb_method.threshold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf7e2613",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import trange\n",
    "test_log_volumes, cal_log_volumes = [], []\n",
    "\n",
    "test_progress_bar = trange(ds.X_test.shape[0])\n",
    "for x_index in test_progress_bar:\n",
    "    x = ds.X_test[x_index]\n",
    "    test_log_volumes.append(\n",
    "        model_u.model.get_log_volume(\n",
    "            torch.tensor(x, dtype=torch.float32),\n",
    "            pb_method.threshold\n",
    "        )\n",
    "    )\n",
    "    mean, std = torch.tensor(test_log_volumes).mean().item(), torch.tensor(test_log_volumes).std().item()\n",
    "    test_progress_bar.set_postfix({\n",
    "        \"index\":x_index,\n",
    "        \"mean\":mean,\n",
    "        \"std\":std,\n",
    "    })\n",
    "\n",
    "calibration_progress_bar = trange(ds.X_cal.shape[0])\n",
    "for x_index in calibration_progress_bar:\n",
    "    x = ds.X_test[x_index]\n",
    "    test_log_volumes.append(\n",
    "        model_u.model.get_log_volume(\n",
    "            torch.tensor(x, dtype=torch.float32),\n",
    "            pb_method.threshold\n",
    "        )\n",
    "    )\n",
    "    mean, std = torch.tensor(cal_log_volumes).mean().item(), torch.tensor(cal_log_volumes).std().item()\n",
    "    calibration_progress_bar.set_postfix({\n",
    "        \"index\":x_index,\n",
    "        \"mean\":mean,\n",
    "        \"std\":std,\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdc42e27",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds.X_cal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b497636f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conditional_quantile_function",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
