{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Analysis of Latent Recalibration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import logging\n",
    "from functools import partial\n",
    "import math\n",
    "from pathlib import Path\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "\n",
    "from moc.analysis.dataframes import (\n",
    "    agg_mean_sem,\n",
    "    format_cell_jupyter,\n",
    "    format_cell_latex,\n",
    "    get_datasets_df,\n",
    "    get_metric_df,\n",
    "    load_config,\n",
    "    load_df,\n",
    "    to_latex,\n",
    ")\n",
    "from moc.analysis.highlighter import Highlighter\n",
    "from moc.analysis.plot_barplots import barplots\n",
    "from moc.analysis.plot_calibration import plot_reliability_diagrams\n",
    "from moc.models.tarflow.tarflow import image_shape\n",
    "from moc.utils import savefig, set_notebook_options\n",
    "\n",
    "set_notebook_options(logging.WARNING)\n",
    "\n",
    "plt.rcParams.update(\n",
    "    {\n",
    "        'axes.titlesize': 12,\n",
    "        'axes.labelsize': 12,\n",
    "        'legend.fontsize': 14,\n",
    "    }\n",
    ")\n",
    "name = 'lr'\n",
    "path = Path('results') / name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = load_config(Path('logs') / name)\n",
    "df_raw = load_df(config)\n",
    "df = get_metric_df(config, df_raw).reset_index()\n",
    "df_ds = get_datasets_df(config, reload=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tables creation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_table(df, name, metrics, format_cell_kwargs={}):\n",
    "    df = df.copy().reset_index()\n",
    "    df = df.query('metric in @metrics')\n",
    "    df = df.query('posthoc_method != \"HDR\" or metric not in [\"nll\", \"latent_calibration\"]')\n",
    "    df['metric'] = pd.Categorical(df['metric'], categories=metrics)\n",
    "    plot_df = df.reset_index()[['abb', 'metric', 'name', 'value', 'run_id']]\n",
    "    pivot_df = plot_df.pivot_table(\n",
    "        index='abb',\n",
    "        columns=('metric', 'name'),\n",
    "        values='value',\n",
    "        aggfunc=agg_mean_sem,\n",
    "        observed=True,\n",
    "    )\n",
    "    styled_table = pivot_df.style.apply(\n",
    "        Highlighter().highlight_statistically_similar_to_best_per_metric, axis=None\n",
    "    )\n",
    "    to_latex(\n",
    "        styled_table.format(partial(format_cell_latex, **format_cell_kwargs)),\n",
    "        path / 'tables' / f'{name}_lr.tex',\n",
    "    )\n",
    "    return styled_table.format(partial(format_cell_jupyter, add_sem=True, **format_cell_kwargs))\n",
    "\n",
    "\n",
    "def create_tables(df, name):\n",
    "    create_table(df, f'{name}_scoring_rules', metrics=['nll', 'energy_score'])\n",
    "    create_table(df, f'{name}_calibration', metrics=['latent_calibration', 'hdr_calibration'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bar plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_all_barplots(plot_df, dir_name):\n",
    "    barplots(plot_df, ['latent_calibration', 'hdr_calibration'])\n",
    "    savefig(path / 'barplot' / dir_name / 'calibration.pdf')\n",
    "    barplots(plot_df, ['nll', 'energy_score'], width=4.8)\n",
    "    savefig(path / 'barplot' / dir_name / 'scoring_rules.pdf')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reliability diagrams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_all_reliability_diagrams(df, dir_name):\n",
    "    plot_reliability_diagrams(\n",
    "        df.query('posthoc_method != \"HDR\"'), 'latent_distance', config, ncols=5, ncols_legend=5\n",
    "    )\n",
    "    savefig(path / 'reliability_diagrams' / dir_name / 'latent_distance.pdf')\n",
    "    plot_reliability_diagrams(df, 'hpd', config, ncols=5, ncols_legend=3)\n",
    "    savefig(path / 'reliability_diagrams' / dir_name / 'hpd.pdf')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convex potential flows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_df = df.query('model == \"MQF2\"')\n",
    "plot_df = plot_df.query('posthoc_density_estimator.isna() or posthoc_density_estimator == \"kde\"')\n",
    "\n",
    "plot_all_barplots(plot_df, 'MQF2')\n",
    "plot_all_reliability_diagrams(plot_df, 'MQF2')\n",
    "create_tables(plot_df, 'MQF2')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ARFlow results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_df = df.query('model == \"ARFlow\"')\n",
    "plot_df = plot_df.query('transform_type == \"spline-quadratic\" and hidden_size == 64 and num_layers == 2')\n",
    "\n",
    "plot_all_barplots(plot_df, 'ARFlow')\n",
    "create_tables(plot_df, 'ARFlow')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Misspecified convex potential flow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = load_config(Path('logs') / 'lr_misspecified')\n",
    "df_raw = load_df(config)\n",
    "df = get_metric_df(config, df_raw).reset_index()\n",
    "df_ds = get_datasets_df(config, reload=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_df = df.query('model == \"MQF2\"')\n",
    "plot_df = plot_df.query('posthoc_density_estimator.isna() or posthoc_density_estimator == \"kde\"')\n",
    "\n",
    "plot_all_barplots(plot_df, 'MQF2')\n",
    "create_tables(plot_df, 'MQF2')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TarFlow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# BPD metric\n",
    "def nll_to_bpd(nll, k=128):\n",
    "    n_dims = image_shape.numel()\n",
    "    # Scale\n",
    "    bpd = nll + math.log(k) * n_dims\n",
    "    # Bits per dimension\n",
    "    return bpd / (n_dims * math.log(2))\n",
    "\n",
    "\n",
    "def add_bpd(df):\n",
    "    nll_rows = df.query('metric == \"nll\"')\n",
    "    df_bpd = nll_rows.assign(value=lambda x: x['value'].apply(nll_to_bpd), metric='bpd')\n",
    "    return pd.concat([df, df_bpd], axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "names = ['lr_tarflow_noisy', 'lr_tarflow_no_noise']\n",
    "for name in names:\n",
    "    config = load_config(Path('logs') / name)\n",
    "    df = get_metric_df(config, load_df(config)).reset_index()\n",
    "    df = add_bpd(df)\n",
    "    display(\n",
    "        create_table(\n",
    "            df,\n",
    "            name,\n",
    "            metrics=['latent_calibration_100', 'bpd'],\n",
    "            format_cell_kwargs={'mean_digits': 4, 'sem_digits': 4},\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
