{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n",
      "env: PYTHONHASHSEED=0\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import sklearn\n",
    "np.random.seed(0)\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "from matplotlib.ticker import NullFormatter\n",
    "%matplotlib inline\n",
    "import seaborn as sns\n",
    "sns.set(palette=\"bright\",style=\"ticks\",font=\"Arial\")\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "from matplotlib.patches import PathPatch\n",
    "from matplotlib.ticker import FormatStrFormatter\n",
    "\n",
    "from sklearn import linear_model\n",
    "from sklearn.linear_model import HuberRegressor\n",
    "from sklearn import manifold, datasets\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn import preprocessing\n",
    "\n",
    "from functools import partial\n",
    "import cvxpy as cp\n",
    "import pandas as pd\n",
    "import copy\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "from sklearn.metrics.pairwise import rbf_kernel\n",
    "import os.path\n",
    "import pdb\n",
    "import scipy as sp\n",
    "import hashlib\n",
    "import joblib\n",
    "import pickle\n",
    "import pdb\n",
    "import scipy\n",
    "\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "%env PYTHONHASHSEED=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy: 1.23.5\n",
      "matplotlib: 3.7.0\n",
      "seaborn: 0.12.2\n",
      "sklearn: 1.2.1\n",
      "cvxpy: 1.4.0\n",
      "pandas: 1.5.3\n",
      "scipy: 1.10.0\n",
      "joblib: 1.1.1\n",
      "hashlib (standard library, version not applicable)\n",
      "pickle (standard library, version not applicable)\n"
     ]
    }
   ],
   "source": [
    "print(f\"numpy: {np.__version__}\")\n",
    "print(f\"matplotlib: {matplotlib.__version__}\")\n",
    "print(f\"seaborn: {sns.__version__}\")\n",
    "print(f\"sklearn: {sklearn.__version__}\")\n",
    "print(f\"cvxpy: {cp.__version__}\")\n",
    "print(f\"pandas: {pd.__version__}\")\n",
    "print(f\"scipy: {scipy.__version__}\")\n",
    "print(f\"joblib: {joblib.__version__}\")\n",
    "print(f\"hashlib (standard library, version not applicable)\")\n",
    "print(f\"pickle (standard library, version not applicable)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# NOTE: don't forget to set these v.\n",
    "\n",
    "ff= \"../data/jasa_10_07_2023_data/\"\n",
    "out_ff = \"../data/jasa_10_07_2023_data/\"\n",
    "\n",
    "datasets = [\"bos\"]\n",
    "# datasets = [\"kin\", \"cpu\", \"bos\", \"puma\", \"delta\", \"ca\", \"abalone\", \"airfoil\"]\n",
    "\n",
    "num_trials = 20 # 50\n",
    "\n",
    "fontsize = 16\n",
    "\n",
    "linewidth = 2\n",
    "\n",
    "shift_betas = np.array([0.02, 0.04, 0.08, 0.16, 0.32, 0.64])\n",
    "shift_betas = np.concatenate([shift_betas, -shift_betas])\n",
    "shift_betas.sort()\n",
    "num_shift_betas = shift_betas.shape[0]\n",
    "\n",
    "shift_taus = [0.5, 0.6, 0.7]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Read test weights in"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ext_jlib = \"jlib\"\n",
    "\n",
    "test_idxes_weights_dict_list = []\n",
    "for dataset_idx, dataset in enumerate(datasets):\n",
    "    if dataset == \"gaus\":\n",
    "        continue\n",
    "    \n",
    "    print(\"Reading in weights on the test set, for data set %s ...\" % dataset)\n",
    "    test_idxes_weights_wksp_fp = ff + dataset + \"_test_idxes_weights_wksp.\" + ext_jlib    \n",
    "    dataset_test_idxes_weights_wksp = joblib.load(test_idxes_weights_wksp_fp)\n",
    "    print(\"Done.\")\n",
    "    \n",
    "    test_idxes_weights_dict_list += [dataset_test_idxes_weights_wksp[\"test_idxes_weights\"]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Make plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_I_E_str(shift_fn):\n",
    "    \n",
    "    dim_idx_shift_tau_str = shift_fn.replace(\"I-E-\", \"\")\n",
    "    dash_idx = dim_idx_shift_tau_str.index(\"-\")\n",
    "\n",
    "    dim_idx_str = dim_idx_shift_tau_str[0:dash_idx]\n",
    "    shift_tau = float(dim_idx_shift_tau_str.replace(dim_idx_str + \"-\", \"\"))\n",
    "    dim_idx = int(dim_idx_str)\n",
    "    shift_tau_str = \", tau=\" + str(shift_tau)\n",
    "    \n",
    "    return shift_tau_str, dim_idx, shift_tau"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_title(shift_fn, shift_beta, dataset, rev=False):\n",
    "    \n",
    "    if \"I-C\" in shift_fn:\n",
    "        dim_idx = int(shift_fn.replace(\"I-C-\", \"\"))\n",
    "        shift_tau_str = \"\"\n",
    "\n",
    "        shift_fn = \"I-C, dim. \" + str(dim_idx+1)            \n",
    "        \n",
    "    elif \"I-D\" in shift_fn:\n",
    "        shift_tau = float(shift_fn.replace(\"I-D-\", \"\"))\n",
    "        shift_tau_str = \", tau=\" + str(shift_tau)\n",
    "\n",
    "        shift_fn = \"I-D\"\n",
    "        \n",
    "    elif \"I-E\" in shift_fn:\n",
    "        shift_tau_str, _, _ = parse_I_E_str(shift_fn)\n",
    "\n",
    "    else:\n",
    "        shift_tau_str = \"\"\n",
    "\n",
    "    if not rev:\n",
    "        title = \"K-L for \"\n",
    "    else:\n",
    "        title = \"Rev. K-L for \"\n",
    "    title +=shift_fn\n",
    "        \n",
    "    if \"I-A\" not in shift_fn:\n",
    "        title += \", beta=\" + str(shift_beta)            \n",
    "\n",
    "        title += shift_tau_str        \n",
    "        \n",
    "    title += \", \" + dataset # + \" data\"\n",
    "    \n",
    "    return title"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_out_fp(shift_fn, shift_beta, dataset, rev=False):\n",
    "    \n",
    "    fp = out_ff + dataset + \"_boxplot_kl_\"\n",
    "    \n",
    "    if rev:\n",
    "        fp += \"rev_\"\n",
    "    \n",
    "    fp += shift_fn\n",
    "\n",
    "    if \"I-A\" not in shift_fn:\n",
    "        fp += \"_beta_\" + str(shift_beta)\n",
    "    \n",
    "    fp += \".pdf\"\n",
    "    \n",
    "    return fp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_plot(kls, shift_fn, shift_beta, dataset, rev=False):\n",
    "    f,a = plt.subplots()\n",
    "    a.boxplot(kls,\n",
    "              showmeans=True,\n",
    "              showfliers=True)\n",
    "    a.set_title(get_title(shift_fn, shift_beta, dataset, rev), fontsize=fontsize+2)\n",
    "\n",
    "    a.set_ylabel('')    \n",
    "    a.set_xlabel('')\n",
    "    a.tick_params(axis='both', labelsize=fontsize)\n",
    "    a.set_xticks([])\n",
    "\n",
    "    out_fp = get_out_fp(shift_fn, shift_beta, dataset, rev)\n",
    "    f.savefig(out_fp, bbox_inches=\"tight\")\n",
    "    print(\"Saved plot to %s.\" % out_fp)       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_dataset_to_friendly_dataset_name(dataset, append_word_data=True):\n",
    "    \n",
    "    if(dataset == \"gaus\"):\n",
    "        dataset_friendly_name = \"Gaussian\"\n",
    "\n",
    "    elif(dataset == \"airfoil\"):\n",
    "        dataset_friendly_name = \"Airfoil\"\n",
    "\n",
    "    elif(dataset == \"abalone\"):\n",
    "        dataset_friendly_name = \"Abalone\"\n",
    "\n",
    "    elif(dataset == \"ca\"):\n",
    "        dataset_friendly_name = \"California housing\"\n",
    "\n",
    "    elif(dataset == \"delta\"):\n",
    "        dataset_friendly_name = \"Delta ailerons\"\n",
    "\n",
    "    elif(dataset == \"ailerons\"):\n",
    "        dataset_friendly_name = \"Ailerons\"\n",
    "\n",
    "    elif(dataset == \"bank\"):\n",
    "        dataset_friendly_name = \"Banking\"\n",
    "\n",
    "    elif(dataset == \"bos\"):\n",
    "        dataset_friendly_name = \"Boston housing\"\n",
    "\n",
    "    elif(dataset == \"cpu\"):\n",
    "        dataset_friendly_name = \"CPU\"\n",
    "\n",
    "    elif(dataset == \"kin\"):\n",
    "        dataset_friendly_name = \"Kinematics\"\n",
    "\n",
    "    elif(dataset == \"puma\"):\n",
    "        dataset_friendly_name = \"Puma\"\n",
    "        \n",
    "    if(append_word_data):\n",
    "        dataset_friendly_name += \" data\"\n",
    "        \n",
    "    return dataset_friendly_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_shift_beta_idx(shift_beta):\n",
    "    return np.where(shift_betas == shift_beta)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ext_pkl = \"pkl\"\n",
    "\n",
    "coverages_dict_list = []\n",
    "raw_coverages_dict_list = []\n",
    "for dataset_idx, dataset in enumerate(datasets):\n",
    "    print(\"Reading in coverages for data set %s ...\" % dataset)\n",
    "    dataset_wskp_fp = ff + dataset + \"_wksp.\" + ext_pkl   \n",
    "    dataset_wksp = pickle.load(open(dataset_wskp_fp, \"rb\"))\n",
    "    print(\"Done.\")\n",
    "    \n",
    "    coverages_dict_list += [dataset_wksp[\"coverages_dict\"]]\n",
    "    raw_coverages_dict_list += [dataset_wksp[\"raw_coverages_dict\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_dims_list = []\n",
    "for dataset_idx, dataset in enumerate(datasets):\n",
    "    \n",
    "    if(dataset == \"airfoil\"):\n",
    "\n",
    "        X = pd.read_csv(ff + dataset + \"_X.csv\", index_col=0).to_numpy()\n",
    "\n",
    "    elif((dataset == \"abalone\") or\n",
    "         (dataset == \"ca\") or\n",
    "         (dataset == \"delta\") or\n",
    "         (dataset == \"ailerons\") or\n",
    "         (dataset == \"bank\") or\n",
    "         (dataset == \"bos\") or\n",
    "         (dataset == \"cpu\") or\n",
    "         (dataset == \"kin\") or\n",
    "         (dataset == \"puma\") or\n",
    "         (dataset == \"gaus\")):\n",
    "\n",
    "        X = pd.read_csv(ff + dataset + \"_X.csv\").to_numpy()\n",
    "\n",
    "    num_dims_list += [X.shape[1]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alg_names_ordered = [\"Standard\", \"K-L\", \"Chi-squared\"]\n",
    "alg_idxes = {\"Standard\":0, \"K-L\":1, \"Chi-squared\":2}\n",
    "\n",
    "# NOTE: don't forget to set this v.\n",
    "alg_idx = alg_idxes[\"Standard\"]\n",
    "# alg_idx = alg_idxes[\"Chi-squared\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset_idx, dataset in enumerate(datasets):\n",
    "    if dataset == \"gaus\":\n",
    "        continue\n",
    "    \n",
    "    num_dims = num_dims_list[dataset_idx]\n",
    "        \n",
    "    shift2shift_beta_idxes2dim_idxes2cvgs = {}\n",
    "    shift2shift_beta_idxes2dim_idxes2cvgs[\"I-B\"] = np.inf*np.ones((num_shift_betas, 1, num_trials))\n",
    "    shift2shift_beta_idxes2dim_idxes2cvgs[\"I-C\"] = np.inf*np.ones((num_shift_betas, num_dims, num_trials))\n",
    "    for shift_tau in shift_taus:\n",
    "        shift2shift_beta_idxes2dim_idxes2cvgs[\"I-D-\" + str(shift_tau)] = np.inf*np.ones((num_shift_betas, 1, num_trials))\n",
    "        shift2shift_beta_idxes2dim_idxes2cvgs[\"I-E-\" + str(shift_tau)] = np.inf*np.ones((num_shift_betas, num_dims, num_trials))\n",
    "    \n",
    "    shifts = coverages_dict_list[dataset_idx].keys()\n",
    "    for shift in shifts:\n",
    "        shift_fn = shift[0]\n",
    "        shift_beta = shift[1]\n",
    "        shift_beta_idx = get_shift_beta_idx(shift_beta)   \n",
    "        \n",
    "        if \"I-B\" in shift_fn:\n",
    "            shift_fn_trim = shift_fn\n",
    "            dim_idx = 0\n",
    "        \n",
    "        elif \"I-C\" in shift_fn:\n",
    "            shift_fn_trim = \"I-C\"\n",
    "            dim_idx = int(shift_fn.replace(\"I-C-\", \"\"))\n",
    "            \n",
    "        elif \"I-D\" in shift_fn:\n",
    "            shift_fn_trim = shift_fn\n",
    "            dim_idx = 0\n",
    "            \n",
    "        elif \"I-E\" in shift_fn:\n",
    "            shift_fn_trim = \"I-E\"\n",
    "            _, dim_idx, shift_tau = parse_I_E_str(shift_fn)\n",
    "            shift_fn_trim += \"-\" + str(shift_tau)\n",
    "            \n",
    "        else:\n",
    "            continue\n",
    "\n",
    "        shift2shift_beta_idxes2dim_idxes2cvgs[shift_fn_trim][shift_beta_idx, dim_idx, :] = \\\n",
    "        coverages_dict_list[dataset_idx][shift_fn, shift_beta][0,0,:,0,alg_idx]\n",
    "    \n",
    "    shift_fn_trims = [\"I-B\", \"I-C\"]\n",
    "    for shift_tau in shift_taus:\n",
    "        shift_fn_trims += [\"I-D-\" + str(shift_tau)]\n",
    "        shift_fn_trims += [\"I-E-\" + str(shift_tau)]\n",
    "    for shift_fn_trim in shift_fn_trims:\n",
    "        \n",
    "        f,a = plt.subplots()\n",
    "\n",
    "        my_means = np.mean(shift2shift_beta_idxes2dim_idxes2cvgs[shift_fn_trim][:, :, :], axis=(1,2))\n",
    "        a.plot(shift_betas, my_means, color=\"lime\", linestyle=\"-\", linewidth=linewidth)\n",
    "        \n",
    "        my_medians = np.median(shift2shift_beta_idxes2dim_idxes2cvgs[shift_fn_trim][:, :, :], axis=(1,2))\n",
    "        a.plot(shift_betas, my_medians, color=\"black\", linestyle=\"-\", linewidth=linewidth)\n",
    "        \n",
    "        my_qs = []\n",
    "        for q in np.arange(0.1, 1, 0.1):\n",
    "            my_qs += [np.quantile(shift2shift_beta_idxes2dim_idxes2cvgs[shift_fn_trim][:, :, :], q, axis=(1,2))]\n",
    "        \n",
    "        for q_idx, q in enumerate(np.arange(0.1, 0.5, 0.1)):\n",
    "            a.fill_between(shift_betas, my_qs[q_idx], my_qs[len(my_qs)-q_idx-1], color=\"cornflowerblue\", alpha=0.2)\n",
    "\n",
    "        a.axhline(0.95, c='r', linestyle=\"-\", linewidth=linewidth)\n",
    "        a.tick_params(axis='both', labelsize=fontsize)\n",
    "\n",
    "        a.set_ylim([a.get_ylim()[0],1])\n",
    "        a.set_xlim([min(shift_betas), max(shift_betas)])\n",
    "        a.set_xlabel(\"a\", fontsize=fontsize)\n",
    "        a.set_title(convert_dataset_to_friendly_dataset_name(dataset), fontsize=fontsize+2)\n",
    "        \n",
    "        out_fp = out_ff + dataset + \"_coverage_\" + alg_names_ordered[alg_idx] + \"_\" + shift_fn_trim.replace(\"D-\", \"D_\").replace(\"E-\", \"E_\") + \".pdf\"\n",
    "        f.savefig(out_fp, bbox_inches=\"tight\")\n",
    "\n",
    "print(\"All done.\")        "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
