{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "443daa08",
   "metadata": {},
   "source": [
    "# Step Wise Estimation of TV Distances for Burgers, Heat, and Wave Data"
   ]
  },
  {
   "cell_type": "code",
   "id": "3de76749",
   "metadata": {},
   "source": [
    "import sys\n",
    "from pathlib import Path\n",
    "\n",
    "sys.path.append(str(Path().resolve().parent))\n",
    "import numpy as np\n",
    "import torch\n",
    "from src.utils import wandb_utils as wu\n",
    "from tqdm import trange\n",
    "from sklearn.covariance import LedoitWolf\n",
    "\n",
    "LedoitWolf"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "33f4b1f9",
   "metadata": {},
   "source": [
    "data_path = Path().resolve().parent / 'data'\n",
    "results_path = Path().resolve().parent / 'results'\n",
    "results_path.mkdir(parents=True, exist_ok=True)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "f379411a",
   "metadata": {},
   "source": [
    "## Train Base Models"
   ]
  },
  {
   "cell_type": "code",
   "id": "57d46288",
   "metadata": {},
   "source": [
    "# Run Base Prediction Model\n",
    "# !train_base_models.sh"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# Enter the best model versions here (it should be in the output/ folder)\n",
    "best_antidiff_model = \"xxx\"\n",
    "best_bwheat_model = \"xxx\"\n",
    "best_reaction_model = \"xxx\""
   ],
   "id": "26a04e2c577e44e2",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "16628780",
   "metadata": {},
   "source": [
    "## Generate Residual Datasets for LSCI"
   ]
  },
  {
   "cell_type": "code",
   "id": "258ade1b",
   "metadata": {},
   "source": [
    "TIME_FRAME = 3\n",
    "TIME_STEP_SIZE = 1\n",
    "ALPHA = 0.1\n",
    "DEVICE = \"mps\"  # \"cpu\" or \"mps\" or \"cuda\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "0cf8618f",
   "metadata": {},
   "source": [
    "anti_diff_model = wu.load_model_from_checkpoint(best_antidiff_model, device=DEVICE)\n",
    "bwheat_model = wu.load_model_from_checkpoint(best_bwheat_model, device=DEVICE)\n",
    "reaction_model = wu.load_model_from_checkpoint(best_reaction_model, device=DEVICE)\n",
    "\n",
    "# Load Data\n",
    "antidiff_data_calib = np.load(data_path / 'antidiff' / 'antidiff_calib.npz')\n",
    "antidiff_data_test = np.load(data_path / 'antidiff' / 'antidiff_test.npz')\n",
    "\n",
    "bwheat_data_calib = np.load(data_path / 'bwheat' / 'bwheat_calib.npz')\n",
    "bwheat_data_test = np.load(data_path / 'bwheat' / 'bwheat_test.npz')\n",
    "\n",
    "reaction_data_calib = np.load(data_path / 'reaction' / 'reaction_calib.npz')\n",
    "reaction_data_test = np.load(data_path / 'reaction' / 'reaction_test.npz')"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "5d9be1d8",
   "metadata": {},
   "source": [
    "@torch.no_grad()\n",
    "def predict_over_time(model, data_full, time_delta_index=1, batch_size=16, max_steps=32):\n",
    "    \"\"\"\n",
    "    Calculate residuals for a given prior model and starting data.\n",
    "\n",
    "    Args:\n",
    "        prior_model: The pre-trained model to use for predictions.\n",
    "        start_data_full: A dictionary containing 'data' and 'x' keys.\n",
    "        target_time: The time step interval for predictions.\n",
    "    Returns:\n",
    "        start_residuals: The calculated residuals.\n",
    "        start_predictions: The model predictions.\n",
    "    \"\"\"\n",
    "\n",
    "    data = torch.tensor(data_full['data'])\n",
    "    input_grid = torch.stack([torch.tensor(data_full['x']) for _ in range(data.shape[0])], dim=0).unsqueeze(-1)\n",
    "    latent_grid = torch.stack([torch.tensor(data_full['x']) for _ in range(data.shape[0])], dim=0).unsqueeze(-1)\n",
    "    output_grid = torch.stack([torch.tensor(data_full['x']) for _ in range(data.shape[0])], dim=0).unsqueeze(-1)\n",
    "\n",
    "    residuals = torch.zeros_like(data, device='cpu')\n",
    "    predictions = torch.zeros_like(data, device='cpu')\n",
    "\n",
    "    max_steps = min(max_steps, data.shape[1])\n",
    "\n",
    "    for i in range(0, data.shape[0], batch_size):\n",
    "\n",
    "        # Input is time-step 0, optionally downsampled in space\n",
    "        u0 = data[i: i + batch_size, 0].unsqueeze(-1)\n",
    "\n",
    "        for j in trange(time_delta_index, max_steps, time_delta_index):\n",
    "            # Output/label is full grid at the specified time index\n",
    "            y = data[i: i + batch_size, j].unsqueeze(-1)\n",
    "            input_dict = {\n",
    "                \"x\": u0.to(DEVICE),\n",
    "                \"y\": y.to(DEVICE),\n",
    "                \"input_grid\": input_grid[:u0.shape[0]].to(DEVICE),\n",
    "                \"latent_grid\": latent_grid[:u0.shape[0]].to(DEVICE),\n",
    "                \"output_grid\": output_grid[:u0.shape[0]].to(DEVICE),\n",
    "            }\n",
    "\n",
    "            y_pred = model(input_dict)\n",
    "\n",
    "            predictions[i: i + batch_size, j] = y_pred.squeeze(-1).cpu()\n",
    "            residuals[i: i + batch_size, j] = (y - y_pred.cpu()).squeeze(-1)\n",
    "            u0 = y\n",
    "    return residuals, predictions\n"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "2dae00fc",
   "metadata": {},
   "source": [
    "antidiff_residuals_calib, antidiff_predictions_calib = predict_over_time(anti_diff_model, antidiff_data_calib,\n",
    "                                                                         time_delta_index=TIME_STEP_SIZE,\n",
    "                                                                         batch_size=500, max_steps=TIME_FRAME)\n",
    "antidiff_residuals_test, antidiff_predictions_test = predict_over_time(anti_diff_model, antidiff_data_test,\n",
    "                                                                       time_delta_index=TIME_STEP_SIZE, batch_size=500,\n",
    "                                                                       max_steps=TIME_FRAME)\n",
    "\n",
    "bwheat_residuals_calib, bwheat_predictions_calib = predict_over_time(bwheat_model, bwheat_data_calib,\n",
    "                                                                     time_delta_index=TIME_STEP_SIZE, batch_size=500,\n",
    "                                                                     max_steps=TIME_FRAME)\n",
    "bwheat_residuals_test, bwheat_predictions_test = predict_over_time(bwheat_model, bwheat_data_test,\n",
    "                                                                   time_delta_index=TIME_STEP_SIZE, batch_size=500,\n",
    "                                                                   max_steps=TIME_FRAME)\n",
    "\n",
    "reaction_residuals_calib, reaction_predictions_calib = predict_over_time(reaction_model, reaction_data_calib,\n",
    "                                                                         time_delta_index=TIME_STEP_SIZE,\n",
    "                                                                         batch_size=500, max_steps=TIME_FRAME)\n",
    "reaction_residuals_test, reaction_predictions_test = predict_over_time(reaction_model, reaction_data_test,\n",
    "                                                                       time_delta_index=TIME_STEP_SIZE, batch_size=500,\n",
    "                                                                       max_steps=TIME_FRAME)\n"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "f9a9300b",
   "metadata": {},
   "source": [
    "from src.conformal_prediction import covariate_shift as cs\n",
    "\n",
    "cp = cs.ConformalFunctionalBand(alpha=ALPHA, scale_mode=\"none\", side=\"two\", eps=1e-8)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "92cb39e4",
   "metadata": {},
   "source": [
    "def get_mean_and_cov(data):\n",
    "    mean = np.mean(data, axis=0)\n",
    "    lw = LedoitWolf().fit(data)\n",
    "    cov = lw.covariance_\n",
    "\n",
    "    return mean, cov"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "95172188",
   "metadata": {},
   "source": [
    "coverages_antidiff = []\n",
    "mean_bandwidths_antidiff = []\n",
    "num_infinite_antidiff = []\n",
    "\n",
    "for i in range(1, TIME_FRAME):\n",
    "    mean_cal, cov_cal = get_mean_and_cov(antidiff_data_calib['data'][:, 0])\n",
    "    mean_test, cov_test = get_mean_and_cov(antidiff_data_test['data'][:, i - 1])\n",
    "\n",
    "    cp.calibrate_with_gaussian_shift(Y_cal=antidiff_data_calib['data'][:, 1], Yhat_cal=antidiff_predictions_calib[:, 1],\n",
    "                                     X_cal=antidiff_data_calib['data'][:, 0], mu_p=mean_cal, Sigma_p=cov_cal,\n",
    "                                     mu_q=mean_test, Sigma_q=cov_test)\n",
    "\n",
    "    coverage, covered, bandwidth = cp.get_coverage(antidiff_data_test['data'][:, i], antidiff_predictions_test[:, i],\n",
    "                                                   antidiff_data_test['data'][:, i - 1], mu_p=mean_cal, Sigma_p=cov_cal,\n",
    "                                                   mu_q=mean_test, Sigma_q=cov_test)\n",
    "\n",
    "    # Calculate mean bandwidth excluding infinite values\n",
    "    finite_bandwidth = bandwidth[np.isfinite(bandwidth.mean(axis=1)), :].mean(axis=1)\n",
    "    mean_bandwidth = finite_bandwidth.mean() if finite_bandwidth.size > 0 else float('nan')\n",
    "    num_infinite = np.sum(~np.isfinite(bandwidth.mean(axis=1)))\n",
    "\n",
    "    coverages_antidiff.append(coverage)\n",
    "    mean_bandwidths_antidiff.append(mean_bandwidth)\n",
    "    num_infinite_antidiff.append(num_infinite)\n",
    "    print(\n",
    "        f\"Time Step {i}, Coverage: {coverage}, Bandwidth Mean (finite): {mean_bandwidth}, Num Infinite: {num_infinite}\")\n",
    "\n",
    "(results_path / 'antidiff').mkdir(parents=True, exist_ok=True)\n",
    "np.savez(results_path / 'antidiff' / 'antidiff_weighted_results.npz',\n",
    "         coverages=np.array(coverages_antidiff),\n",
    "         mean_bandwidths=np.array(mean_bandwidths_antidiff),\n",
    "         num_infinite=np.array(num_infinite_antidiff))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "966b88c2",
   "metadata": {},
   "source": [
    "coverages_bwheat = []\n",
    "mean_bandwidths_bwheat = []\n",
    "num_infinite_bwheat = []\n",
    "\n",
    "for i in range(1, TIME_FRAME):\n",
    "    mean_cal, cov_cal = get_mean_and_cov(bwheat_data_calib['data'][:, 0])\n",
    "    mean_test, cov_test = get_mean_and_cov(bwheat_data_test['data'][:, i - 1])\n",
    "\n",
    "    cp.calibrate_with_gaussian_shift(Y_cal=bwheat_data_calib['data'][:, 1], Yhat_cal=bwheat_predictions_calib[:, 1],\n",
    "                                     X_cal=bwheat_data_calib['data'][:, 0], mu_p=mean_cal, Sigma_p=cov_cal,\n",
    "                                     mu_q=mean_test, Sigma_q=cov_test)\n",
    "\n",
    "    coverage, covered, bandwidth = cp.get_coverage(bwheat_data_test['data'][:, i], bwheat_predictions_test[:, i],\n",
    "                                                   bwheat_data_test['data'][:, i - 1], mu_p=mean_cal, Sigma_p=cov_cal,\n",
    "                                                   mu_q=mean_test, Sigma_q=cov_test)\n",
    "\n",
    "    # Calculate mean bandwidth excluding infinite values\n",
    "    finite_bandwidth = bandwidth[np.isfinite(bandwidth.mean(axis=1)), :].mean(axis=1)\n",
    "    mean_bandwidth = finite_bandwidth.mean() if finite_bandwidth.size > 0 else float('nan')\n",
    "    num_infinite = np.sum(~np.isfinite(bandwidth.mean(axis=1)))\n",
    "\n",
    "    coverages_bwheat.append(coverage)\n",
    "    mean_bandwidths_bwheat.append(mean_bandwidth)\n",
    "    num_infinite_bwheat.append(num_infinite)\n",
    "    print(\n",
    "        f\"Time Step {i}, Coverage: {coverage}, Bandwidth Mean (finite): {mean_bandwidth}, Num Infinite: {num_infinite}\")\n",
    "\n",
    "(results_path / 'bwheat').mkdir(parents=True, exist_ok=True)\n",
    "np.savez(results_path / 'bwheat' / 'bwheat_weighted_results.npz',\n",
    "         coverages=np.array(coverages_bwheat),\n",
    "         mean_bandwidths=np.array(mean_bandwidths_bwheat),\n",
    "         num_infinite=np.array(num_infinite_bwheat))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "4c4157b0",
   "metadata": {},
   "source": [
    "coverages_reaction = []\n",
    "mean_bandwidths_reaction = []\n",
    "num_infinite_reaction = []\n",
    "\n",
    "for i in range(1, TIME_FRAME):\n",
    "    mean_cal, cov_cal = get_mean_and_cov(reaction_data_calib['data'][:, 0])\n",
    "    mean_test, cov_test = get_mean_and_cov(reaction_data_calib['data'][:, i - 1])\n",
    "\n",
    "    cp.calibrate_with_gaussian_shift(Y_cal=reaction_data_calib['data'][:, 1], Yhat_cal=reaction_predictions_calib[:, 1],\n",
    "                                     X_cal=reaction_data_calib['data'][:, 0], mu_p=mean_cal, Sigma_p=cov_cal,\n",
    "                                     mu_q=mean_test, Sigma_q=cov_test)\n",
    "\n",
    "    coverage, covered, bandwidth = cp.get_coverage(reaction_data_calib['data'][:, i], reaction_predictions_calib[:, i],\n",
    "                                                   reaction_data_calib['data'][:, i - 1], mu_p=mean_cal,\n",
    "                                                   Sigma_p=cov_cal, mu_q=mean_test, Sigma_q=cov_test)\n",
    "\n",
    "    # Calculate mean bandwidth excluding infinite values\n",
    "    finite_bandwidth = bandwidth[np.isfinite(bandwidth.mean(axis=1)), :].mean(axis=1)\n",
    "    mean_bandwidth = finite_bandwidth.mean() if finite_bandwidth.size > 0 else float('nan')\n",
    "    num_infinite = np.sum(~np.isfinite(bandwidth.mean(axis=1)))\n",
    "\n",
    "    coverages_reaction.append(coverage)\n",
    "    mean_bandwidths_reaction.append(mean_bandwidth)\n",
    "    num_infinite_reaction.append(num_infinite)\n",
    "    print(\n",
    "        f\"Time Step {i}, Coverage: {coverage}, Bandwidth Mean (finite): {mean_bandwidth}, Num Infinite: {num_infinite}\")\n",
    "\n",
    "(results_path / 'reaction').mkdir(parents=True, exist_ok=True)\n",
    "np.savez(results_path / 'reaction' / 'reaction_weighted_results.npz',\n",
    "         coverages=np.array(coverages_reaction),\n",
    "         mean_bandwidths=np.array(mean_bandwidths_reaction),\n",
    "         num_infinite=np.array(num_infinite_reaction))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "b27149f2",
   "metadata": {},
   "source": [
    "# Unweighted CP "
   ]
  },
  {
   "cell_type": "code",
   "id": "ed30dc4f",
   "metadata": {},
   "source": [
    "from src.conformal_prediction import covariate_shift_nw as csnw\n",
    "\n",
    "cpnw = csnw.ConformalFunctionalBandNaive(alpha=ALPHA, scale_mode=\"none\", side=\"two\", eps=1e-8)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "fd01a7d1",
   "metadata": {},
   "source": [
    "coverages_antidiff_nw = []\n",
    "mean_bandwidths_antidiff_nw = []\n",
    "num_infinite_antidiff_nw = []\n",
    "\n",
    "for i in range(1, TIME_FRAME):\n",
    "    cpnw.calibrate(Y_cal=antidiff_data_calib['data'][:, 1], Yhat_cal=antidiff_predictions_calib[:, 1])\n",
    "\n",
    "    coverage, covered, bandwidth = cpnw.get_coverage(antidiff_data_test['data'][:, i], antidiff_predictions_test[:, i])\n",
    "\n",
    "    # Calculate mean bandwidth excluding infinite values\n",
    "    finite_bandwidth = bandwidth[np.isfinite(bandwidth.mean(axis=1)), :].mean(axis=1)\n",
    "    mean_bandwidth = finite_bandwidth.mean() if finite_bandwidth.size > 0 else float('nan')\n",
    "    num_infinite = np.sum(~np.isfinite(bandwidth.mean(axis=1)))\n",
    "\n",
    "    coverages_antidiff_nw.append(coverage)\n",
    "    mean_bandwidths_antidiff_nw.append(mean_bandwidth)\n",
    "    num_infinite_antidiff_nw.append(num_infinite)\n",
    "    print(\n",
    "        f\"Time Step {i}, Coverage: {coverage}, Bandwidth Mean (finite): {mean_bandwidth}, Num Infinite: {num_infinite}\")\n",
    "\n",
    "np.savez(results_path / 'antidiff' / 'antidiff_unweighted_results.npz',\n",
    "         coverages=np.array(coverages_antidiff_nw),\n",
    "         mean_bandwidths=np.array(mean_bandwidths_antidiff_nw),\n",
    "         num_infinite=np.array(num_infinite_antidiff_nw))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "174b0cd5",
   "metadata": {},
   "source": [
    "coverages_bwheat_nw = []\n",
    "mean_bandwidths_bwheat_nw = []\n",
    "num_infinite_bwheat_nw = []\n",
    "\n",
    "for i in range(1, TIME_FRAME):\n",
    "    cpnw.calibrate(Y_cal=bwheat_data_calib['data'][:, 1], Yhat_cal=bwheat_predictions_calib[:, 1])\n",
    "\n",
    "    coverage, covered, bandwidth = cpnw.get_coverage(bwheat_data_test['data'][:, i], bwheat_predictions_test[:, i])\n",
    "    # Calculate mean bandwidth excluding infinite values\n",
    "    finite_bandwidth = bandwidth[np.isfinite(bandwidth.mean(axis=1)), :].mean(axis=1)\n",
    "    mean_bandwidth = finite_bandwidth.mean() if finite_bandwidth.size > 0 else float('nan')\n",
    "    num_infinite = np.sum(~np.isfinite(bandwidth.mean(axis=1)))\n",
    "\n",
    "    coverages_bwheat_nw.append(coverage)\n",
    "    mean_bandwidths_bwheat_nw.append(mean_bandwidth)\n",
    "    num_infinite_bwheat_nw.append(num_infinite)\n",
    "    print(\n",
    "        f\"Time Step {i}, Coverage: {coverage}, Bandwidth Mean (finite): {mean_bandwidth}, Num Infinite: {num_infinite}\")\n",
    "\n",
    "np.savez(results_path / 'bwheat' / 'bwheat_unweighted_results.npz',\n",
    "         coverages=np.array(coverages_bwheat_nw),\n",
    "         mean_bandwidths=np.array(mean_bandwidths_bwheat_nw),\n",
    "         num_infinite=np.array(num_infinite_bwheat_nw))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "77ae259c",
   "metadata": {},
   "source": [
    "coverages_reaction_nw = []\n",
    "mean_bandwidths_reaction_nw = []\n",
    "num_infinite_reaction_nw = []\n",
    "\n",
    "for i in range(1, TIME_FRAME):\n",
    "    cpnw.calibrate(Y_cal=reaction_data_calib['data'][:, 1], Yhat_cal=reaction_predictions_calib[:, 1])\n",
    "\n",
    "    coverage, covered, bandwidth = cpnw.get_coverage(reaction_data_calib['data'][:, i],\n",
    "                                                     reaction_predictions_calib[:, i])\n",
    "\n",
    "    # Calculate mean bandwidth excluding infinite values\n",
    "    finite_bandwidth = bandwidth[np.isfinite(bandwidth.mean(axis=1)), :].mean(axis=1)\n",
    "    mean_bandwidth = finite_bandwidth.mean() if finite_bandwidth.size > 0 else float('nan')\n",
    "    num_infinite = np.sum(~np.isfinite(bandwidth.mean(axis=1)))\n",
    "\n",
    "    coverages_reaction_nw.append(coverage)\n",
    "    mean_bandwidths_reaction_nw.append(mean_bandwidth)\n",
    "    num_infinite_reaction_nw.append(num_infinite)\n",
    "    print(\n",
    "        f\"Time Step {i}, Coverage: {coverage}, Bandwidth Mean (finite): {mean_bandwidth}, Num Infinite: {num_infinite}\")\n",
    "\n",
    "np.savez(results_path / 'reaction' / 'reaction_unweighted_results.npz',\n",
    "         coverages=np.array(coverages_reaction_nw),\n",
    "         mean_bandwidths=np.array(mean_bandwidths_reaction_nw),\n",
    "         num_infinite=np.array(num_infinite_reaction_nw))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "694db2ae",
   "metadata": {},
   "source": [
    "# LSCI Coverage"
   ]
  },
  {
   "cell_type": "code",
   "id": "c9b8b5ff",
   "metadata": {},
   "source": [
    "# Prepare Data for LSCI\n",
    "\n",
    "def prepare_lsci_data(data_calib, residuals_calib, data_test, residuals_test, save_path, base_time=1, target_time=2):\n",
    "    \"\"\"\n",
    "    Prepare data for LSCI by selecting random time indices.\n",
    "\n",
    "    Args:\n",
    "        data: The input data array.\n",
    "        residuals: The residuals corresponding to the input data.\n",
    "        test_data: The test data array.\n",
    "        residuals_test: The residuals corresponding to the test data.\n",
    "    Returns:\n",
    "        None\n",
    "    \"\"\"\n",
    "    xval = np.zeros((data_calib.shape[0], data_calib.shape[-1]))\n",
    "    rval = np.zeros((residuals_calib.shape[0], residuals_calib.shape[-1]))\n",
    "    xtest = np.zeros((data_test.shape[0], data_test.shape[-1]))\n",
    "    rtest = np.zeros((residuals_test.shape[0], residuals_test.shape[-1]))\n",
    "\n",
    "    for i in range(data_calib.shape[0]):\n",
    "        xval[i] = data_calib[i, 0]\n",
    "        rval[i] = residuals_calib[i, base_time]\n",
    "\n",
    "    for i in range(data_test.shape[0]):\n",
    "        xtest[i] = data_test[i, target_time - TIME_STEP_SIZE]\n",
    "        rtest[i] = residuals_test[i, target_time]\n",
    "\n",
    "    np.savez(save_path, xval=xval, rval=rval, xtest=xtest, rtest=rtest)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "42081341",
   "metadata": {},
   "source": [
    "\n",
    "for i in range(TIME_STEP_SIZE, TIME_FRAME):\n",
    "    antidiff_path = (results_path / 'antidiff' / 'processed_data' / f'antidiff_lsci_{i}.npz')\n",
    "    antidiff_path.parent.mkdir(parents=True, exist_ok=True)\n",
    "    prepare_lsci_data(antidiff_data_calib['data'], antidiff_residuals_calib.numpy(), antidiff_data_test['data'],\n",
    "                      antidiff_residuals_test.numpy(), antidiff_path, base_time=1, target_time=i)\n",
    "\n",
    "    bwheat_path = (results_path / 'bwheat' / 'processed_data' / f'bwheat_lsci_{i}.npz')\n",
    "    bwheat_path.parent.mkdir(parents=True, exist_ok=True)\n",
    "    prepare_lsci_data(bwheat_data_calib['data'], bwheat_residuals_calib.numpy(), bwheat_data_test['data'],\n",
    "                      bwheat_residuals_test.numpy(), bwheat_path, base_time=1, target_time=i)\n",
    "\n",
    "    reaction_path = (results_path / 'reaction' / 'processed_data' / f'reaction_lsci_{i}.npz')\n",
    "    reaction_path.parent.mkdir(parents=True, exist_ok=True)\n",
    "    prepare_lsci_data(reaction_data_calib['data'], reaction_residuals_calib.numpy(), reaction_data_test['data'],\n",
    "                      reaction_residuals_test.numpy(), reaction_path, base_time=1, target_time=i)\n"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "suffix",
   "id": "98e8e6a885beb2bf",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "787513ec",
   "metadata": {},
   "source": [
    "from pathlib import Path\n",
    "from src.conformal_prediction.lsci import run_lsci  # adjust import if module path differs\n",
    "\n",
    "lsci_antidiff_outdir = results_path / 'antidiff' / 'processed_data' / 'lsci_experiments'\n",
    "lsci_antidiff_outdir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "lsci_bwheat_outdir = results_path / 'bwheat' / 'processed_data' / 'lsci_experiments'\n",
    "lsci_bwheat_outdir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "lsci_reaction_outdir = results_path / 'reaction' / 'processed_data' / 'lsci_experiments'\n",
    "lsci_reaction_outdir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "base_path = Path().resolve().parent\n",
    "\n",
    "for i in range(TIME_STEP_SIZE, TIME_FRAME):\n",
    "    print(f\"Running LSCI for {i}\")\n",
    "    suffix = f\"_{i}\"\n",
    "\n",
    "    antidiff_lsci_path = results_path / 'antidiff' / 'processed_data' / f'antidiff_lsci_{i}.npz'\n",
    "    run_lsci(\n",
    "        npz_all=antidiff_lsci_path,\n",
    "        lam=5,\n",
    "        n_proj=20,\n",
    "        alpha=ALPHA,\n",
    "        out_dir=lsci_antidiff_outdir,\n",
    "        suffix=suffix,\n",
    "        save_weights=True,\n",
    "        verbose=True,\n",
    "    )\n",
    "\n",
    "    bwheat_lsci_path = results_path / 'bwheat' / 'processed_data' / f'bwheat_lsci_{i}.npz'\n",
    "    run_lsci(\n",
    "        npz_all=bwheat_lsci_path,\n",
    "        lam=5,\n",
    "        n_proj=20,\n",
    "        alpha=ALPHA,\n",
    "        out_dir=lsci_bwheat_outdir,\n",
    "        suffix=suffix,\n",
    "        save_weights=True,\n",
    "        verbose=True,\n",
    "    )\n",
    "\n",
    "    reaction_lsci_path = results_path / 'reaction' / 'processed_data' / f'reaction_lsci_{i}.npz'\n",
    "    run_lsci(\n",
    "        npz_all=reaction_lsci_path,\n",
    "        lam=5,\n",
    "        n_proj=20,\n",
    "        alpha=ALPHA,\n",
    "        out_dir=lsci_reaction_outdir,\n",
    "        suffix=suffix,\n",
    "        save_weights=True,\n",
    "        verbose=True,\n",
    "    )\n"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "c2b03c98",
   "metadata": {},
   "source": "'## Get LSCI actual coverage"
  },
  {
   "cell_type": "code",
   "id": "32f3a1e4",
   "metadata": {},
   "source": [
    "import lsci.conformal as lsci\n",
    "from jax import numpy as jnp\n",
    "\n",
    "\n",
    "def get_lsci_coverage(lsci_outdir, filename, time_steps=TIME_FRAME):\n",
    "    lsci_coverages = []\n",
    "    conf_ens_list = []\n",
    "\n",
    "    for i in trange(1, time_steps):\n",
    "        lsci_time_coverages = []\n",
    "        conf_ens_time_list = []\n",
    "\n",
    "        antidiff_lsci_path = str(lsci_outdir / f'{filename}_{i}.npz')\n",
    "        local_weights = np.load(lsci_outdir / 'lsci_experiments' / f'local_weights_{i}.npz')['local_weights_val']\n",
    "\n",
    "        anti_diff_lsci_input = np.load(antidiff_lsci_path)\n",
    "        rval, rtest = anti_diff_lsci_input['rval'], anti_diff_lsci_input['rtest']\n",
    "        rval, rtest, local_weights = jnp.array(rval), jnp.array(rtest), jnp.array(local_weights)\n",
    "\n",
    "        for j in range(rtest.shape[0]):\n",
    "            conf_ens = lsci.local_sampler(rval, local_weights[j], ALPHA, 500, 20)\n",
    "            lo, hi = jnp.min(conf_ens, axis=0), jnp.max(conf_ens, axis=0)\n",
    "            coverage = jnp.mean((rtest >= lo) & (rtest <= hi))\n",
    "            coverage = float(coverage)\n",
    "\n",
    "            lsci_time_coverages.append(coverage >= 0.99)\n",
    "            conf_ens_time_list.append(np.array([lo, hi]))\n",
    "\n",
    "        lsci_coverages.append(lsci_time_coverages)\n",
    "        conf_ens_list.append(conf_ens_time_list)\n",
    "\n",
    "    return lsci_coverages, conf_ens_list"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "5db2b483",
   "metadata": {},
   "source": [
    "coverages_lsci_antidiff, conf_ens_lsci_antidiff = get_lsci_coverage(\n",
    "    (results_path / 'antidiff' / 'processed_data'), 'antidiff_lsci', TIME_FRAME)\n",
    "np.savez(results_path / 'antidiff' / 'lsci_coverages.npz', coverages=coverages_lsci_antidiff,\n",
    "         conf_ens=conf_ens_lsci_antidiff)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "d6771848",
   "metadata": {},
   "source": [
    "coverages_lsci_bwheat, conf_ens_lsci_bwheat = get_lsci_coverage((results_path / 'bwheat' / 'processed_data'),\n",
    "                                                                'bwheat_lsci', TIME_FRAME)\n",
    "np.savez(results_path / 'bwheat' / 'lsci_coverages.npz', coverages=coverages_lsci_bwheat,\n",
    "         conf_ens=conf_ens_lsci_bwheat)\n"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "coverages_lsci_reaction, conf_ens_lsci_reaction = get_lsci_coverage(\n",
    "    (results_path / 'reaction' / 'processed_data'), 'reaction_lsci', TIME_FRAME)\n",
    "\n"
   ],
   "id": "ec0526705b448144",
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "neural_operators (3.13.5)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
