{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f76428f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96a9ba26",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "print(f\"Old working dir {os.getcwd()}\")\n",
    "os.chdir('../../')\n",
    "print(f\"New working dir {os.getcwd()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c3aa518",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "plots_dir = Path('./conformal_plots/')\n",
    "os.makedirs(plots_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7081a195",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef634b84",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dir = Path('./conformal_results_u/')\n",
    "#results_dir = Path('./conformal_results/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d191a3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from conformal.real_datasets.process_raw import datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da5abda1",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_targets = {\"rf1\": 8, \"rf2\": 8, \"scm1d\": 16, \"scm20d\": 16, \"sgemm\": 4, \"bio\": 2, \"blog\": 2}\n",
    "df_n_targets = pd.DataFrame({\"dataset_name\": n_targets.keys(), \"n_targets\": n_targets.values()})\n",
    "df_n_targets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f6b3fb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "t20c = matplotlib.colormaps[\"tab20c\"]\n",
    "t20c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56840ebe",
   "metadata": {},
   "outputs": [],
   "source": [
    "palette = {\n",
    "    \"OT-CP-Global\": t20c(0),\n",
    "    \"OT-CP-Local\": t20c(1),\n",
    "    \"Ell-Local\": t20c(2),\n",
    "    \"PB\": t20c(4),\n",
    "    \"RPB\": t20c(5),\n",
    "    \"HPD\": t20c(6),\n",
    "    \"Quantile\": t20c(7),\n",
    "    \"PB (CPFlow)\": t20c(8),\n",
    "    \"RPB (CPFlow)\": t20c(9),\n",
    "    \"HPD (CPFlow)\": t20c(10),\n",
    "    \"Quantile (CPFlow)\": t20c(11),\n",
    "    \"PB (Y)\": t20c(12),\n",
    "    \"RPB (Y)\": t20c(13),\n",
    "    \"HPD (Y)\": t20c(14),\n",
    "    \"Quantile (Y)\": t20c(15),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "092fea9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "colormap = matplotlib.colormaps[\"tab20\"]\n",
    "colormap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb5b110d",
   "metadata": {},
   "outputs": [],
   "source": [
    "selected_palette = {\n",
    "    \"OT-CP-Global\": colormap(0),\n",
    "    \"OT-CP-Local\": colormap(1),\n",
    "    \"PB\": colormap(2),\n",
    "    \"RPB\": colormap(3),\n",
    "    \"PB (RF)\": colormap(4),\n",
    "    \"RPB (RF)\": colormap(5),\n",
    "    \n",
    "    \"PB (Y, RF)\": colormap(6),\n",
    "    \"RPB (Y, RF)\": colormap(7),\n",
    "\n",
    "    \"PB (Y)\": colormap(8),\n",
    "    \"RPB (Y)\": colormap(9),\n",
    "    \n",
    "    \"Ell-Local\": colormap(12),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe7c1de9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Where to load each method/metric from?\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04130647",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Literal\n",
    "\n",
    "def load_methods_from(method_names: List[str], path: str | Path, seeds=range(10), extention: Literal[\"feather\", \"csv\"] = \"feather\") -> pd.DataFrame:\n",
    "    dataframes = []\n",
    "    #seeds = range(10)\n",
    "    for seed in seeds:\n",
    "        for dataset_name in datasets:\n",
    "            fn = Path(path) / dataset_name / str(seed) / f\"metrics_all.{extention}\"\n",
    "            if fn.is_file():\n",
    "                if extention == \"feather\":\n",
    "                    dataframes.append(pd.read_feather(fn))\n",
    "                else:\n",
    "                    dataframes.append(pd.read_csv(fn))\n",
    "            else:\n",
    "                print(f\"Error: dataset {dataset_name}, seed {seed} not found.\")\n",
    "    df = pd.concat(dataframes).merge(df_n_targets, on=\"dataset_name\")\n",
    "    if \"volume\" in df.columns:\n",
    "        df[\"log_vol_d\"] = np.log(df[\"volume\"]) / df[\"n_targets\"]\n",
    "    if len(method_names) > 0:\n",
    "        df = df[df[\"method_name\"].isin(method_names)]\n",
    "    return df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b422e46",
   "metadata": {},
   "outputs": [],
   "source": [
    "df0 = load_methods_from(method_names=[], path=\"./conformal_results_u/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69d0490f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "793e2d7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#df.head(20)\n",
    "df1 = load_methods_from(method_names=[], path=\"./conformal_results_u/\")\n",
    "df2 = load_methods_from(method_names=[], path=\"./conformal_results_250923/\", extention=\"csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c971319f",
   "metadata": {},
   "outputs": [],
   "source": [
    "df1[\"method_name\"].unique(), df2.method_name.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aedf12a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "df1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c7904d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "df2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c0afb87",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.concat([df1, df2], ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c28ce5a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#palette = blended_palette(df[\"base_model_name\"], df[\"conformalizer\"], paletteA=\"Set1\", paletteB=\"Set2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af56a95e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#pd.DataFrame.from_dict(palette, orient=\"index\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acf18368",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.shape, df[['dataset_name', 'alpha', 'method_name', 'seed']].drop_duplicates().shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d507283",
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"method_name\"].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2700f331",
   "metadata": {},
   "outputs": [],
   "source": [
    "(df0[df0['method_name'].str.contains(\"OT-CP\")][\"volume\"] == 0).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5efaa0dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "g_cov = sns.catplot(\n",
    "    data=df[df[\"dataset_name\"].isin([\"sgemm\", \"scm20d\"])], x=\"dataset_name\", y=\"marginal_coverage\", col=\"alpha\", hue=\"method_name\", sharey=False,\n",
    ")\n",
    "g_cov.set_axis_labels(\"Dataset\", \"Marginal coverage\")\n",
    "for alpha, ax in g_cov.axes_dict.items():\n",
    "    ax.axhline(1 - alpha, ls=\"--\", c=\"k\", alpha=0.5)\n",
    "#for ax in g_cov.axes.flatten():\n",
    "#    ax.tick_params(labelbottom=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef826810",
   "metadata": {},
   "outputs": [],
   "source": [
    "g_cov = sns.catplot(\n",
    "    data=df0, x=\"dataset_name\", y=\"volume\", col=\"alpha\", hue=\"method_name\", sharey=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14d05816",
   "metadata": {},
   "outputs": [],
   "source": [
    "#g_cov.axes_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b8dd6c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "#df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "978c7f26",
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_columns = [\"marginal_coverage\", \"worst_slab_coverage\", \"volume\", \"log_vol_d\"]\n",
    "id_vars = list(df.columns.difference(metrics_columns))\n",
    "df_melted = pd.melt(df, id_vars=id_vars, value_vars=metrics_columns, var_name=\"metric\", value_name=\"value\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a782756c",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(df_melted)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce077ff5",
   "metadata": {},
   "outputs": [],
   "source": [
    "g_all = sns.catplot(\n",
    "    data=df_melted,#.query(\"dataset_name == 'bio' or dataset_name == 'blog'\"), \n",
    "    kind=\"box\", \n",
    "    x=\"dataset_name\", y=\"value\", col=\"alpha\", row=\"metric\", hue=\"method_name\", #_mathtext\",\n",
    "    #palette=palette,\n",
    "    sharey=\"row\", showfliers=False,\n",
    ")\n",
    "g_all.set_axis_labels(\"\", \"\")\n",
    "for (metric_name, alpha), ax in g_all.axes_dict.items():\n",
    "    if \"coverage\" in metric_name:\n",
    "        ax.axhline(1 - alpha, ls=\"--\", c=\"k\", alpha=0.5)\n",
    "    if \"volume\" in metric_name:\n",
    "        ax.set_yscale(\"log\")\n",
    "for ax in g_all.axes.flatten():\n",
    "    ax.tick_params(labelbottom=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5dde908",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_main = df[\n",
    "    df[\"dataset_name\"].isin([\"sgemm\", \"scm20d\", \"bio\", \"blog\",])\n",
    "    & ~df[\"method_name\"].str.contains(\"CPFlow\")\n",
    "    & df['worst_slab_coverage']!=0\n",
    "].copy()\n",
    "def get_hatch(name):\n",
    "    if \"Quantile\" in name:\n",
    "        return \"/\"\n",
    "    elif \"PB\" in name:\n",
    "        return \"\\\\\"\n",
    "    elif \"RPB\" in name:\n",
    "        return \"x\"\n",
    "    elif \"HPD\" in name:\n",
    "        return \"-\"\n",
    "    else:\n",
    "        return None\n",
    "df_main[\"hatch\"] = df_main[\"method_name\"].apply(get_hatch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d9ce069",
   "metadata": {},
   "outputs": [],
   "source": [
    "boxplot_hatches = {\n",
    "    -3: \"\",\n",
    "    -2: \"\",\n",
    "    -1: \"\",\n",
    "    0: \"//\",\n",
    "    1: \"xx\",\n",
    "    2: \"--\",\n",
    "    3: \"o\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f21f6d4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_main[\"hatch\"].head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e895766",
   "metadata": {},
   "outputs": [],
   "source": [
    "t20c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6075c724",
   "metadata": {},
   "outputs": [],
   "source": [
    "palette_appendix = {\n",
    "    \"OT-CP-Global\": t20c(0),\n",
    "    \"OT-CP-Local\": t20c(1),\n",
    "    \"Ell-Local\": t20c(2),\n",
    "    \"PB\": t20c(4),\n",
    "    \"RPB\": t20c(5),\n",
    "    \"HPD\": t20c(6),\n",
    "    \"Quantile\": t20c(7),\n",
    "    \"PB (Y)\": t20c(8),\n",
    "    \"RPB (Y)\": t20c(9),\n",
    "    \"HPD (Y)\": t20c(10),\n",
    "    \"Quantile (Y)\": t20c(11),\n",
    "    \"PB (RF)\": t20c(12),\n",
    "    \"RPB (RF)\": t20c(13),\n",
    "    \"HPD (RF)\": t20c(14),\n",
    "    \"Quantile (RF)\": t20c(15),\n",
    "    \"PB (Y, RF)\": t20c(16),\n",
    "    \"RPB (Y, RF)\": t20c(17),\n",
    "    \"HPD (Y, RF)\": t20c(18),\n",
    "    \"Quantile (Y, RF)\": t20c(19),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "484d1bd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "#n_facets_to_plot = new_ugly_filter_wsc_df[\"dataset_name\"].nunique()\n",
    "#print(n_facets_to_plot)\n",
    "iclr_width = 5.50107\n",
    "plot_aspect_wide = 16 / 9\n",
    "plot_height = iclr_width / plot_aspect_wide\n",
    "sns.set_style({'axes.grid' : True})\n",
    "g_wsc = sns.catplot(\n",
    "    data=df_main,\n",
    "    kind=\"box\",\n",
    "    y=\"worst_slab_coverage\",\n",
    "    col=\"alpha\",\n",
    "    row=\"dataset_name\",\n",
    "    #col=\"dataset_name\",\n",
    "    hue=\"method_name\", #_mathtext\",\n",
    "    palette=palette_appendix,\n",
    "    sharey=\"row\",\n",
    "    showfliers=False,\n",
    "    #height=plot_height,\n",
    ")\n",
    "g_wsc.set_axis_labels(\"\", \"Worst slab coverage\")\n",
    "g_wsc.set_xticklabels([])\n",
    "g_wsc.despine(bottom=False, top=False, right=False)\n",
    "for (dataset_name, alpha), ax in g_wsc.axes_dict.items():\n",
    "    ax.set_title(rf\"$\\mathtt{{{dataset_name}}}$, $\\alpha={alpha:.1f}$\")\n",
    "    ax.axhline(1 - alpha, xmax=1, ls=\"--\", c=\"k\", alpha=0.9)\n",
    "for ax in g_wsc.axes.flatten():\n",
    "    ax.tick_params(left=False, bottom=False)\n",
    "    for i, patch in enumerate(ax.patches):\n",
    "        # Blue bars first, then green bars\n",
    "        patch.set_hatch(boxplot_hatches[(i - 3) % 4])\n",
    "for j, legend_patch in enumerate(g_wsc.legend.get_patches()):\n",
    "    legend_patch.set_hatch(boxplot_hatches[j % 4])\n",
    "\n",
    "sns.move_legend(g_wsc, \"lower center\", bbox_to_anchor=(0.45, 1), ncol=len(palette) // 2, title=None,\n",
    "                )\n",
    "g_wsc.savefig(plots_dir / \"results_worst_slab_coverage_250925_hatch.pdf\", bbox_inches=\"tight\")\n",
    "g_wsc.savefig(plots_dir / \"results_worst_slab_coverage_250925_hatch.png\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28078b54",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "efd81126",
   "metadata": {},
   "source": [
    "# Selected results for main part"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d57cc34",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['dataset_name'].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34606cd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "#df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d4fcd14",
   "metadata": {},
   "outputs": [],
   "source": [
    "#g_all.axes_dict\n",
    "#df[df['dataset_name'] == 'sgemm']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30a01418",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_ugly_filter_wsc_df = df[\n",
    "    (df[\"alpha\"] == 0.1) &\n",
    "    (~df[\"method_name\"].str.contains(\"Quantile\")) &\n",
    "    ~df[\"method_name\"].str.contains(\"HPD\") &\n",
    "    ~df[\"method_name\"].str.contains(\"CPFlow\") &\n",
    "    ~df[\"dataset_name\"].str.contains(\"rf\") &\n",
    "    ~df[\"dataset_name\"].str.contains(\"scm1d\") &\n",
    "    ~df[\"method_name\"].str.contains(\"Y\")\n",
    "].copy()\n",
    "new_ugly_filter_wsc_df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc57a0b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_ugly_filter_wsc_df['worst_slab_coverage_error'] = np.log((new_ugly_filter_wsc_df['worst_slab_coverage'] - (1 - new_ugly_filter_wsc_df['alpha'])).abs())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "767f32c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels_main_part_old = [r\"$\\mathtt{OT}$-$\\mathtt{CP}$\", r\"$\\mathtt{OT}$-$\\mathtt{CP}$+\", \n",
    "                        r\"$\\mathrm{ELL}$\",\n",
    "                        r\"$\\mathrm{PB}_{U}$\", r\"$\\mathrm{RPB}_{U}$\",\n",
    "                        r\"$\\mathrm{PB}_{Y}$\", r\"$\\mathrm{RPB}_{Y}$\",\n",
    "                        r\"$\\mathrm{PBS}_{U}$\", r\"$\\mathrm{RPBS}_{U}$\",\n",
    "                        r\"$\\mathrm{PBS}_{Y}$\", r\"$\\mathrm{RPBS}_{Y}$\",]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba05de05",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels_main_part = [r\"$\\mathtt{OT}$-$\\mathtt{CP}$\", r\"$\\mathtt{OT}$-$\\mathtt{CP}$+\", \n",
    "                        r\"$\\mathrm{ELL}$\",\n",
    "                        r\"$\\mathrm{PB}$\", r\"$\\mathrm{RPB}$\",\n",
    "                        r\"$\\mathrm{PBS}$\", r\"$\\mathrm{RPBS}$\",]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34212d51",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_facets_to_plot = new_ugly_filter_wsc_df[\"dataset_name\"].nunique()\n",
    "print(n_facets_to_plot)\n",
    "iclr_width = 5.50107\n",
    "plot_aspect_wide = 16 / 9\n",
    "plot_height = iclr_width / plot_aspect_wide\n",
    "sns.set_style({'axes.grid' : True})\n",
    "g_wsc = sns.catplot(\n",
    "    data=new_ugly_filter_wsc_df,\n",
    "    kind=\"box\",\n",
    "    y=\"worst_slab_coverage\",\n",
    "    #col=\"alpha\",\n",
    "    #row=\"dataset_name\",\n",
    "    col=\"dataset_name\",\n",
    "    hue=\"method_name\", #_mathtext\",\n",
    "    palette=selected_palette,\n",
    "    sharey=True,\n",
    "    showfliers=False,\n",
    "    height=plot_height,\n",
    ")\n",
    "sns.move_legend(g_wsc, \"lower center\", bbox_to_anchor=(0.45, 1), ncol=len(palette), title=None,\n",
    "                labels=labels_main_part)\n",
    "g_wsc.set_axis_labels(\"\", \"Worst slab coverage\")\n",
    "g_wsc.set_xticklabels([])\n",
    "g_wsc.despine(bottom=False, top=False, right=False)\n",
    "#for (dataset_name, alpha,), ax in g_wsc.axes_dict.items():\n",
    "for dataset_name, ax in g_wsc.axes_dict.items():\n",
    "    #ax.set_title(rf\"$\\mathtt{{{dataset_name}}}$, $\\alpha={alpha:.1f}$\")\n",
    "    ax.set_title(rf\"$\\mathtt{{{dataset_name}}}$\")\n",
    "    ax.axhline(1 - 0.1, xmax=1, ls=\"--\", c=\"k\", alpha=0.9)\n",
    "for ax in g_wsc.axes.flatten():\n",
    "    ax.tick_params(left=False, bottom=False)\n",
    "    ax.set_ylim(0.65, 0.95)\n",
    "\n",
    "#g_wsc.savefig(plots_dir / \"selected_results_worst_slab_coverage_250924.pdf\", bbox_inches=\"tight\")\n",
    "#g_wsc.savefig(plots_dir / \"selected_results_worst_slab_coverage_250924.png\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f14f6f2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_height"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "717d0b91",
   "metadata": {},
   "outputs": [],
   "source": [
    "12 / 5 / 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "346632e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_ugly_filter_volume_df = df[\n",
    "    (df[\"alpha\"] == 0.1) &\n",
    "    (~df[\"method_name\"].str.contains(\"Quantile\")) &\n",
    "    ~df[\"method_name\"].str.contains(\"HPD\") &\n",
    "    ~df[\"method_name\"].str.contains(\"CPFlow\") &\n",
    "    #~df[\"dataset_name\"].str.contains(\"scm1d\") &\n",
    "    #~df[\"dataset_name\"].str.contains(\"rf\") &\n",
    "    df[\"dataset_name\"].isin([\"scm20d\", \"sgemm\", \"bio\", \"blog\"]) &\n",
    "    ~df[\"method_name\"].str.contains(\"Y\")    \n",
    "].copy()\n",
    "new_ugly_filter_volume_df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc1994ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "#new_ugly_filter_volume_df.query(\"dataset_name == 'sgemm' and method_name == 'OT-CP-Local'\")[\"volume\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c27e4c95",
   "metadata": {},
   "outputs": [],
   "source": [
    "#new_ugly_filter_volume_df.query(\"dataset_name == 'sgemm' and method_name == 'OT-CP-Global'\")[\"volume\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6b134a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "#new_ugly_filter_volume_df.query(\"dataset_name == 'sgemm' and method_name == 'PB'\")[\"volume\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a32103e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_four_volumes = pd.read_csv(\"four_volumes.csv\").set_index([\"dataset_name\", \"seed\"])\n",
    "df_four_volumes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c701048",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_four_volumes_rf = pd.read_csv(\"four_volumes_rf.csv\").set_index([\"dataset_name\", \"seed\"])\n",
    "df_four_volumes_rf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdbd1ba8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "\n",
    "new_ugly_filter_volume_corrected_df = \\\n",
    "    new_ugly_filter_volume_df[(new_ugly_filter_volume_df[\"seed\"].isin([0, 1, 3])) | (~new_ugly_filter_volume_df[\"method_name\"].isin([\"PB\", \"PB (RF)\"]))].copy()\n",
    "for dataset_name, seed in itertools.product([\"scm20d\", \"sgemm\", \"bio\", \"blog\"], [0, 1, 3]):\n",
    "    #new_ugly_filter_volume_corrected_df[\"log_vol_d\"] \n",
    "    new_ugly_filter_volume_corrected_df.loc[\n",
    "        ((new_ugly_filter_volume_corrected_df[\"method_name\"] == \"PB\")) & (new_ugly_filter_volume_corrected_df[\"dataset_name\"] == dataset_name) & (new_ugly_filter_volume_corrected_df[\"seed\"] == seed),\n",
    "        \"log_vol_d\"\n",
    "    ] = df_four_volumes.loc[(dataset_name, seed), \"mean\"]\n",
    "    new_ugly_filter_volume_corrected_df.loc[\n",
    "        ((new_ugly_filter_volume_corrected_df[\"method_name\"] == \"PB (RF)\")) & (new_ugly_filter_volume_corrected_df[\"dataset_name\"] == dataset_name) & (new_ugly_filter_volume_corrected_df[\"seed\"] == seed),\n",
    "        \"log_vol_d\"\n",
    "    ] = df_four_volumes_rf.loc[(dataset_name, seed), \"mean\"]\n",
    "new_ugly_filter_volume_corrected_df.query(\"method_name == 'PB (RF)'\")[[\"dataset_name\", \"seed\", \"log_vol_d\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d403faaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "g_logvold = sns.catplot(\n",
    "    data=new_ugly_filter_volume_corrected_df,#.query(\"dataset_name == 'bio' or dataset_name == 'blog'\"),\n",
    "    kind=\"bar\",\n",
    "    y=\"log_vol_d\",\n",
    "    #col=\"alpha\",\n",
    "    #row=\"dataset_name\",\n",
    "    col=\"dataset_name\",\n",
    "    hue=\"method_name\", #_mathtext\",\n",
    "    estimator=\"median\",\n",
    "    palette=selected_palette,\n",
    "    sharey=False,\n",
    "    #showfliers=False,\n",
    "    facet_kws={\n",
    "        \"despine\": False,\n",
    "    },\n",
    "    height=plot_height,\n",
    "    linewidth=0.9,\n",
    "    edgecolor=\"k\",\n",
    "    dodge=2.6,\n",
    "    gap=0.1,\n",
    ")\n",
    "sns.move_legend(g_logvold, \"lower center\", bbox_to_anchor=(0.45, 1), ncol=len(palette), title=None, \n",
    "                labels=labels_main_part)\n",
    "g_logvold.set_axis_labels(\"\", r\"$(\\log V) / d$\")#Worst slab coverage\")\n",
    "#g_logvold.set_axis_labels(\"\", \"Volume\")\n",
    "g_logvold.set_xticklabels([])\n",
    "#g_logvold.despine(bottom=True)\n",
    "for dataset_name, ax in g_logvold.axes_dict.items():\n",
    "    ax.set_title(rf\"$\\mathtt{{{dataset_name}}}$\")\n",
    "for ax in g_logvold.axes.flatten():\n",
    "    ax.tick_params(bottom=False)\n",
    "    ax.grid(visible=True, which=\"both\", axis=\"y\")\n",
    "    ax.set_axisbelow(True)\n",
    "    #ax.set_ylim(None, 2.5)\n",
    "\n",
    "g_logvold.savefig(plots_dir / \"selected_results_volume_250925.pdf\", bbox_inches=\"tight\")\n",
    "g_logvold.savefig(plots_dir / \"selected_results_volume_250925.png\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acf4fa08",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.log(6427081) / 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0be4e9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_ugly_filter_volume_df.query(\"method_name == 'Ell-Local'\").volume.min()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9573a17e",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3818393",
   "metadata": {},
   "outputs": [],
   "source": [
    "colormap = matplotlib.colormaps[\"tab20\"]\n",
    "colormap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3341393c",
   "metadata": {},
   "outputs": [],
   "source": [
    "tuned_params = {}\n",
    "dfs = []\n",
    "for dataset_name in (\"rf1\", \"rf2\", \"scm1d\", \"scm20d\"):\n",
    "    df_tuning = pd.read_feather(f\"./conformal_results_slurm/{dataset_name}/53/tuning.feather\")\n",
    "    print(df_tuning.loc[df_tuning['error'].idxmin()])\n",
    "    tuned_params[dataset_name] = df_tuning.loc[df_tuning['error'].idxmin()].to_dict()\n",
    "    df_tuning[\"dataset_name\"] = dataset_name\n",
    "    dfs.append(df_tuning)\n",
    "df_tuning = pd.concat(dfs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2a0dd7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tuned_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7575137",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.pointplot(df_tuning, x=\"n_epochs\", y=\"error\", hue=\"dataset_name\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d78a0ecd",
   "metadata": {},
   "outputs": [],
   "source": [
    "df3 = load_methods_from(method_names=[], path=\"./conformal_results_sgemm_no_areas/\", seeds=range(10, 15), extention=\"feather\")\n",
    "df4 = load_methods_from(method_names=[], path=\"./conformal_results_sgemm_areas/\", seeds=range(10, 15),extention=\"feather\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9126967b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_sgemm = pd.merge(df3, df4,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41bf08fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_sgemm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48f33206",
   "metadata": {},
   "outputs": [],
   "source": [
    "g_sgemm_vold = sns.catplot(\n",
    "    data=df_sgemm,\n",
    "    kind=\"bar\",\n",
    "    y=\"volume\",\n",
    "    #col=\"alpha\",\n",
    "    #row=\"dataset_name\",\n",
    "    col=\"dataset_name\",\n",
    "    hue=\"method_name\", #_mathtext\",\n",
    "    estimator=\"median\",\n",
    "    #palette=selected_palette,\n",
    "    sharey=True,\n",
    "    #showfliers=False,\n",
    "    facet_kws={\n",
    "        \"despine\": False,\n",
    "    },\n",
    "    height=plot_height,\n",
    "    linewidth=0.9,\n",
    "    edgecolor=\"k\",\n",
    "    dodge=2.6,\n",
    "    gap=0.1,\n",
    ")\n",
    "plt.ylim(0, 2.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "484f129c",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_sgemm.plot(\"volume\", kind=\"hist\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f79e0da7",
   "metadata": {},
   "outputs": [],
   "source": [
    "(df_sgemm[\"volume\"] > 0).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d599d0d7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conditional_quantile_function",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
