{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7e004218",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bacba25e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(CVXPY) Apr 18 11:12:55 PM: Encountered unexpected exception importing solver SCS:\n",
      "ImportError(\"dlopen(/Users/cherian/opt/miniconda3/envs/conformal-gan/lib/python3.10/site-packages/_scs_direct.cpython-310-darwin.so, 0x0002): Library not loaded: '@rpath/libmkl_rt.2.dylib'\\n  Referenced from: '/Users/cherian/opt/miniconda3/envs/conformal-gan/lib/python3.10/site-packages/_scs_direct.cpython-310-darwin.so'\\n  Reason: tried: '/Users/cherian/opt/miniconda3/envs/conformal-gan/lib/python3.10/site-packages/../../libmkl_rt.2.dylib' (no such file), '/Users/cherian/opt/miniconda3/envs/conformal-gan/lib/python3.10/site-packages/../../libmkl_rt.2.dylib' (no such file), '/Users/cherian/opt/miniconda3/envs/conformal-gan/bin/../lib/libmkl_rt.2.dylib' (no such file), '/Users/cherian/opt/miniconda3/envs/conformal-gan/bin/../lib/libmkl_rt.2.dylib' (no such file), '/usr/local/lib/libmkl_rt.2.dylib' (no such file), '/usr/lib/libmkl_rt.2.dylib' (no such file)\")\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "from scipy.stats import norm\n",
    "from sklearn.preprocessing import PolynomialFeatures\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from tqdm import tqdm\n",
    "\n",
    "from conditionalconformal.synthetic_data import generate_cqr_data, indicator_matrix\n",
    "from conditionalconformal import CondConf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "eb3f174d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate data\n",
    "x_train_final, y_train_final, x_calib, y_calib, x_test, y_test = generate_cqr_data(seed=0, n_calib=250)\n",
    "\n",
    "# fit a fourth order polynomial\n",
    "poly = PolynomialFeatures(4)\n",
    "reg = LinearRegression().fit(poly.fit_transform(x_train_final), y_train_final)\n",
    "\n",
    "# nominal level is 0.9\n",
    "alpha = 0.1\n",
    "\n",
    "# score function is residual\n",
    "score_fn = lambda x, y : y - reg.predict(poly.fit_transform(x))\n",
    "score_inv_fn_ub = lambda s, x : [-np.inf, reg.predict(poly.fit_transform(x)) + s]\n",
    "score_inv_fn_lb = lambda s, x : [reg.predict(poly.fit_transform(x)) + s, np.inf]\n",
    "\n",
    "# coverage on indicators of all sub-intervals with endpoints in [0,0.5,1,..,5]\n",
    "eps = 0.5\n",
    "disc = np.arange(0, 5 + eps, eps)\n",
    "\n",
    "def phi_fn_groups(x):\n",
    "    return indicator_matrix(x, disc)\n",
    "\n",
    "# coverage on Gaussians with mu=loc and sd=scale \n",
    "# scale = 1 for x != [1.5, 3.5]\n",
    "eval_locs = [1.5, 3.5]\n",
    "eval_scale = 0.2\n",
    "\n",
    "other_locs = [0.5, 2.5, 4.5]\n",
    "other_scale = 1\n",
    "\n",
    "def phi_fn_shifts(x):\n",
    "    shifts = [norm.pdf(x, loc=loc, scale=eval_scale).reshape(-1,1)\n",
    "                   for loc in eval_locs]\n",
    "    shifts.extend([norm.pdf(x, loc=loc, scale=other_scale).reshape(-1,1)\n",
    "                   for loc in other_locs])\n",
    "    shifts.append(np.ones((x.shape[0], 1)))\n",
    "    return np.concatenate(shifts, axis=1)\n",
    "\n",
    "# intercept only phi_fn\n",
    "def phi_fn_intercept(x):\n",
    "    return np.ones((x.shape[0], 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "id": "deadbfc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.optimize import linprog\n",
    "\n",
    "experiment = 'groups' # valid choices: ['groups', 'shifts', 'agnostic']\n",
    "\n",
    "if experiment == 'groups':\n",
    "    phi_fn = phi_fn_groups\n",
    "    infinite_params = {}\n",
    "elif experiment == 'shifts':\n",
    "    phi_fn = phi_fn_shifts\n",
    "    infinite_params = {}\n",
    "elif experiment == 'agnostic':\n",
    "    phi_fn = phi_fn_intercept\n",
    "    infinite_params = {'kernel': 'rbf', 'gamma': 12.5, 'lambda': 0.005}\n",
    "else:\n",
    "    raise ValueError(f\"Invalid value for experiment: {experiment}.\")\n",
    "\n",
    "\n",
    "def get_regression_betas(Phi, S_calib, S_test = None, quantile=0.95):\n",
    "    if S_test == None:\n",
    "        S = S_calib\n",
    "    else:\n",
    "        S = np.concatenate((S_calib, [[S_test]]))\n",
    "    zeros = np.zeros((Phi.shape[1],))\n",
    "    bounds = [(quantile - 1, quantile)] * len(S)\n",
    "    res = linprog(-1 * S, A_eq=Phi.T, b_eq=zeros, bounds=bounds, method='highs')\n",
    "    primal_vars = -1 * res.eqlin.marginals.reshape(-1,1)\n",
    "    dual_vars = res.x.reshape(-1,1)\n",
    "    return primal_vars, dual_vars[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "id": "39c765fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:05<00:00,  4.59it/s]\n"
     ]
    }
   ],
   "source": [
    "# plot predictions\n",
    "\n",
    "def generate_plot(i, x_aug, Phi, S_calib, x_test_val, y_test_val, S_val):\n",
    "    beta_upper, dual_upper = get_regression_betas(Phi, S_calib, S_val, quantile=0.95)\n",
    "    beta_lower, dual_lower = get_regression_betas(Phi, S_calib, S_val, quantile=0.05)\n",
    "            \n",
    "    lbs = Phi @ beta_lower.reshape(-1,1)\n",
    "    ubs = Phi @ beta_upper.reshape(-1,1)\n",
    "    \n",
    "    cp = sns.color_palette()\n",
    "    sns.set(font=\"DejaVu Sans\")\n",
    "    sns.set_style(\"whitegrid\", {'axes.grid' : False})\n",
    "    fig = plt.figure()\n",
    "    fig.set_size_inches(8, 5)\n",
    "\n",
    "    sort_order = np.argsort(x_aug[:,0])\n",
    "    x_aug_s = x_aug[sort_order]\n",
    "    y_aug_hat = reg.predict(poly.fit_transform(x_aug[sort_order]))\n",
    "    lb = y_aug_hat.flatten() + lbs[sort_order].flatten()\n",
    "    ub = y_aug_hat.flatten() + ubs[sort_order].flatten()\n",
    "    \n",
    "\n",
    "    ax1 = fig.add_subplot(1, 1, 1)\n",
    "    ax1.plot(x_calib, y_calib, '.', alpha=0.2)\n",
    "    if S_val is not None:\n",
    "        ax1.scatter(x_test_val, y_test_val, color='r', s=20, marker='x')\n",
    "        ax1.annotate(f'Dual: {dual_upper}', xy=(x_test_val, y_test_val), \n",
    "                     xytext=(x_test_val + 0.5, y_test_val + 1),\n",
    "                     arrowprops=dict(facecolor='black', shrink=0.05),\n",
    "                    color='red')\n",
    "\n",
    "\n",
    "    ax1.plot(x_aug_s, y_aug_hat, lw=1, color='k')\n",
    "    ax1.plot(x_aug_s, ub, color=cp[0], lw=2)\n",
    "    ax1.plot(x_aug_s, lb, color=cp[0], lw=2)\n",
    "    ax1.fill_between(x_aug_s.flatten(), lb.flatten(), ub.flatten(), \n",
    "                     color=cp[0], alpha=0.2)\n",
    "    ax1.set_ylim(-2,6.5)\n",
    "    ax1.tick_params(axis='both', which='major', labelsize=14)\n",
    "    ax1.set_xlabel(\"$X$\", fontsize=16, labelpad=10)\n",
    "    ax1.set_ylabel(\"$Y$\", fontsize=16, labelpad=10)\n",
    "    ax1.set_title(\"Dual tracking\", fontsize=18, pad=12)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f'animation/iteration_dual_{i}.pdf')\n",
    "    plt.close()\n",
    "\n",
    "x_test_val = x_test[1].reshape(1,-1)\n",
    "S_calib = score_fn(x_calib, y_calib).reshape(-1,1)\n",
    "x_aug = np.concatenate((x_calib, x_test_val))\n",
    "Phi = phi_fn(x_aug).astype(float)\n",
    "\n",
    "i = 0\n",
    "for S_val in tqdm(np.linspace(0.5, 1.9, 25)):\n",
    "    y_test_val = reg.predict(poly.fit_transform(x_test_val)) + S_val\n",
    "    generate_plot(i, x_aug, Phi, S_calib, x_test_val, y_test_val, S_val)\n",
    "    i += 1\n",
    "    \n",
    "generate_plot(-1, x_aug[:-1], Phi[:-1], S_calib, None, None, None)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "75484af1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[3.728333]]\n"
     ]
    }
   ],
   "source": [
    "print(x_test_val)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
