{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as mpatches\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "from matplotlib.ticker import NullFormatter\n",
    "%matplotlib inline\n",
    "import seaborn as sns\n",
    "sns.set(style=\"ticks\")\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "import geopandas\n",
    "import typing\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "\n",
    "import cvxpy as cp\n",
    "import pandas as pd\n",
    "import scipy as sp\n",
    "import covidcast\n",
    "import datetime\n",
    "import sklearn\n",
    "import sklearn.preprocessing\n",
    "import copy\n",
    "import pynndescent\n",
    "from functools import reduce\n",
    "import os\n",
    "import pickle\n",
    "import pymde\n",
    "from sklearn.linear_model import QuantileRegressor\n",
    "from sklearn.metrics import pairwise_distances\n",
    "from enum import Enum\n",
    "import statsmodels\n",
    "from statsmodels.graphics import gofplots\n",
    "import csv\n",
    "\n",
    "import sys\n",
    "sys.path.insert(1, f\"{os.getcwd()}/backend\")\n",
    "from np_backend.dro_conformal import *\n",
    "import numpy as np\n",
    "np.random.seed(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy: 1.23.5\n",
      "matplotlib: 3.7.0\n",
      "seaborn: 0.12.2\n",
      "geopandas: 0.14.0\n",
      "cvxpy: 1.4.0\n",
      "pandas: 1.5.3\n",
      "scipy: 1.10.0\n",
      "sklearn: 1.2.1\n",
      "pynndescent: 0.5.10\n",
      "pymde: 0.1.18\n",
      "statsmodels: 0.13.5\n"
     ]
    }
   ],
   "source": [
    "print(f\"numpy: {np.__version__}\")\n",
    "print(f\"matplotlib: {matplotlib.__version__}\")\n",
    "print(f\"seaborn: {sns.__version__}\")\n",
    "print(f\"geopandas: {geopandas.__version__}\")\n",
    "print(f\"cvxpy: {cp.__version__}\")\n",
    "print(f\"pandas: {pd.__version__}\")\n",
    "print(f\"scipy: {sp.__version__}\")\n",
    "print(f\"sklearn: {sklearn.__version__}\")\n",
    "print(f\"pynndescent: {pynndescent.__version__}\")\n",
    "print(f\"pymde: {pymde.__version__}\")\n",
    "print(f\"statsmodels: {statsmodels.__version__}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "B_func = lambda n: n // 4\n",
    "\n",
    "lam_base = 2\n",
    "lam_exps = np.arange(-10, 10, dtype=float)\n",
    "lams = lam_base**lam_exps\n",
    "lams_orig = np.copy(lams)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EnsembleType = Enum(\"EnsembleType\", \n",
    "                    [\"Bagged\", \"Stacked\", \"Multitask\", \"PureLocal\"])\n",
    "\n",
    "ensemble_type = EnsembleType.Bagged\n",
    "# ensemble_type = EnsembleType.Stacked\n",
    "# ensemble_type = EnsembleType.Multitask\n",
    "# ensemble_type = EnsembleType.PureLocal\n",
    "\n",
    "if ensemble_type != EnsembleType.Multitask:\n",
    "    lams = None\n",
    "else:\n",
    "    lams = np.copy(lams_orig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "want_robust_intervals = True\n",
    "alpha = 0.1\n",
    "kl = lambda z : -cp.entr(z)\n",
    "adjust_alpha = lambda my_alpha, my_n_val: np.maximum(1. - (1. + 1./my_n_val)*(1. - my_alpha), 0.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "ff= \"../data/jasa_10_07_2023_data/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "start_date = datetime.date(2021,1,21)\n",
    "end_date = datetime.date(2021,9,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.set_option('display.max_rows', 100)\n",
    "pd.set_option('display.max_columns', 100)\n",
    "pd.options.display.max_info_columns = 100\n",
    "pd.options.display.max_seq_items = 100\n",
    "\n",
    "datetime_str_f = \"%m-%d-%Y\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_pickle(ff + \"df_\" + str(\"2021-10-05\") + \".pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_col = \"indicator-combination_confirmed_7dav_incidence_prop_0_value_nyt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(ff+'county_data_col_names.csv', 'r') as fp:\n",
    "    csv_reader = csv.reader(fp, delimiter=\",\")\n",
    "    for line in csv_reader:\n",
    "        X_counties_col_names = line\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_cols_prefix = [y_col] + line"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_wks = ((end_date - start_date).days // 7)\n",
    "t = start_date\n",
    "one_wk = datetime.timedelta(weeks=1)\n",
    "\n",
    "Xy_tot = pd.read_pickle(ff + \"Xy_tot_\" + str(\"2021-10-05\") + \".pkl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_logit(y, a=1e-2):\n",
    "    y_out = np.copy(y)\n",
    "    for idx in range(y_out.shape[0]):\n",
    "        x = y_out[idx]\n",
    "        y_out[idx] = np.log((x+a)/(1-x+a))\n",
    "    return y_out\n",
    "\n",
    "def apply_inverse_logit(y, a=1e-2):\n",
    "    y_out = np.copy(y)\n",
    "    for idx in range(y_out.shape[0]):\n",
    "        x = y_out[idx]\n",
    "        # y_out[idx] = 1/(1+np.exp(-x)) - a\n",
    "        y_out[idx] = (np.exp(x) * (1 + a) - a) / (1 + np.exp(x))\n",
    "    return y_out\n",
    "\n",
    "def compute_scores(yhat, y, yprev=None,\n",
    "                   want_apply_inverse_logit=True, a=0): # a=1e-2\n",
    "    num = np.abs(yhat - y)        \n",
    "    # den = np.abs(yprev - y)\n",
    "    den = 1 if yprev is None else np.maximum(np.abs(yprev - y), a)\n",
    "    return num/den\n",
    "\n",
    "def pick_out_complement(v, all_idxes, some_idxes):\n",
    "    some_idxes_complement = [idx for idx in all_idxes if idx not in some_idxes]\n",
    "    \n",
    "    if len(v.shape) == 1:\n",
    "        return v[some_idxes_complement]\n",
    "    else:\n",
    "        return v[some_idxes_complement, :]\n",
    "    \n",
    "add_suffix_to_X_cols_prefix = lambda suffix : [X_col_prefix + suffix for X_col_prefix in X_cols_prefix]\n",
    "\n",
    "def get_train_val_test_data(t, t_idx, one_wk, num_wks, datetime_str_f,\n",
    "                            Xy_tot, X_cols_prefix, y_col):\n",
    "    \n",
    "    t_str = t.strftime(datetime_str_f)\n",
    "    \n",
    "    tplus1_str = (t+one_wk).strftime(datetime_str_f)\n",
    "    tplus2_str = (t+2*one_wk).strftime(datetime_str_f)\n",
    "    tplus3_str = (t+3*one_wk).strftime(datetime_str_f)\n",
    "\n",
    "    print(\"Processing (t_idx = \" + str(t_idx) + \") ...\")\n",
    "    print(\"\\tTrain set range: [\" + t_str + \", \" + tplus1_str + \")\")\n",
    "    print(\"\\tVal. set range: [\" + tplus1_str + \", \" + tplus2_str + \")\")\n",
    "    print(\"\\tTest set range: [\" + tplus2_str + \", \" + tplus3_str + \")\")\n",
    "    print()\n",
    "\n",
    "    if t_idx == 0:\n",
    "        X_train = Xy_tot[X_cols_prefix].to_numpy()\n",
    "        X_train[:,0] = apply_logit(X_train[:,0])\n",
    "    else:\n",
    "        X_train = Xy_tot[add_suffix_to_X_cols_prefix(\"_\" + t_str)].to_numpy()\n",
    "        X_train[:,0] = apply_logit(X_train[:,0])\n",
    "    y_train = Xy_tot[y_col + \"_\" + tplus1_str].to_numpy()\n",
    "\n",
    "    X_val = Xy_tot[add_suffix_to_X_cols_prefix(\"_\" + tplus1_str)].to_numpy()\n",
    "    X_val[:,0] = apply_logit(X_val[:,0])\n",
    "    y_val = Xy_tot[y_col + \"_\" + tplus2_str].to_numpy()\n",
    "    n_val = y_val.shape[0]\n",
    "\n",
    "    X_test = Xy_tot[add_suffix_to_X_cols_prefix(\"_\" + tplus2_str)].to_numpy()\n",
    "    X_test[:,0] = apply_logit(X_test[:,0])\n",
    "    y_test = Xy_tot[y_col + \"_\" + tplus3_str].to_numpy()\n",
    "    n_test = y_test.shape[0]\n",
    "\n",
    "    y_train_h = apply_logit(y_train)\n",
    "    y_val_h = apply_logit(y_val)\n",
    "    y_test_h = apply_logit(y_test)\n",
    "\n",
    "    return X_train, y_train, y_train_h, \\\n",
    "            X_val, y_val, y_val_h, n_val, \\\n",
    "            X_test, y_test, y_test_h, n_test\n",
    "\n",
    "def compute_len_cvg(hi, lo, y_test, n_test):\n",
    "    my_len = np.mean(hi-lo)\n",
    "    my_cvg = np.array([1. if ((lo[idx] <= y_test[idx]) & (hi[idx] >= y_test[idx])) else 0. \n",
    "                       for idx in range(n_test)])\n",
    "    return my_len, my_cvg\n",
    "\n",
    "def gaussian_kernel_matrix(X,Y, sigma):\n",
    "    if X.ndim == 1:\n",
    "        X = X[:,None]\n",
    "    if Y.ndim == 1:\n",
    "        Y = Y[:,None]\n",
    "    pw_dist = - 2.0 * X.dot(\n",
    "        Y.T) + (X**2).sum(\n",
    "        axis=1, keepdims=True\n",
    "    ) + (Y**2).sum(\n",
    "        axis=1, keepdims=True\n",
    "    ).T\n",
    "    return np.exp(- pw_dist / sigma)\n",
    "\n",
    "def kl_M1_estimator(K_YY, K_YX, lambda_n):\n",
    "    n, p= K_YX.shape\n",
    "    alpha = cp.Variable(p)\n",
    "\n",
    "    kl_obj = cp.Minimize(\n",
    "        1.0 / (2 * lambda_n) * (\n",
    "            cp.quad_form(alpha, K_YY) - 2 * cp.sum(cp.multiply(alpha,K_YX.dot(np.ones(n)/n)))\n",
    "        ) - 1.0 / p * cp.sum(cp.log(p * alpha))\n",
    "    )\n",
    "    kl_prob = cp.Problem(kl_obj, constraints=[alpha>=0])\n",
    "    kl_prob.solve(solver=cp.SCS)\n",
    "\n",
    "    return -1.0 / p * np.sum(np.log(p*alpha.value)), alpha.value\n",
    "\n",
    "def kl_bregman_div(x,y):\n",
    "    return x*np.log(x/y) - x + y\n",
    "\n",
    "def kl_M2_estimator(K_XX, K_XY, lambda_n):\n",
    "    n, p= K_XY.shape\n",
    "    alpha = cp.Variable(n)\n",
    "\n",
    "    kl_obj = cp.Minimize(\n",
    "        1.0 / (2 * lambda_n) * (\n",
    "            cp.quad_form(alpha, K_XX) - 2 * cp.sum(cp.multiply(alpha,K_XY.dot(np.ones(p)/p)))\n",
    "        ) + cp.sum(cp.kl_div(alpha, np.ones(n)/n))\n",
    "    )\n",
    "    kl_prob = cp.Problem(kl_obj, constraints=[alpha>=0])\n",
    "    kl_prob.solve(solver=cp.SCS)\n",
    "\n",
    "    return kl_bregman_div(alpha.value, np.ones(n)/n).sum(), alpha.value\n",
    "\n",
    "def run_wainwright_to_est_shift(X1, X2, lambda_n, gaus_sigma=1.0):\n",
    "    K_XX = gaussian_kernel_matrix(X1,X1, sigma=gaus_sigma)\n",
    "    K_YY = gaussian_kernel_matrix(X2,X2, sigma=gaus_sigma)\n",
    "    K_XY = gaussian_kernel_matrix(X1,X2, sigma=gaus_sigma)\n",
    "\n",
    "    kl1, alpha_1 = kl_M1_estimator(K_YY,K_XY.T, lambda_n)\n",
    "    kl2, alpha_2 = kl_M2_estimator(K_XX, K_XY, lambda_n)\n",
    "\n",
    "    return kl1, kl2\n",
    "\n",
    "def fit_and_eval_ensemble_model_on_test_data(X_train, y_train_h, hard_idxes,\n",
    "                                             X_val, y_val_h,\n",
    "                                             X_test, y_test_h,\n",
    "                                             datetime_str_f, t_idx_start, start_date, one_wk, num_wks, Xy_tot, y_col,\n",
    "                                             ensemble_type, lams=None):\n",
    "    if hard_idxes:\n",
    "        if ensemble_type == EnsembleType.Bagged:\n",
    "            mixed_model_obj = BaggedOrStackedModel()\n",
    "        elif ensemble_type == EnsembleType.Stacked:\n",
    "            mixed_model_obj = BaggedOrStackedModel()\n",
    "            mixed_model_obj.ensemble_type = EnsembleType.Stacked\n",
    "        elif ensemble_type == EnsembleType.Multitask:\n",
    "            mixed_model_obj = MultitaskModel(regularization=lams)\n",
    "            mixed_model_obj.ensemble_type = EnsembleType.Multitask\n",
    "        elif ensemble_type == EnsembleType.PureLocal:\n",
    "            mixed_model_obj = PureLocalStrategy()\n",
    "            mixed_model_obj.X_cols_idxes_for_dist_calc = [X_cols_prefix.index(\"pclat10\"),\n",
    "                                                          X_cols_prefix.index(\"pclon10\")]\n",
    "            mixed_model_obj.hard_idxes = hard_idxes\n",
    "            \n",
    "        mixed_model_obj.fit(y_train_h, X_train, idxes=hard_idxes, y_val_h=y_val_h, X_val=X_val)\n",
    "        mixture_weights = mixed_model_obj.coeffs_hat\n",
    "\n",
    "        yhat_test = mixed_model_obj.predict(X_test)\n",
    "        AE = compute_scores(apply_inverse_logit(yhat_test), apply_inverse_logit(y_test_h),\n",
    "                            apply_inverse_logit(y_val_h), True)\n",
    "        \n",
    "    else:\n",
    "        model_obj = Model()\n",
    "        model_obj.fit(y_train_h, X_train)\n",
    "        mixture_weights = None\n",
    "        \n",
    "        yhat_test = model_obj.predict(X_test)\n",
    "        AE = compute_scores(apply_inverse_logit(yhat_test), apply_inverse_logit(y_test_h),\n",
    "                            apply_inverse_logit(y_val_h), True)\n",
    "\n",
    "    AEs = [None]*num_wks\n",
    "    \n",
    "    return None, None, \\\n",
    "           AE, AEs, \\\n",
    "           mixture_weights\n",
    "\n",
    "def compute_robust_pred_ints_stats(hist, alg_idx, num_wks):\n",
    "    my_cvg = np.mean(np.concatenate([hist[t_idx].cvgs[alg_idx,:] for t_idx in range(num_wks-3)]))\n",
    "    my_len = np.mean(np.concatenate([hist[t_idx].lens[alg_idx,:] for t_idx in range(num_wks-3)]))\n",
    "    \n",
    "    my_his = [hist[t_idx].his[alg_idx,:] for t_idx in range(num_wks-3)]\n",
    "    my_los = [np.maximum(hist[t_idx].los[alg_idx,:], 0) for t_idx in range(num_wks-3)]\n",
    "    \n",
    "    true = [hist[t_idx].true[alg_idx,:] for t_idx in range(num_wks-3)]\n",
    "    \n",
    "    return my_cvg, my_len, my_his, my_los, true\n",
    "\n",
    "def compute_robust_pred_ints(scores_val, alpha, n_val, yhat_test, y_train_h, y_val_h, y_test_h):\n",
    "    y_train_h_sig = apply_inverse_logit(y_train_h)\n",
    "    n_train = y_train_h_sig.shape[0]\n",
    "    \n",
    "    y_val_h_sig = apply_inverse_logit(y_val_h)\n",
    "    n_val = y_val_h_sig.shape[0]\n",
    "    \n",
    "    yhat_test_sig = apply_inverse_logit(yhat_test)\n",
    "    y_test_h_sig = apply_inverse_logit(y_test_h)\n",
    "    n_test = y_test_h_sig.shape[0]\n",
    "    \n",
    "    his = np.inf*np.ones((5, n_test))\n",
    "    los = np.inf*np.ones((5, n_test))\n",
    "    true = np.inf*np.ones((5, n_test))\n",
    "    cvgs = np.inf*np.ones((5, n_test))\n",
    "    lens = np.inf*np.ones((5, n_test))\n",
    "    preds = np.inf*np.ones((5, n_test))\n",
    "    \n",
    "    # Run standard conformal:\n",
    "    q_std = np.quantile(scores_val, 1. - adjust_alpha(alpha, n_val))\n",
    "    hi_std = yhat_test_sig + q_std\n",
    "    lo_std = yhat_test_sig - q_std\n",
    "    my_len_std, my_cvg_std = compute_len_cvg(hi_std, lo_std, y_test_h_sig, n_test)\n",
    "    \n",
    "    conformal_alg_idx = 0\n",
    "    his[conformal_alg_idx, :] = hi_std\n",
    "    los[conformal_alg_idx, :] = lo_std\n",
    "    true[conformal_alg_idx, :] = y_test_h_sig\n",
    "    cvgs[conformal_alg_idx, :] = my_cvg_std\n",
    "    lens[conformal_alg_idx, :] = my_len_std\n",
    "    preds[conformal_alg_idx, :] = yhat_test_sig\n",
    "    \n",
    "    # Just cut alpha by half:\n",
    "    q_half = np.quantile(scores_val, 1. - adjust_alpha(alpha/2., n_val))\n",
    "    hi_half = yhat_test_sig + q_half\n",
    "    lo_half = yhat_test_sig - q_half\n",
    "    my_len_half, my_cvg_half = compute_len_cvg(hi_half, lo_half, y_test_h_sig, n_test)\n",
    "    \n",
    "    conformal_alg_idx = 1\n",
    "    his[conformal_alg_idx, :] = hi_half\n",
    "    los[conformal_alg_idx, :] = lo_half\n",
    "    true[conformal_alg_idx, :] = y_test_h_sig\n",
    "    cvgs[conformal_alg_idx, :] = my_cvg_half\n",
    "    lens[conformal_alg_idx, :] = my_len_half\n",
    "    preds[conformal_alg_idx, :] = yhat_test_sig\n",
    "    \n",
    "    # Run Wainwright's stuff w/ k11:\n",
    "    rho_k11, rho_k12 = run_wainwright_to_est_shift(y_train_h_sig.reshape(-1,1),\n",
    "                                                   y_val_h_sig.reshape(-1,1),\n",
    "                                                   lambda_n=1.0/np.minimum(n_train, n_val))\n",
    "    q_wain_k11, _ = dro_conformal_quantile_procedure_cvx(scores_val, kl, adjust_alpha(\n",
    "        alpha, n_val), rho_k11, want_bisection=True, verbose=False, solver=cp.SCS)\n",
    "    hi_wain_k11 = yhat_test_sig + q_wain_k11\n",
    "    lo_wain_k11 = yhat_test_sig - q_wain_k11\n",
    "    my_len_wain_k11, my_cvg_wain_k11 = compute_len_cvg(hi_wain_k11, lo_wain_k11, y_test_h_sig, n_test)\n",
    "    \n",
    "    conformal_alg_idx = 2\n",
    "    his[conformal_alg_idx, :] = hi_wain_k11\n",
    "    los[conformal_alg_idx, :] = lo_wain_k11\n",
    "    true[conformal_alg_idx, :] = y_test_h_sig\n",
    "    cvgs[conformal_alg_idx, :] = my_cvg_wain_k11\n",
    "    lens[conformal_alg_idx, :] = my_len_wain_k11\n",
    "    preds[conformal_alg_idx, :] = yhat_test_sig\n",
    "    \n",
    "    # Run Wainwright's stuff w/ k12:\n",
    "    q_wain_k12, _ = dro_conformal_quantile_procedure_cvx(scores_val, kl, adjust_alpha(\n",
    "        alpha, n_val), rho_k12, want_bisection=True, verbose=False, solver=cp.SCS)\n",
    "    hi_wain_k12 = yhat_test_sig + q_wain_k12\n",
    "    lo_wain_k12 = yhat_test_sig - q_wain_k12\n",
    "    my_len_wain_k12, my_cvg_wain_k12 = compute_len_cvg(hi_wain_k12, lo_wain_k12, y_test_h_sig, n_test)\n",
    "    \n",
    "    conformal_alg_idx = 3\n",
    "    his[conformal_alg_idx, :] = hi_wain_k12\n",
    "    los[conformal_alg_idx, :] = lo_wain_k12\n",
    "    true[conformal_alg_idx, :] = y_test_h_sig\n",
    "    cvgs[conformal_alg_idx, :] = my_cvg_wain_k12\n",
    "    lens[conformal_alg_idx, :] = my_len_wain_k12\n",
    "    preds[conformal_alg_idx, :] = yhat_test_sig\n",
    "    \n",
    "#     # Use Alg. 2 from the robust_cv paper:\n",
    "#     q_alg2 = learnable_direction_quantile(y_val_h.reshape(-1,1), scores_val, np.arange(y_val_h.shape[0]))\n",
    "#     hi_alg2 = yhat_test_sig + q_alg2\n",
    "#     lo_alg2 = yhat_test_sig - q_alg2\n",
    "#     my_len_alg2, my_cvg_alg2 = compute_len_cvg(hi_alg2, lo_alg2, y_test_h_sig, n_test)\n",
    "    \n",
    "#     conformal_alg_idx = 4\n",
    "#     his[conformal_alg_idx, :] = hi_alg2\n",
    "#     los[conformal_alg_idx, :] = lo_alg2\n",
    "#     true[conformal_alg_idx, :] = y_test_h_sig\n",
    "#     cvgs[conformal_alg_idx, :] = my_cvg_alg2\n",
    "#     lens[conformal_alg_idx, :] = my_len_alg2\n",
    "#     preds[conformal_alg_idx, :] = yhat_test_sig\n",
    "    \n",
    "    return his, los, true, cvgs, lens, preds\n",
    "\n",
    "    \n",
    "def make_df(idxes, intensities, fips,\n",
    "            wk_idx, start_date, one_wk):\n",
    "    # intensities = apply_inverse_logit(intensities)\n",
    "    selected = np.copy(intensities)\n",
    "    if idxes is not None:\n",
    "        for idx in range(len(selected)):\n",
    "            if idx not in idxes:\n",
    "                selected[idx] = np.nan # np.nan # 0 # np.nan\n",
    "    \n",
    "    num_rows = fips.shape[0]\n",
    "    time_value = start_date + wk_idx*one_wk\n",
    "    geo_type = \"county\"\n",
    "    data_source = \"indicator-combination\"\n",
    "    signal = \"confirmed_7dav_incidence_prop\"\n",
    "    null = np.nan*np.ones(num_rows)\n",
    "    \n",
    "    df_columns = [\"geo_value\", \"time_value\",\n",
    "                  \"issue\", \"lag\", \"missing_value\", \"missing_stderr\", \"missing_sample_size\",\n",
    "                  \"value\",\n",
    "                  \"stderr\", \"sample_size\",\n",
    "                  \"geo_type\", \"data_source\", \"signal\"]\n",
    "\n",
    "    df_data = {df_columns[0]:fips, df_columns[1]:time_value,\n",
    "               df_columns[2]:null, df_columns[3]:null, df_columns[4]:null, df_columns[5]:null, df_columns[6]:null,\n",
    "               df_columns[7]:selected,\n",
    "               df_columns[8]:null, df_columns[9]:null,\n",
    "               df_columns[10]:geo_type, df_columns[11]:data_source, df_columns[12]:signal}    \n",
    "\n",
    "    df = pd.DataFrame(df_data, columns=df_columns)\n",
    "\n",
    "    for idx in df[\"geo_value\"].index:\n",
    "        cur_fips = int(df.loc[idx, \"geo_value\"])\n",
    "        cur_fips_str = str(cur_fips).zfill(5)\n",
    "        df.loc[idx, \"geo_value\"] = cur_fips_str\n",
    "            \n",
    "    return df, df_data, df_columns\n",
    "\n",
    "def plot(data, time_value=None, combine_megacounties=True, plot_type=\"choropleth\",\n",
    "         vmax_mean=1, vmax_std=0, cbar=True, for_res_stmt=False, my_title=None, **kwargs):\n",
    "    \n",
    "    data_source, signal, geo_type = covidcast.plotting._detect_metadata(data)  # pylint: disable=W0212\n",
    "    meta = covidcast.plotting._signal_metadata(data_source, signal, geo_type)  # pylint: disable=W0212\n",
    "    # use most recent date in data if none provided\n",
    "    day_to_plot = time_value if time_value else max(data.time_value)\n",
    "    day_data = data.loc[data.time_value == pd.to_datetime(day_to_plot), :]\n",
    "    \n",
    "    kwargs[\"vmax\"] = kwargs.get(\"vmax\", vmax_mean + 3*vmax_std)\n",
    "    \n",
    "    kwargs[\"figsize\"] = kwargs.get(\"figsize\", (12, 6))\n",
    "    \n",
    "    fig, ax = covidcast.plotting._plot_background_states(kwargs[\"figsize\"])\n",
    "    if plot_type == \"choropleth\":\n",
    "        if for_res_stmt:\n",
    "            ax.annotate(\"Michigan\", xy=(0.59, 0.68), xytext=(0.65, 0.94), xycoords='figure fraction', textcoords='figure fraction',\n",
    "                        fontsize=res_stmt_font_size, arrowprops=dict(facecolor='black', width=1.))\n",
    "        _plot_choro(ax, day_data, combine_megacounties, \"vertical\", cbar, for_res_stmt, **kwargs)\n",
    "        \n",
    "    return fig, ax\n",
    "\n",
    "def _plot_choro(ax: matplotlib.axes.Axes,\n",
    "                data: geopandas.gpd.GeoDataFrame,\n",
    "                combine_megacounties: bool,\n",
    "                orientation: bool,\n",
    "                cbar: bool,\n",
    "                for_res_stmt: bool,\n",
    "                **kwargs: typing.Any) -> None:\n",
    "    \"\"\"Generate a choropleth map on a given Figure/Axes from a GeoDataFrame.\n",
    "    :param ax: Matplotlib axes to plot on.\n",
    "    :param data: GeoDataFrame with information to plot.\n",
    "    :param kwargs: Optional keyword arguments passed to ``GeoDataFrame.plot()``.\n",
    "    :return: Matplotlib axes with the plot added.\n",
    "    \"\"\"\n",
    "    kwargs[\"vmin\"] = kwargs.get(\"vmin\", 0)\n",
    "    kwargs[\"cmap\"] = kwargs.get(\"cmap\", \"YlOrRd\")\n",
    "    data_w_geo = covidcast.plotting.get_geo_df(data, combine_megacounties=combine_megacounties)\n",
    "    for shape in covidcast.plotting._project_and_transform(data_w_geo):\n",
    "        if not shape.empty:\n",
    "            shape.plot(column=\"value\", ax=ax, **kwargs)\n",
    "    sm = plt.cm.ScalarMappable(cmap=kwargs[\"cmap\"],\n",
    "                               norm=plt.Normalize(vmin=kwargs[\"vmin\"], vmax=kwargs[\"vmax\"]))\n",
    "    # this is to remove the set_array error that occurs on some platforms\n",
    "    sm._A = []  # pylint: disable=W0212\n",
    "\n",
    "    if cbar:\n",
    "        divider = make_axes_locatable(ax)\n",
    "        my_size = \"5.%\" if for_res_stmt else \"3.%\" \n",
    "        cax = divider.append_axes(\"right\", size=my_size, pad=0.1)\n",
    "        my_cbar = plt.colorbar(sm, ticks=np.linspace(kwargs[\"vmin\"], kwargs[\"vmax\"], 8), ax=ax,\n",
    "                     orientation=orientation, anchor=(0.5, 1.8), pad=0.1, format=\"%.2f\",\n",
    "                     cax=cax)\n",
    "        if for_res_stmt:\n",
    "            my_cbar.ax.tick_params(labelsize=res_stmt_font_size)\n",
    "        \n",
    "def make_boxplot(cvgs, lens, algs, alpha, mode, ff, fontsize):    \n",
    "    if mode == \"Coverage\":\n",
    "        obj_to_plot = np.vstack(cvgs).T\n",
    "    else:\n",
    "        obj_to_plot = np.vstack(lens).T\n",
    "        \n",
    "    fig, ax = plt.subplots(figsize=(12,5))\n",
    "    ax.boxplot(obj_to_plot,\n",
    "               labels=algs,\n",
    "               showmeans=True,\n",
    "               showfliers=True)\n",
    "    ax.tick_params(axis='both', labelsize=fontsize)\n",
    "    ax.tick_params(axis=\"x\") # , rotation=45\n",
    "    ax.set_title(mode, fontsize=fontsize+2)\n",
    "    \n",
    "    if mode == \"Coverage\":\n",
    "        ax.axhline((1. - alpha), c='r', linestyle=\"-\", linewidth=1)\n",
    "    fig.savefig(ff + \"covid_rob_pred_ints_boxplot_{}_{}.pdf\".format(mode, datetime.date.today()),\n",
    "                bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Alg:\n",
    "    def __init__(self):\n",
    "        self.sel_idxes = None\n",
    "        self.all_sel_idxes = None\n",
    "\n",
    "        self.y_train = None\n",
    "        self.y_val = None\n",
    "        self.y_test = None\n",
    "\n",
    "        self.AE = None\n",
    "        self.AEs = None\n",
    "        self.MAE = None\n",
    "        \n",
    "        self.props = {}\n",
    "\n",
    "    def identify_weirdest_points(self, X, weirdness, budget,\n",
    "                                 remove_idxes=None, round_idx=0, delta=0.01, knns=None):\n",
    "        raise Exception(\"Not implemented.\")\n",
    "\n",
    "    def identify_groups_of_weirdest_points(self, X, weirdness, budget,\n",
    "                                           just_identify_the_single_weirdest_group=False, delta=0.01, knns=None):\n",
    "        n = X.shape[0]\n",
    "\n",
    "        sel_idxes = self.identify_weirdest_points(X, weirdness, budget,\n",
    "                                                  delta=delta, knns=knns)\n",
    "        \n",
    "        all_sel_idxes = [copy.deepcopy(sel_idxes)]\n",
    "        if just_identify_the_single_weirdest_group == False:\n",
    "            remove_idxes = copy.deepcopy(sel_idxes)\n",
    "            num_weird_rounds = n//budget\n",
    "            for weird_round_idx in range(1, num_weird_rounds):\n",
    "                sel_idxes_cur = self.identify_weirdest_points(X, weirdness, budget,\n",
    "                                                              remove_idxes=remove_idxes, round_idx=weird_round_idx,\n",
    "                                                              delta=delta, knns=knns)\n",
    "                all_sel_idxes += [copy.deepcopy(sel_idxes_cur)]\n",
    "                remove_idxes += copy.deepcopy(sel_idxes_cur)\n",
    "\n",
    "            if num_weird_rounds*budget < n:\n",
    "                all_sel_idxes += [[idx for idx in range(n) if idx not in remove_idxes]]\n",
    "        return all_sel_idxes\n",
    "    \n",
    "class Balls(Alg):\n",
    "    def __init__(self, use_penalty=True):\n",
    "        super().__init__()\n",
    "        self.use_penalty = use_penalty\n",
    "\n",
    "    def identify_weirdest_points(self, X, weirdness, budget, remove_idxes=None, round_idx=0,\n",
    "                                 delta=0.01, knns=None):\n",
    "        n = X.shape[0]\n",
    "        p = X.shape[1]\n",
    "        X = np.copy(X)\n",
    "\n",
    "        if remove_idxes is not None:\n",
    "            keep_idxes = list(set(range(n)) - set(remove_idxes))\n",
    "            X = X[keep_idxes, :]\n",
    "            n = X.shape[0]\n",
    "\n",
    "        if knns is None:\n",
    "            print(\"\\tEnumerating balls (using pynndescent) ...\")\n",
    "            index_obj = pynndescent.NNDescent(X, n_neighbors=budget)\n",
    "            knns = index_obj.neighbor_graph[0]\n",
    "\n",
    "        my_std = np.std(weirdness)\n",
    "        avg_weirdness = np.inf*np.ones((n, budget))\n",
    "        for cur_budget_idx, cur_budget in enumerate(range(budget)):\n",
    "            rewards = np.sum(norm_ranks[knns][:,0:cur_budget+1], axis=1)*(1./np.sqrt(cur_budget+1))\n",
    "            pen = my_std * np.sqrt(p * np.log(n/(cur_budget+1.) + 1) + np.log(1./delta)) if self.use_penalty \\\n",
    "                else 0.\n",
    "\n",
    "            avg_weirdness[:, cur_budget_idx] = rewards - pen\n",
    "        cstar, rstar = np.unravel_index(np.argmax(avg_weirdness, axis=None), avg_weirdness.shape)\n",
    "        sel_idxes = knns[cstar,0:rstar+1].tolist()\n",
    "\n",
    "        if remove_idxes is not None:\n",
    "            sel_idxes = [keep_idxes[sel_idx] for sel_idx in sel_idxes]\n",
    "        return sel_idxes\n",
    "    \n",
    "class Naive(Alg):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    def identify_weirdest_points(self, X, weirdness, budget, remove_idxes=None, round_idx=0,\n",
    "                                 delta=None, knns=None):\n",
    "        if round_idx == 0:       \n",
    "            sel_idxes = np.argsort(weirdness)[-budget:]\n",
    "        else:\n",
    "            sel_idxes = np.argsort(weirdness)[-(round_idx+1)*budget:-(round_idx)*budget]\n",
    "        return list(sel_idxes)\n",
    "    \n",
    "class History:\n",
    "    def __init__(self):\n",
    "        self.fips = None\n",
    "        \n",
    "        self.date = None\n",
    "        \n",
    "        self.test_hardness = None\n",
    "        \n",
    "        self.balls = Balls()\n",
    "        self.balls_no_penalty = Balls(use_penalty=False)\n",
    "        \n",
    "        self.naive = Naive()\n",
    "        self.random = Alg()\n",
    "        self.actual = Alg()\n",
    "        self.all = Alg()\n",
    "        self.half = Alg()\n",
    "\n",
    "        self.lens = None\n",
    "        self.cvgs = None\n",
    "        self.his = None\n",
    "        self.los = None\n",
    "        self.true = None\n",
    "        self.preds = None\n",
    "        \n",
    "class Model:\n",
    "    def __init__(self):\n",
    "        self.coeffs_hat = None\n",
    "        self.training_MAE = None\n",
    "        self.props = {}\n",
    "    \n",
    "    def fit(self, y_h, X, idxes=None, y_val_h=None, X_val=None):\n",
    "        qr_obj = QuantileRegressor(quantile=0.5, solver=\"highs-ds\", alpha=0)\n",
    "        if idxes is None:\n",
    "            qr_obj.fit(X, y_h)            \n",
    "        else:\n",
    "            qr_obj.fit(X[idxes,:], y_h[idxes])\n",
    "        self.coeffs_hat = np.hstack((qr_obj.coef_, qr_obj.intercept_))\n",
    "                    \n",
    "    def predict(self, X, want_props=True):\n",
    "        yhat = X @ self.coeffs_hat[:-1] + self.coeffs_hat[-1]\n",
    "        return yhat\n",
    "    \n",
    "class BaggedOrStackedModel(Model):\n",
    "    def __init__(self, identifier=None, regularization=None):\n",
    "        super().__init__()\n",
    "        self.ensemble_type = EnsembleType.Bagged\n",
    "        self.model_objs = None\n",
    "        self.coeffs_hat = None\n",
    "    \n",
    "    def fit(self, y_h, X, idxes=None, y_val_h=None, X_val=None):\n",
    "        model_preds_val = np.inf*np.ones((X_val.shape[0], len(idxes))) \\\n",
    "            if self.ensemble_type == EnsembleType.Stacked else None\n",
    "        \n",
    "        self.model_objs = []\n",
    "        for hard_region_idx in range(len(idxes)):\n",
    "            print(\"\\tFitting model to hard region \" + str(hard_region_idx) + \" ...\")    \n",
    "            model_obj = Model()\n",
    "            if self.ensemble_type == EnsembleType.Stacked:\n",
    "                model_obj.fit(y_h, X, idxes=idxes[hard_region_idx])\n",
    "                model_preds_val[:, hard_region_idx] = model_obj.predict(X_val)\n",
    "            else:\n",
    "                model_obj.fit(np.hstack((y_h, y_val_h)), np.vstack((X, X_val)),\n",
    "                              idxes=idxes[hard_region_idx])\n",
    "            self.model_objs += [model_obj]\n",
    "                \n",
    "        print(\"\\tFitting ensembled model to validation data ...\") \n",
    "        if self.ensemble_type == EnsembleType.Stacked:\n",
    "            w = cp.Variable(model_preds_val.shape[1])\n",
    "            objf = cp.sum(cp.abs(y_val_h \\\n",
    "                                         - model_preds_val @ w))\n",
    "            prob = cp.Problem(cp.Minimize(objf), [w >= 0, cp.sum(w) == 1])\n",
    "            prob.solve(solver=cp.SCS) # cp.SCS\n",
    "            self.coeffs_hat = w.value            \n",
    "        else:\n",
    "            self.coeffs_hat = (1/len(idxes)) * np.ones(len(idxes))\n",
    "        \n",
    "    def predict(self, X, want_props=True):\n",
    "        yhat = 0\n",
    "        for idx in range(len(self.model_objs)):\n",
    "            yhat_idx = X @ self.model_objs[idx].coeffs_hat[:-1] + self.model_objs[idx].coeffs_hat[-1]\n",
    "            yhat += self.coeffs_hat[idx] * yhat_idx\n",
    "        return yhat        \n",
    "    \n",
    "class MultitaskModel(Model):\n",
    "    def __init__(self, identifier=None, regularization=None):\n",
    "        super().__init__()\n",
    "        self.ensemble_type = EnsembleType.Multitask\n",
    "        self.model_objs = None\n",
    "        self.coeffs_hat = None\n",
    "        self.lams = regularization\n",
    "    \n",
    "    def fit(self, y_h, X, idxes=None, y_val_h=None, X_val=None):\n",
    "        self.model_objs = []\n",
    "        for hard_region_idx in range(len(idxes)):        \n",
    "            print(\"\\tFitting 'child' model to hard region \" + str(hard_region_idx) + \" ...\")\n",
    "            model_obj = Model()\n",
    "            model_obj.fit(y_h, X, idxes=idxes[hard_region_idx])\n",
    "            self.model_objs += [model_obj]\n",
    "    \n",
    "        print(\"\\tFitting 'parent' model to hard region (and tuning regularization strength) ...\")\n",
    "        coeffs = cp.Variable(X.shape[1]+1) # +1 for bias\n",
    "        X_pad = np.hstack([X, np.ones((X.shape[0], 1))])\n",
    "        objf = cp.sum(cp.abs(y_h - X_pad @ coeffs))\n",
    "\n",
    "        reg = 0\n",
    "        for model_obj in self.model_objs:\n",
    "            reg += cp.norm(coeffs - model_obj.coeffs_hat)\n",
    "        lam_param = cp.Parameter(nonneg=True)\n",
    "        prob = cp.Problem(cp.Minimize(objf + lam_param*reg))\n",
    "\n",
    "        coeffs_hats = [None]*len(self.lams)\n",
    "        errs = np.inf*np.ones(len(self.lams))\n",
    "        for lam_idx, lam in enumerate(lams):\n",
    "            lam_param.value = lam\n",
    "            prob.solve(solver=cp.SCS, warm_start=True)\n",
    "            self.coeffs_hat = np.copy(coeffs.value)\n",
    "            coeffs_hats[lam_idx] = np.copy(self.coeffs_hat)\n",
    "            \n",
    "            yhat_val = self.predict(X_val)\n",
    "            errs[lam_idx] = np.median(compute_scores(apply_inverse_logit(yhat_val), apply_inverse_logit(y_val_h),\n",
    "                                                     apply_inverse_logit(y_h), True))\n",
    "        best_lam_idx = np.argmin(errs)\n",
    "        print(\"\\tPicked lam idx \" + str(best_lam_idx) + \" ...\")\n",
    "        self.coeffs_hat = coeffs_hats[best_lam_idx]\n",
    "        \n",
    "class PureLocalStrategy(Model):\n",
    "    def __init__(self, identifier=None, regularization=None):\n",
    "        super().__init__()\n",
    "        self.ensemble_type = EnsembleType.PureLocal\n",
    "        self.model_objs = None\n",
    "        self.coeffs_hat = None\n",
    "        self.hard_idxes = None\n",
    "        self.X_val = None\n",
    "        self.X_cols_idxes_for_dist_calc = None\n",
    "        \n",
    "    def fit(self, y_h, X, idxes=None, y_val_h=None, X_val=None):\n",
    "        self.model_objs = []\n",
    "        for hard_region_idx in range(len(idxes)):        \n",
    "            model_obj = Model()\n",
    "            model_obj.fit(y_h, X, idxes=idxes[hard_region_idx])\n",
    "            self.model_objs += [model_obj]\n",
    "        self.X_val = X_val\n",
    "        \n",
    "    def predict(self, X, want_props=True):\n",
    "        yhat = np.nan*np.ones(X.shape[0])\n",
    "        num_local_models = len(self.model_objs)\n",
    "        test_pt_2_cluster_dists = np.nan*np.ones((X.shape[0], num_local_models))\n",
    "        for model_obj_idx, model_obj in enumerate(self.model_objs):\n",
    "            model_obj_hard_idxes = self.hard_idxes[model_obj_idx]\n",
    "            test_pt_2_cluster_dists[:, model_obj_idx] = np.min(\n",
    "                pairwise_distances(X[:, self.X_cols_idxes_for_dist_calc],\n",
    "                                   self.X_val[:, self.X_cols_idxes_for_dist_calc])[:,model_obj_hard_idxes],\n",
    "                axis=1)\n",
    "        closest_model_obj_idxes = np.argmin(test_pt_2_cluster_dists, axis=1)\n",
    "        \n",
    "        for row_idx in range(X.shape[0]):\n",
    "            x = X[row_idx, :]\n",
    "            yhat[row_idx] = self.model_objs[closest_model_obj_idxes[row_idx]].predict(x)        \n",
    "        return yhat        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hist = [History() for i in range(num_wks)]\n",
    "for t_idx in range(num_wks):\n",
    "    t = start_date + t_idx*one_wk\n",
    "    t_str = t.strftime(datetime_str_f)\n",
    "    fips_test = Xy_tot[\"geo_value\"].to_numpy()\n",
    "    X_train, y_train, y_train_h, X_val, y_val, y_val_h, n_val, X_test, y_test, y_test_h, n_test \\\n",
    "        = get_train_val_test_data(t, t_idx, one_wk, num_wks, datetime_str_f, Xy_tot, X_cols_prefix, y_col)\n",
    "\n",
    "    n_test_range = range(n_test)\n",
    "    if t_idx == 0:\n",
    "        B = B_func(n_test)\n",
    "        print(\"\\tFixing budget = \" + str(B))\n",
    "\n",
    "        print(\"\\tEnumerating balls (once, using pynndescent) ...\")\n",
    "        index_obj = pynndescent.NNDescent(Xy_tot[[\"pclon10\", \"pclat10\"]].to_numpy(), n_neighbors=B) \n",
    "        knns = index_obj.neighbor_graph[0]            \n",
    "\n",
    "    # Fit model(s).\n",
    "    print(\"\\tFitting model ...\")\n",
    "    model_obj = Model()\n",
    "    model_obj.fit(y_train_h, X_train)\n",
    "\n",
    "    # Compute (normalized) ranks.\n",
    "    print(\"\\tComputing (normalized) ranks ...\")\n",
    "    yhat_val = model_obj.predict(X_val)\n",
    "    scores_val = compute_scores(apply_inverse_logit(yhat_val), apply_inverse_logit(y_val_h),\n",
    "                                apply_inverse_logit(y_train_h), True, a=1e-2)\n",
    "\n",
    "    yhat_test = model_obj.predict(X_test)\n",
    "    scores_test = compute_scores(apply_inverse_logit(yhat_test), apply_inverse_logit(y_test_h),\n",
    "                                 apply_inverse_logit(y_train_h), True, a=1e-2)\n",
    "    if want_robust_intervals:\n",
    "        print(\"\\tComputing robust intervals ...\")\n",
    "        scores_val_unscaled = compute_scores(apply_inverse_logit(yhat_val), apply_inverse_logit(y_val_h),\n",
    "                                             None, True, a=1e-2)\n",
    "        his, los, true, cvgs, lens, preds = compute_robust_pred_ints(\n",
    "            scores_val_unscaled, alpha, n_val, yhat_test, y_train_h, y_val_h, y_test_h)\n",
    "        hist[t_idx].his = his\n",
    "        hist[t_idx].los = los\n",
    "        hist[t_idx].true = true\n",
    "        hist[t_idx].cvgs = cvgs\n",
    "        hist[t_idx].lens = lens\n",
    "        hist[t_idx].preds = preds\n",
    "\n",
    "    norm_ranks = [sp.stats.rankdata(np.append(scores_val, scores_test[idx]))[-1] / (n_val+1) \\\n",
    "                      for idx in range(n_test)]\n",
    "    norm_ranks = np.array(norm_ranks)\n",
    "\n",
    "    # Identify weird (test) points (two different ways):\n",
    "    print(\"\\tIdentifying weird points ...\")\n",
    "    # 1a,b) Use balls (w/ and w/o penalty) to pick weird points.\n",
    "    all_sel_idxes = hist[t_idx].balls.identify_groups_of_weirdest_points(X_test, norm_ranks, B,\n",
    "                                                                         just_identify_the_single_weirdest_group=True,\n",
    "                                                                         knns=knns)\n",
    "\n",
    "    all_sel_idxes_balls_no_penalty = hist[t_idx].balls_no_penalty.identify_groups_of_weirdest_points(X_test, norm_ranks, B,\n",
    "                                                                         just_identify_the_single_weirdest_group=True,\n",
    "                                                                         knns=knns)\n",
    "\n",
    "    # 2) Just pick the weirdest points (i.e., a \"naive strategy\").\n",
    "    all_sel_idxes_baseline = hist[t_idx].naive.identify_groups_of_weirdest_points(X_test, norm_ranks, B,\n",
    "                                                                                  just_identify_the_single_weirdest_group=True)\n",
    "\n",
    "    # Update loop variables.\n",
    "    print(\"\\tSaving epoch ...\\n\")\n",
    "    hist[t_idx].fips = fips_test\n",
    "    \n",
    "    hist[t_idx].date = t\n",
    "    hist[t_idx+1].date = t+one_wk\n",
    "    hist[t_idx+2].date = t+2*one_wk\n",
    "    hist[t_idx+3].date = t+3*one_wk\n",
    "    \n",
    "    hist[t_idx+3].test_hardness = norm_ranks\n",
    "    \n",
    "    if t_idx == 0:\n",
    "        hist[t_idx].actual = Xy_tot[y_col + \"_original\"].to_numpy()\n",
    "    else:\n",
    "        hist[t_idx].actual = Xy_tot[y_col + \"_original\" + \"_\" + t_str].to_numpy()\n",
    "    hist[t_idx+1].actual = y_train\n",
    "    hist[t_idx+2].actual = y_val\n",
    "    hist[t_idx+3].actual = y_test\n",
    "    \n",
    "    hist[t_idx+3].balls.sel_idxes = all_sel_idxes[0]\n",
    "    hist[t_idx+3].balls.all_sel_idxes = all_sel_idxes\n",
    "    \n",
    "    hist[t_idx+3].balls_no_penalty.sel_idxes = all_sel_idxes_balls_no_penalty[0]\n",
    "    hist[t_idx+3].balls_no_penalty.all_sel_idxes = all_sel_idxes_balls_no_penalty\n",
    "    \n",
    "    hist[t_idx+3].naive.sel_idxes = all_sel_idxes_baseline[0]\n",
    "    hist[t_idx+3].naive.all_sel_idxes = all_sel_idxes_baseline\n",
    "    \n",
    "    if (t_idx+1 + 3) >= num_wks:\n",
    "        hist[t_idx+1 + 0].fips = fips_test\n",
    "        hist[t_idx+1 + 1].fips = fips_test\n",
    "        hist[t_idx+1 + 2].fips = fips_test\n",
    "        break\n",
    "        \n",
    "print(\"All done.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Make plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_intensities = []\n",
    "for wk_idx in range(num_wks-3):\n",
    "    _, _, _, _, true = compute_robust_pred_ints_stats(hist, 0, num_wks)\n",
    "    my_intensities += [true[wk_idx]]\n",
    "    \n",
    "vmax_mean_frac = np.mean(np.concatenate(my_intensities))\n",
    "vmax_std_frac = np.std(np.concatenate(my_intensities))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "special_alg_idx = 0\n",
    "num_algs = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for alg_idx in range(num_algs):\n",
    "    my_cvg, my_len, my_his, my_los, true = compute_robust_pred_ints_stats(hist, alg_idx, num_wks)\n",
    "    print(str(alg_idx) + \":\")\n",
    "    print(my_cvg, my_len)\n",
    "    \n",
    "    if alg_idx == special_alg_idx:\n",
    "        for hi_lo_true in [\"his\", \"los\", \"true\"]:\n",
    "            for wk_idx in range(num_wks-3):\n",
    "                if hi_lo_true == \"his\":\n",
    "                    intensities = my_his[wk_idx]\n",
    "                elif hi_lo_true == \"los\":\n",
    "                    intensities = my_los[wk_idx]\n",
    "                else:\n",
    "                    intensities = true[wk_idx]\n",
    "                all_idxes = list(range(len(hist[wk_idx].actual)))                    \n",
    "                fips = hist[wk_idx].fips\n",
    "                cur_df, _, _ = make_df(all_idxes, intensities, fips,\n",
    "                                 wk_idx, start_date, one_wk)\n",
    "\n",
    "                fn = \"covid_rob_pred_ints_{}_{}_alg_idx_{}\".format(\n",
    "                    (start_date + wk_idx*one_wk).strftime(datetime_str_f), hi_lo_true, alg_idx)\n",
    "                fig, _ = plot(cur_df, vmax_mean=vmax_mean_frac, vmax_std=vmax_std_frac, cbar=False)\n",
    "                fig.savefig(ff + fn + \".pdf\", bbox_inches=\"tight\")\n",
    "\n",
    "                fig, _ = plot(cur_df, vmax_mean=vmax_mean_frac, vmax_std=vmax_std_frac, cbar=True)\n",
    "                fig.savefig(ff + fn + \"_cbar.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "algs = [\"SC\",\n",
    "        r\"SC-$\\alpha/2$\",\n",
    "        \"KL-M2\",\n",
    "        \"KL-R\"]\n",
    "\n",
    "skip_alg_idx = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 16\n",
    "linewidth = 2\n",
    "linestyles = [\"-\",\n",
    "              \":\",\n",
    "              \"-\",\n",
    "              \"-\",\n",
    "              \"-\",\n",
    "              \"-\",\n",
    "              \"-\"]\n",
    "colors = [\"mediumvioletred\",\n",
    "          \"mediumvioletred\",\n",
    "          \"blue\",\n",
    "          \"deepskyblue\",\n",
    "          \"chartreuse\",\n",
    "          \"gold\"]\n",
    "algs = [\"Algorithm 2\",\n",
    "        \"Algorithm 2, unpenalized\",\n",
    "        \"Hardest points\",\n",
    "        \"Uniformly at random\",\n",
    "        \"Pure global\"]\n",
    "    \n",
    "vmax_mean = df[y_col + \"_original\"].mean(skipna=True)\n",
    "vmax_std = df[y_col + \"_original\"].std(skipna=True)\n",
    "\n",
    "res_stmt_font_size = 30"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_cvgs = []\n",
    "my_lens = []\n",
    "for alg_idx in range(num_algs):\n",
    "    if alg_idx == skip_alg_idx:\n",
    "        continue\n",
    "    \n",
    "    my_cvgs += [[np.mean(hist[t_idx].cvgs[alg_idx,:]) for t_idx in range(num_wks-3)]]\n",
    "    my_lens += [[np.mean(hist[t_idx].lens[alg_idx,:]) for t_idx in range(num_wks-3)]]\n",
    "\n",
    "make_boxplot(my_cvgs, my_lens, algs, alpha=alpha, mode=\"Coverage\", ff=ff, fontsize=fontsize)\n",
    "make_boxplot(my_cvgs, my_lens, algs, alpha=alpha, mode=\"Length\", ff=ff, fontsize=fontsize)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for wk_idx in range(3,num_wks):\n",
    "    all_idxes = list(range(len(hist[wk_idx].actual)))\n",
    "    intensities = hist[wk_idx].actual # test_hardness\n",
    "    fips = hist[wk_idx].fips\n",
    "    cur_df, _, _ = make_df(all_idxes, intensities, fips,\n",
    "                     wk_idx, start_date, one_wk)\n",
    "    \n",
    "    fp = ff + \"covid_regions_true_{}.pdf\".format(\n",
    "        (start_date + wk_idx*one_wk).strftime(datetime_str_f))\n",
    "    fig, _ = plot(cur_df, vmax_mean=vmax_mean, vmax_std=vmax_std, cbar=False)\n",
    "    fig.savefig(fp, bbox_inches=\"tight\")\n",
    "    \n",
    "    fp = ff + \"covid_regions_true_{}_cbar.pdf\".format(\n",
    "        (start_date + wk_idx*one_wk).strftime(datetime_str_f))\n",
    "    fig, _ = plot(cur_df, vmax_mean=vmax_mean, vmax_std=vmax_std, cbar=True)\n",
    "    fig.savefig(fp, bbox_inches=\"tight\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
