{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: PYTHONHASHSEED=0\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "np.random.seed(0)\n",
    "import os\n",
    "import sys\n",
    "import sklearn\n",
    "sys.path.insert(1, f\"{os.getcwd()}/backend\")\n",
    "\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "from matplotlib.ticker import NullFormatter\n",
    "import seaborn as sns\n",
    "sns.set(palette=\"bright\",style=\"ticks\",font=\"Arial\")\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "from matplotlib.patches import PathPatch\n",
    "from matplotlib.ticker import FormatStrFormatter\n",
    "\n",
    "from sklearn import linear_model\n",
    "from sklearn.linear_model import HuberRegressor\n",
    "from sklearn import manifold, datasets\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn import preprocessing\n",
    "\n",
    "from functools import partial\n",
    "import cvxpy as cp\n",
    "import pandas as pd\n",
    "import copy\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "from sklearn.metrics.pairwise import rbf_kernel\n",
    "import os.path\n",
    "import pdb\n",
    "import scipy as sp\n",
    "import hashlib\n",
    "import joblib\n",
    "import pickle\n",
    "import pdb\n",
    "\n",
    "from np_backend.dro_conformal import dro_conformal_quantile_procedure_cvx\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "%env PYTHONHASHSEED=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"numpy: {np.__version__}\")\n",
    "print(f\"sys (standard library, version not applicable)\")\n",
    "print(f\"matplotlib: {matplotlib.__version__}\")\n",
    "print(f\"seaborn: {sns.__version__}\")\n",
    "print(f\"sklearn: {sklearn.__version__}\")\n",
    "print(f\"cvxpy: {cp.__version__}\")\n",
    "print(f\"pandas: {pd.__version__}\")\n",
    "print(f\"scipy: {sp.__version__}\")\n",
    "print(f\"joblib: {joblib.__version__}\")\n",
    "print(f\"pickle (standard library, version not applicable)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# NOTE: don't forget to set these v.\n",
    "\n",
    "num_reps = 20\n",
    "dataset = \"bos\"\n",
    "fix_train = True\n",
    "base_models = [\"rf\"] # [\"lm\", \"rf\", \"rob\"]\n",
    "stdize = False\n",
    "\n",
    "alphas = [0.05]\n",
    "# alphas = [0.05, 0.1, 0.2, 0.3, 0.4]\n",
    "alphas.sort()\n",
    "\n",
    "if(dataset == \"gaus\"):\n",
    "    rhos = [0.01]\n",
    "    calib_test_shift_pcts = [0, 1]\n",
    "\n",
    "elif(dataset == \"airfoil\"):\n",
    "    rhos = [1e-2]\n",
    "    calib_test_shift_pcts = [0, 1]\n",
    "    \n",
    "elif(dataset == \"abalone\"):\n",
    "    rhos = [1e-2]\n",
    "    calib_test_shift_pcts = [0, 1]\n",
    "    \n",
    "elif(dataset == \"ca\"):\n",
    "    rhos = [1e-2]\n",
    "    calib_test_shift_pcts = [0, 1]\n",
    "\n",
    "elif(dataset == \"delta\"):\n",
    "    rhos = [1e-2]\n",
    "    calib_test_shift_pcts = [0, 1]\n",
    "     \n",
    "elif(dataset == \"ailerons\"):\n",
    "    rhos = [1e-2]\n",
    "    calib_test_shift_pcts = [0, 1]\n",
    "    \n",
    "elif(dataset == \"bank\"):\n",
    "    rhos = [1e-2]\n",
    "    calib_test_shift_pcts = [0, 1]\n",
    "    \n",
    "elif(dataset == \"bos\"):\n",
    "    rhos = [1e-2]\n",
    "    calib_test_shift_pcts = [0, 1]\n",
    "    \n",
    "elif(dataset == \"cpu\"):\n",
    "    rhos = [1e-2]\n",
    "    calib_test_shift_pcts = [0, 1]\n",
    "    \n",
    "elif(dataset == \"kin\"):\n",
    "    rhos = [1e-2]\n",
    "    calib_test_shift_pcts = [0, 1]\n",
    "    \n",
    "elif(dataset == \"puma\"):\n",
    "    rhos = [1e-2]\n",
    "    calib_test_shift_pcts = [0, 1]\n",
    "\n",
    "else:\n",
    "    print(\"ERROR: Bad dataset.\")\n",
    "    assert(False)\n",
    "    \n",
    "rhos.sort()\n",
    "\n",
    "kl = lambda z : -cp.entr(z)                     # This is the K-L-ball (it's equivalent to z*cp.log(z),\n",
    "                                                # just written in a way that's DCP-compliant.\n",
    "chisq = lambda z : 0.5 * cp.sum_squares(z-1)    # This is the chi-squared ball.\n",
    "\n",
    "ff= \"../data/jasa_10_07_2023_data/\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if(dataset == \"gaus\"):\n",
    "    \n",
    "    my_n = 5000\n",
    "    my_p = 10\n",
    "    \n",
    "    A = np.random.randn(my_p,my_p)\n",
    "    A = A.T @ A\n",
    "    S,Q = np.linalg.eigh(A)\n",
    "    \n",
    "    theta0 = Q[:,0]\n",
    "    theta1 = Q[:,1]\n",
    "    \n",
    "    X = np.random.randn(my_n,my_p)\n",
    "    eps = np.random.randn(my_n)\n",
    "    y = X @ theta0 + eps\n",
    "    \n",
    "    X_shifted = 2 + np.copy(X)\n",
    "    y_shifted = X_shifted @ theta0 + eps\n",
    "    \n",
    "    ts = np.arange(0, 1.1, 0.1)\n",
    "        \n",
    "    np.savetxt(ff+dataset + \"_X.csv\", X, delimiter=\",\")\n",
    "    np.savetxt(ff+dataset + \"_y.csv\", y, delimiter=\",\")        \n",
    "        \n",
    "    np.savetxt(ff+dataset + \"_X_shifted.csv\", X_shifted, delimiter=\",\")\n",
    "    np.savetxt(ff+dataset + \"_y_shifted.csv\", y_shifted, delimiter=\",\")         \n",
    "        \n",
    "elif(dataset == \"airfoil\"):\n",
    "\n",
    "    X = pd.read_csv(ff+dataset + \"_X.csv\", index_col=0).to_numpy()\n",
    "    y = pd.read_csv(ff+dataset + \"_y.csv\", index_col=0).to_numpy()[:,0]\n",
    "    \n",
    "elif((dataset == \"abalone\") or\n",
    "     (dataset == \"ca\") or\n",
    "     (dataset == \"delta\") or\n",
    "     (dataset == \"ailerons\") or\n",
    "     (dataset == \"bank\") or\n",
    "     (dataset == \"bos\") or\n",
    "     (dataset == \"cpu\") or\n",
    "     (dataset == \"kin\") or\n",
    "     (dataset == \"puma\")):\n",
    "    \n",
    "    X = pd.read_csv(ff+dataset + \"_X.csv\").to_numpy()\n",
    "    y = pd.read_csv(ff+dataset + \"_y.csv\").to_numpy()[:,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if stdize:\n",
    "    X = preprocessing.scale(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = X.shape[0]\n",
    "p = X.shape[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alg_names_ordered = [\"Standard\", \"K-L\", \"Chi-squared\"]\n",
    "alg_idxes = {\"Standard\":0, \"K-L\":1, \"Chi-squared\":2}\n",
    "\n",
    "shift_taus = [0.5, 0.6, 0.7]\n",
    "\n",
    "shift_fns = [\"I-A\", \"I-B\"]\n",
    "\n",
    "for dim_idx in range(p):\n",
    "    shift_fns += [\"I-C-\" + str(dim_idx)]\n",
    "\n",
    "for shift_tau in shift_taus:\n",
    "    shift_fns += [\"I-D-\" + str(shift_tau)]\n",
    "\n",
    "for dim_idx in range(p):\n",
    "    for shift_tau in shift_taus:\n",
    "        shift_fns += [\"I-E-\" + str(dim_idx) + \"-\" + str(shift_tau)]\n",
    "\n",
    "shift_betas = np.array([0.02, 0.04, 0.08, 0.16, 0.32, 0.64])\n",
    "shift_betas = np.concatenate([shift_betas, -shift_betas])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_algs = len(alg_idxes)\n",
    "num_mods = len(base_models)\n",
    "num_alphas = len(alphas)\n",
    "num_rhos = len(rhos)\n",
    "num_shift_fns = len(shift_fns)\n",
    "num_shift_betas = len(shift_betas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ws = \"  \""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ext_jlib = \"jlib\"\n",
    "ext_pkl = \"pkl\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def shift(X, calib_idxes, test_idxes, calib_test_shift_pcts, shift_fn, shift_beta, v, base_qtys):\n",
    "    \n",
    "    n_test = len(test_idxes)\n",
    "    X_test_idxes = X[test_idxes,:]\n",
    "    X_test_idxes_demeaned = X_test_idxes - np.mean(X_test_idxes, axis=0)\n",
    "    X_test_idxes_dot_v = base_qtys[\"X_test_idxes_dot_v\"]\n",
    "    X_test_idxes_demeaned_dot_v = base_qtys[\"X_test_idxes_demeaned_dot_v\"]\n",
    "    \n",
    "    X_calib_idxes = X[calib_idxes,:]\n",
    "    X_calib_idxes_dot_v = base_qtys[\"X_calib_idxes_dot_v\"]\n",
    "    \n",
    "    if shift_fn == \"I-A\":\n",
    "        w = np.ones(n_test)\n",
    "    \n",
    "    elif shift_fn == \"I-B\":\n",
    "        w = shift_beta * X_test_idxes_demeaned_dot_v\n",
    "    \n",
    "    elif \"I-C\" in shift_fn:\n",
    "        dim_idx = int(shift_fn.replace(\"I-C-\", \"\"))\n",
    "        w = shift_beta * X_test_idxes_demeaned[:,dim_idx]\n",
    "    \n",
    "    elif \"I-D\" in shift_fn:\n",
    "        shift_tau = float(shift_fn.replace(\"I-D-\", \"\"))\n",
    "        tau = np.quantile(X_calib_idxes_dot_v, shift_tau)\n",
    "        indicators = np.array(X_test_idxes_dot_v >= tau, dtype=\"int\")\n",
    "        w = shift_beta * indicators\n",
    "    \n",
    "    elif \"I-E\" in shift_fn:\n",
    "        dim_idx_shift_tau_str = shift_fn.replace(\"I-E-\", \"\")\n",
    "        dash_idx = dim_idx_shift_tau_str.index(\"-\")\n",
    "\n",
    "        dim_idx_str = dim_idx_shift_tau_str[0:dash_idx]\n",
    "        shift_tau = float(dim_idx_shift_tau_str.replace(dim_idx_str + \"-\", \"\"))\n",
    "        dim_idx = int(dim_idx_str)\n",
    "        \n",
    "        tau = np.quantile(X_calib_idxes[:,dim_idx], shift_tau)\n",
    "        indicators = np.array(X_test_idxes[:,dim_idx] >= tau, dtype=\"int\")\n",
    "        w = shift_beta * indicators\n",
    "        \n",
    "    if \"I-A\" not in shift_fn:\n",
    "        w = np.exp(w - np.max(w))\n",
    "    \n",
    "    w = w/np.sum(w)\n",
    "    \n",
    "    assert(all(w >= 0))\n",
    "    assert(all(w <= 1))\n",
    "    assert(all(np.logical_not(np.isnan(w))))\n",
    "    assert(all(np.logical_not(np.isinf(w))))\n",
    "    \n",
    "    return w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_coverage_and_length(lo, hi, y, test_idxes, test_idxes_weights):\n",
    "    \n",
    "    n_test = len(test_idxes)\n",
    "    raw_coverages = np.asarray([1. if ((lo[i] <= y[test_idxes[i]]) & (hi[i] >= y[test_idxes[i]]))\n",
    "                                else 0 for i in range(n_test)])\n",
    "    coverages = np.multiply(raw_coverages, test_idxes_weights)\n",
    "\n",
    "    length = np.mean(hi - lo)\n",
    "    \n",
    "    return np.sum(coverages), length, raw_coverages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_alpha(alpha_in, n_in):\n",
    "    return np.maximum(1. - (n_in+1.)*(1.-alpha_in)/n_in, 0.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wksp_fp = ff + dataset + \"_wksp.\" + ext_pkl\n",
    "test_idxes_weights_wksp_fp = ff + dataset + \"_test_idxes_weights_wksp.\" + ext_jlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.isfile(wksp_fp):\n",
    "    \n",
    "    wksp = pickle.load(open(wksp_fp, \"rb\")) # joblib.load(wksp_fp)\n",
    "    joblib.load(test_idxes_weights_wksp_fp) # test_idxes_weights_wksp = pickle.load(open(test_idxes_weights_wksp_fp, \"rb\"))\n",
    "\n",
    "    train_idxes_list = wksp[\"train_idxes\"]\n",
    "    base_model_objs_list_dict = wksp[\"base_model_objs\"]\n",
    "\n",
    "    calib_idxes_list = wksp[\"calib_idxes\"]\n",
    "\n",
    "    test_idxes_list = wksp[\"test_idxes\"]\n",
    "    test_idxes_weights_list_dict = test_idxes_weights_wksp[\"test_idxes_weights\"]\n",
    "    v_list = wksp[\"v\"]\n",
    "    \n",
    "    coverages_dict = wksp[\"coverages_dict\"]\n",
    "    raw_coverages_dict = wksp[\"raw_coverages_dict\"]\n",
    "    lengths_dict = wksp[\"lengths_dict\"]\n",
    "    \n",
    "    wksp_fp_exists = True\n",
    "    \n",
    "    if dataset == \"gaus\":\n",
    "        ts = wksp[\"ts\"]\n",
    "    \n",
    "else:\n",
    "    \n",
    "    wksp = {}\n",
    "    test_idxes_weights_wksp = {}\n",
    "    \n",
    "    train_idxes_list = []\n",
    "    base_model_objs_list_dict = []\n",
    "\n",
    "    calib_idxes_list = []\n",
    "\n",
    "    test_idxes_list = []\n",
    "    test_idxes_weights_list_dict = []\n",
    "    v_list = []\n",
    "\n",
    "    if dataset == \"gaus\":\n",
    "        lengths_dict = np.inf*np.ones((num_alphas,num_rhos,num_reps,len(ts),num_algs))\n",
    "        coverages_dict = np.inf*np.ones((num_alphas,num_rhos,num_reps,len(ts),num_algs))\n",
    "\n",
    "    else:        \n",
    "        coverages_dict = {}\n",
    "        raw_coverages_dict = {}\n",
    "        lengths_dict = np.inf*np.ones((num_alphas,num_rhos,num_reps,num_mods,num_algs))\n",
    "        for shift_fn_idx, shift_fn in enumerate(shift_fns):\n",
    "\n",
    "            for shift_beta_idx, shift_beta in enumerate(shift_betas):\n",
    "                if (shift_fn_idx == 0) and (shift_beta_idx > 0):\n",
    "                    continue\n",
    "\n",
    "                coverages_dict[shift_fn, shift_beta] = np.inf*np.ones((num_alphas,num_rhos,num_reps,num_mods,num_algs))\n",
    "    \n",
    "    wksp_fp_exists = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for rep_idx in range(num_reps):\n",
    "    print(\"Trial %d/%d (dataset %s):\" % (rep_idx, num_reps - 1, dataset))\n",
    "    base_qtys = {}\n",
    "    \n",
    "    # 1a) If data already exists (i.e. has already been split up +\n",
    "    # shifted), then load it ... except for the test weights, which\n",
    "    # we will load later on\n",
    "    if wksp_fp_exists:\n",
    "        \n",
    "        train_idxes = train_idxes_list[rep_idx]\n",
    "        n_train = len(train_idxes)        \n",
    "        base_model_objs = base_model_objs_list_dict[rep_idx]\n",
    "        \n",
    "        calib_idxes = calib_idxes_list[rep_idx]\n",
    "        n_calib = len(calib_idxes)\n",
    "        \n",
    "        test_idxes = test_idxes_list[rep_idx]\n",
    "        n_test = len(test_idxes)\n",
    "        test_idxes_weights_dict = test_idxes_weights_list_dict[rep_idx]\n",
    "        v = v_list[rep_idx]\n",
    "        \n",
    "    # 1b) Else we are seeing data for the first time\n",
    "    else:\n",
    "        \n",
    "        # So, split up the data + save it, first\n",
    "\n",
    "        # Training data:\n",
    "        if ((fix_train and (rep_idx == 0)) or not fix_train):\n",
    "            n_train = int(n / 3.)\n",
    "            train_idxes = np.random.choice(range(n), n_train, replace=False)\n",
    "        train_idxes_list += [train_idxes]\n",
    "            \n",
    "        # Calibration data:\n",
    "        n_calib_test = n - n_train\n",
    "        calib_test_idxes = np.setdiff1d(range(n), train_idxes)\n",
    "        n_calib = int(n_calib_test / 2.)\n",
    "        calib_idxes = np.random.choice(calib_test_idxes, n_calib, replace=False)\n",
    "        calib_idxes_list += [calib_idxes]\n",
    "        \n",
    "        # Test data:\n",
    "        test_idxes = np.setdiff1d(calib_test_idxes, calib_idxes)\n",
    "        n_test = len(test_idxes)\n",
    "        test_idxes_list += [test_idxes]\n",
    "        \n",
    "        if dataset == \"gaus\":\n",
    "            base_qtys[\"X_test_idxes_dot_v\"] = None\n",
    "            base_qtys[\"X_test_idxes_demeaned_dot_v\"] = None\n",
    "            base_qtys[\"X_calib_idxes_dot_v\"] = None\n",
    "        \n",
    "        else:\n",
    "            calib_test_idxes = np.concatenate([calib_idxes, test_idxes]).astype(int)\n",
    "            X_calib_test_idxes = X[calib_test_idxes,:]\n",
    "            X_calib_test_idxes -= np.mean(X_calib_test_idxes, axis=0)\n",
    "            _, _, VT = sp.linalg.svd(X_calib_test_idxes)\n",
    "            v = VT[0,:]\n",
    "            v_list += [v]\n",
    "\n",
    "            X_test_idxes = X[test_idxes,:]\n",
    "            X_calib_idxes = X[calib_idxes,:]        \n",
    "            base_qtys[\"X_test_idxes_dot_v\"] = X_test_idxes.dot(v)\n",
    "            base_qtys[\"X_test_idxes_demeaned_dot_v\"] = (X_test_idxes - np.mean(X_test_idxes, axis=0)).dot(v)\n",
    "            base_qtys[\"X_calib_idxes_dot_v\"] = X_calib_idxes.dot(v)\n",
    "\n",
    "            test_idxes_weights_dict = {}\n",
    "            for shift_fn_idx, shift_fn in enumerate(shift_fns):\n",
    "\n",
    "                for shift_beta_idx, shift_beta in enumerate(shift_betas):\n",
    "                    if (shift_fn_idx == 0) and (shift_beta_idx > 0):\n",
    "                        continue\n",
    "                        \n",
    "                    test_idxes_weights = shift(X,\n",
    "                                               calib_idxes,\n",
    "                                               test_idxes,\n",
    "                                               calib_test_shift_pcts,\n",
    "                                               shift_fn,\n",
    "                                               shift_beta,\n",
    "                                               v,\n",
    "                                               base_qtys)\n",
    "                    test_idxes_weights_dict[shift_fn, shift_beta] = test_idxes_weights\n",
    "            test_idxes_weights_list_dict += [test_idxes_weights_dict]\n",
    "            print()\n",
    "\n",
    "        # b) Fit the base estimator(s) to the training data\n",
    "        if ((fix_train and (rep_idx == 0)) or not fix_train):\n",
    "            base_model_objs_dict = {}\n",
    "\n",
    "            for mod_idx, base_model in enumerate(base_models):\n",
    "\n",
    "                if(base_model == \"lm\"):\n",
    "                    model_obj = linear_model.LinearRegression()\n",
    "                elif(base_model == \"rf\"):\n",
    "                    model_obj = RandomForestRegressor(\n",
    "                        random_state=0)\n",
    "                elif(base_model == \"rob\"):\n",
    "                    model_obj = HuberRegressor()\n",
    "                else:\n",
    "                    print(ws + \"ERROR: Bad base_model.\")\n",
    "                    assert(False)\n",
    "                model_obj.fit(X[train_idxes, :], y[train_idxes])        \n",
    "                base_model_objs_dict[base_model] = model_obj        \n",
    "        base_model_objs_list_dict += [base_model_objs_dict]\n",
    "                    \n",
    "    # 1c) Report some stats\n",
    "    print(ws + \"Just FYI, here are the various data set sizes:\")\n",
    "    print(ws*2 + \"Train: %d.\" % n_train)\n",
    "    print(ws*2 + \"Calibration: %d.\" % n_calib)\n",
    "    print(ws*2 + \"Test: %d.\" % n_test)\n",
    "    print()\n",
    "    \n",
    "    # 2) Compute residuals and scores\n",
    "    if dataset == \"gaus\":\n",
    "                \n",
    "        alpha_idx = 0\n",
    "        alpha = alphas[alpha_idx]\n",
    "        \n",
    "        w = np.ones(n_test)\n",
    "        test_idxes_weights = w / np.sum(w)\n",
    "        \n",
    "        for t_idx, t in enumerate(ts):\n",
    "            print(ws*2 + \"Processing t=\" + str(t) + \".\")\n",
    "            thetat = np.sqrt(1-t**2)*theta0 + t*theta1\n",
    "            muhat = X @ thetat\n",
    "            S = (X @ thetat - y)**2\n",
    "\n",
    "            for rho_idx, rho in enumerate(rhos):\n",
    "\n",
    "                for alg in alg_names_ordered:\n",
    "                    print(ws*3 + \"Computing \" + alg + \" intervals.\")\n",
    "\n",
    "                    alg_idx = alg_idxes[alg]\n",
    "                    if alg == \"Standard\":\n",
    "                        q = np.quantile(S, 1.0-update_alpha(alpha, n_calib))\n",
    "\n",
    "                    elif alg == \"K-L\":\n",
    "                        q, _ = dro_conformal_quantile_procedure_cvx(S, kl, update_alpha(\n",
    "                            alpha, n_calib), rho, want_bisection=True, verbose=False)\n",
    "\n",
    "                    elif alg == \"Chi-squared\":\n",
    "                        q, _ = dro_conformal_quantile_procedure_cvx(\n",
    "                            S, chisq, update_alpha(alpha, n_calib), rho, want_bisection=True)\n",
    "\n",
    "                    q = np.sqrt(q)\n",
    "                    hi = np.add(muhat[test_idxes], q)\n",
    "                    lo = np.subtract(muhat[test_idxes], q)\n",
    "\n",
    "                    coverage, length, raw_coverages = compute_coverage_and_length(\n",
    "                        lo, hi, y_shifted, test_idxes, test_idxes_weights)\n",
    "\n",
    "                    coverages_dict[alpha_idx,\n",
    "                                   rho_idx,\n",
    "                                   rep_idx,\n",
    "                                   t_idx,\n",
    "                                   alg_idx] = coverage\n",
    "                                               \n",
    "                    lengths_dict[alpha_idx,\n",
    "                                 rho_idx,\n",
    "                                 rep_idx,\n",
    "                                 t_idx,\n",
    "                                 alg_idx] = length                \n",
    "                        \n",
    "    else:    \n",
    "\n",
    "        for mod_idx, base_model in enumerate(base_models):        \n",
    "\n",
    "            model_obj = base_model_objs_list_dict[rep_idx][base_model]\n",
    "            muhat = model_obj.predict(X)            \n",
    "            S = np.abs(y[calib_idxes] - muhat[calib_idxes])\n",
    "\n",
    "            # 3) Go thru all possible mis-coverage levels, rho's, conformalization procedures\n",
    "            for alpha_idx, alpha in enumerate(alphas):\n",
    "\n",
    "                for rho_idx, rho in enumerate(rhos):\n",
    "\n",
    "                    for alg in alg_names_ordered:\n",
    "                        print(ws*4 + \"Computing \" + alg + \" intervals.\")\n",
    "\n",
    "                        alg_idx = alg_idxes[alg]\n",
    "                        if alg == \"Standard\":\n",
    "                            q = np.quantile(S, 1.0-update_alpha(alpha, n_calib))\n",
    "\n",
    "                        elif alg == \"K-L\":\n",
    "                            q, _ = dro_conformal_quantile_procedure_cvx(S, kl, update_alpha(\n",
    "                                alpha, n_calib), rho, want_bisection=True, verbose=False)\n",
    "\n",
    "                        elif alg == \"Chi-squared\":\n",
    "                            q, _ = dro_conformal_quantile_procedure_cvx(\n",
    "                                S, chisq, update_alpha(alpha, n_calib), rho, want_bisection=True)\n",
    "\n",
    "                        hi = np.add(muhat[test_idxes], q)\n",
    "                        lo = np.subtract(muhat[test_idxes], q)\n",
    "\n",
    "                        # Evaluate (weighted) coverage\n",
    "                        for shift_fn_idx, shift_fn in enumerate(shift_fns):\n",
    "\n",
    "                            for shift_beta_idx, shift_beta in enumerate(shift_betas):\n",
    "                                if (shift_fn_idx == 0) and (shift_beta_idx > 0):\n",
    "                                    continue\n",
    "\n",
    "                                test_idxes_weights = test_idxes_weights_list_dict[rep_idx][shift_fn, shift_beta]\n",
    "                                coverage, length, raw_coverages = compute_coverage_and_length(\n",
    "                                    lo, hi, y, test_idxes, test_idxes_weights)\n",
    "\n",
    "                                coverages_dict[shift_fn, shift_beta][alpha_idx,\n",
    "                                                                     rho_idx,\n",
    "                                                                     rep_idx,\n",
    "                                                                     mod_idx,\n",
    "                                                                     alg_idx] = coverage\n",
    "\n",
    "                                raw_coverages_dict[shift_fn, shift_beta, rep_idx, alg_idx] = raw_coverages\n",
    "\n",
    "                                if shift_fn_idx == 0:                                \n",
    "                                    lengths_dict[alpha_idx,\n",
    "                                                 rho_idx,\n",
    "                                                 rep_idx,\n",
    "                                                 mod_idx,\n",
    "                                                 alg_idx] = length\n",
    "\n",
    "    # 5) Write out coverages and lengths on the last trial\n",
    "    print()\n",
    "    if(rep_idx == (num_reps-1)):\n",
    "        \n",
    "        print(\"Finalizing workspace object ...\")\n",
    "        \n",
    "        if dataset == \"gaus\":\n",
    "            wksp[\"num_reps\"] = num_reps\n",
    "            wksp[\"dataset\"] = dataset\n",
    "            wksp[\"fix_train\"] = fix_train\n",
    "            wksp[\"base_models\"] = None\n",
    "            wksp[\"stdize\"] = stdize\n",
    "            wksp[\"alphas\"] = alphas\n",
    "            wksp[\"rhos\"] = rhos\n",
    "            wksp[\"shift_fns\"] = shift_fns\n",
    "            wksp[\"shift_betas\"] = shift_betas\n",
    "\n",
    "            wksp[\"train_idxes\"] = train_idxes_list\n",
    "            wksp[\"base_model_objs\"] = None\n",
    "\n",
    "            wksp[\"calib_idxes\"] = calib_idxes_list\n",
    "\n",
    "            wksp[\"test_idxes\"] = test_idxes_list\n",
    "            test_idxes_weights_wksp[\"test_idxes_weights\"] = None\n",
    "            wksp[\"v\"] = None\n",
    "\n",
    "            wksp[\"coverages_dict\"] = coverages_dict\n",
    "            wksp[\"raw_coverages_dict\"] = None\n",
    "            wksp[\"lengths_dict\"] = lengths_dict\n",
    "            \n",
    "            wksp[\"ts\"] = ts\n",
    "        \n",
    "        else:        \n",
    "            wksp[\"num_reps\"] = num_reps\n",
    "            wksp[\"dataset\"] = dataset\n",
    "            wksp[\"fix_train\"] = fix_train\n",
    "            wksp[\"base_models\"] = base_models\n",
    "            wksp[\"stdize\"] = stdize\n",
    "            wksp[\"alphas\"] = alphas\n",
    "            wksp[\"rhos\"] = rhos\n",
    "            wksp[\"shift_fns\"] = shift_fns\n",
    "            wksp[\"shift_betas\"] = shift_betas\n",
    "\n",
    "            wksp[\"train_idxes\"] = train_idxes_list\n",
    "            wksp[\"base_model_objs\"] = base_model_objs_list_dict\n",
    "\n",
    "            wksp[\"calib_idxes\"] = calib_idxes_list\n",
    "\n",
    "            wksp[\"test_idxes\"] = test_idxes_list\n",
    "            test_idxes_weights_wksp[\"test_idxes_weights\"] = test_idxes_weights_list_dict\n",
    "            wksp[\"v\"] = v_list\n",
    "\n",
    "            wksp[\"coverages_dict\"] = coverages_dict\n",
    "            wksp[\"raw_coverages_dict\"] = raw_coverages_dict\n",
    "            wksp[\"lengths_dict\"] = lengths_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not wksp_fp_exists:\n",
    "    pickle.dump(wksp, open(wksp_fp, \"wb\"))\n",
    "    joblib.dump(test_idxes_weights_wksp, test_idxes_weights_wksp_fp)\n",
    "    print(\"Saved workspaces to %s and %s.\" % (wksp_fp, test_idxes_weights_wksp_fp))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"All done.\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
