{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd728536",
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "import ast\n",
    "import sys\n",
    "import os\n",
    "from itertools import product\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "import torch\n",
    "import pandas as pd\n",
    "import plotly.express as px\n",
    "import torch.nn.functional as F\n",
    "from torch import nn\n",
    "from scipy.io import loadmat\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n",
    "from sklearn.model_selection import cross_val_predict\n",
    "from sklearn.metrics import balanced_accuracy_score\n",
    "import matplotlib.lines as mlines\n",
    "\n",
    "sys.path.insert(1, \"/\".join(os.path.abspath(\"\").split(\"/\")[0:-1]))\n",
    "import models\n",
    "\n",
    "from importlib import reload\n",
    "\n",
    "reload(models)\n",
    "import os\n",
    "\n",
    "os.chdir(\"your dir\")\n",
    "print(\"CWD =\", os.getcwd())\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6e9862a",
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "def draw_scatter(df, \n",
    "                    ax, \n",
    "                    metric, \n",
    "                    method_list,\n",
    "                    x_vals,\n",
    "                    width,\n",
    "                    palette_dict,\n",
    "                    grouped, \n",
    "                    subgroup, \n",
    "                    ):\n",
    "    for i, method in enumerate(method_list):\n",
    "        sub = df[df[\"method\"] == method]\n",
    "        x_idx = np.array([np.where(x_vals == t)[0][0] for t in sub[subgroup]])\n",
    "        x_jitter = x_idx + (i - (len(method_list)-1)/2) * width\n",
    "        x_jitter = x_jitter + np.random.uniform(-0.04, 0.04, size=len(x_jitter))\n",
    "\n",
    "        ax.scatter(\n",
    "            x_jitter,\n",
    "            sub[metric],\n",
    "            s=64,\n",
    "            facecolors='none',\n",
    "            edgecolors=palette_dict[method],\n",
    "            linewidths=1.5,\n",
    "            # label=method,\n",
    "            label=None,\n",
    "            alpha=0.7,\n",
    "            marker='o',\n",
    "        )\n",
    "\n",
    "    \n",
    "def draw_errorbar(df, \n",
    "                 ax, \n",
    "                 metric, \n",
    "                 method_list,\n",
    "                 x_vals,\n",
    "                 width,\n",
    "                 palette_dict,\n",
    "                 grouped, \n",
    "                ):\n",
    "\n",
    "    for i, method in enumerate(method_list):\n",
    "        means = []\n",
    "        stds = []\n",
    "        for x in x_vals:\n",
    "            vals = grouped.get_group((x, method))[metric] if (x, method) in grouped.groups else []\n",
    "            if len(vals) > 0:\n",
    "                means.append(np.mean(vals))\n",
    "                stds.append(np.std(vals))\n",
    "            else:\n",
    "                means.append(np.nan)\n",
    "                stds.append(np.nan)\n",
    "\n",
    "        x_jitter = np.arange(len(x_vals)) + (i - (len(method_list)-1)/2) * width\n",
    "\n",
    "        ax.errorbar(\n",
    "            x_jitter, means, yerr=stds,\n",
    "            fmt='o', \n",
    "            # alpha=0.9,\n",
    "            alpha=1.0,\n",
    "            markersize=14,                # mean point size\n",
    "            markerfacecolor=palette_dict[method],      # mean point face color\n",
    "            markeredgecolor='black',  # mean point edge color\n",
    "            markeredgewidth=1.6,            # mean point edge width\n",
    "            ecolor='black',               # error bar line color\n",
    "            elinewidth=1.6,               # error bar line width\n",
    "            capsize=5,                    # error bar cap width\n",
    "            label=method,\n",
    "            zorder=10\n",
    "        )\n",
    "        ax.plot(\n",
    "            x_jitter, means,\n",
    "            color=palette_dict[method],\n",
    "            linewidth=2.5,\n",
    "            zorder=9,\n",
    "            alpha=0.8\n",
    "        )\n",
    "\n",
    "def get_s_curve(x_center, y_center, d_x, d_y, steps=100):\n",
    "    t = np.linspace(-1, 1, steps)\n",
    "    x = x_center + t * d_x\n",
    "    y = y_center + t * d_y + 0.5 * d_y * np.sin(t * np.pi) \n",
    "    return x, y"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71b078e2",
   "metadata": {},
   "source": [
    "### Training params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07110043",
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "# trial_list = np.arange(10)[1:2] # only plot for trial 0\n",
    "trial_list = np.arange(10)[6:7]\n",
    "# trial_list = np.arange(10)[9:10]\n",
    "\n",
    "method_list = [\"exp\", \"cubic_exp\", \"GS\", \"score\"]\n",
    "upperbound_list = [8]\n",
    "upperbound_method_list = [\"quantile\"]\n",
    "\n",
    "# temp_list = [0.02, 0.05, 0.1, 0.2, 0.5]\n",
    "temp_list = [\n",
    "    0.02, 0.05, 0.1, 0.15, 0.2, 0.25,\n",
    "    0.3, 0.35, 0.4, 0.45, 0.5,\n",
    "] # 0.02, \n",
    "# temp_list = [\n",
    "#     0.02, 0.05, 0.1, 0.2, 0.5,\n",
    "# ]\n",
    "\n",
    "n_monte_carlo_list = [1, 2, 5, 10]\n",
    "seed_list = np.arange(0, 3)\n",
    "\n",
    "method_alias = {\n",
    "    \"exp\": \"EAT_sigmoid\",\n",
    "    \"cubic_exp\": \"EAT_cubic\",\n",
    "    \"GS\": \"Gumbel-Softmax\",\n",
    "    \"score\": \"Score\",\n",
    "}\n",
    "COLORS = {\n",
    "    'EAT_sigmoid': 'C0',\n",
    "    'EAT_cubic': 'C3',\n",
    "    'Gumbel-Softmax': 'C2',\n",
    "    'Score': 'C7',\n",
    "}\n",
    "metric_list = [\"elbo\", \"mll\", \"cll\", \"hll\", \"weights_error\", \"bias_error\"]\n",
    "metric_alias = {\n",
    "    \"elbo\": \"ELBO\",\n",
    "    \"mll\": \"$\\\\ln p(\\\\boldsymbol{x};\\\\theta)$\",\n",
    "    \"cll\": \"$\\\\ln p(\\\\boldsymbol{x}|\\\\boldsymbol{z};\\\\theta)$\",\n",
    "    \"hll\": \"$\\\\ln q(\\\\boldsymbol{z}|\\\\boldsymbol{x};\\\\phi)$\",\n",
    "    \"weights_error\": \"Weights Error\",\n",
    "    \"bias_error\": \"Bias Error\",\n",
    "}\n",
    "\n",
    "palette = sns.color_palette(\"tab10\", n_colors=4, desat=0.75)\n",
    "palette_dict = {m: COLORS[method_alias[m]] for m in method_list}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60b70a26",
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "# tag = \"baseline\"\n",
    "tag = \"baseline_seed_control\"\n",
    "\n",
    "df_result = []\n",
    "idx = 0\n",
    "for (\n",
    "    trial,\n",
    "    method,\n",
    "    upperbound,\n",
    "    upperbound_method,\n",
    "    temp,\n",
    "    n_monte_carlo,\n",
    "    seed,\n",
    ") in product(\n",
    "    trial_list,\n",
    "    method_list,\n",
    "    upperbound_list,\n",
    "    upperbound_method_list,\n",
    "    temp_list,\n",
    "    n_monte_carlo_list,\n",
    "    seed_list,\n",
    "):\n",
    "    name = f\"{trial}_{method}_{upperbound}_{upperbound_method}_{temp}_{n_monte_carlo}_{seed}\"\n",
    "    results_folder = f\"results_{tag}/{name}\"\n",
    "    try:\n",
    "        df_temp = pd.read_csv(f\"{results_folder}/metrics_last.csv\")\n",
    "        df_temp[\"trial\"] = trial\n",
    "        df_temp[\"method\"] = method\n",
    "        df_temp[\"upperbound\"] = upperbound\n",
    "        df_temp[\"upperbound_method\"] = upperbound_method\n",
    "        df_temp[\"temp\"] = temp\n",
    "        df_temp[\"n_monte_carlo\"] = n_monte_carlo\n",
    "        df_temp[\"seed\"] = seed\n",
    "        df_result.append(df_temp)\n",
    "    except:\n",
    "        print(\n",
    "            idx,\n",
    "            f\"{results_folder} not found\",\n",
    "        )\n",
    "    idx += 1\n",
    "\n",
    "df_result = pd.concat(df_result, ignore_index=True)\n",
    "df_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be84290b",
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "df_result.to_csv(f\"benchmark_dense_{tag}.csv\", index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5cbbf382",
   "metadata": {},
   "source": [
    "### Plot metric vs tau"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f480388",
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "dodge_val = 0.5\n",
    "\n",
    "# print(palette_dict)\n",
    "left_right_ratio = len(temp_list) + 0.5\n",
    "\n",
    "for metric in metric_list:\n",
    "    fig, axs = plt.subplots(\n",
    "        1, 2, \n",
    "        figsize=(12, 5), \n",
    "        sharey=True, \n",
    "        gridspec_kw={'width_ratios': [1, left_right_ratio]},\n",
    "        layout=\"constrained\"\n",
    "    )\n",
    "    ax_left, ax_right = axs[0], axs[1]\n",
    "\n",
    "    width = 0.03  # x offset for each method\n",
    "    df_result_score = df_result[(df_result[\"method\"] == \"score\") & (df_result[\"temp\"] == 0.2)]\n",
    "    x_vals = np.array(sorted(df_result_score[\"n_monte_carlo\"].unique()))\n",
    "    # only draw score method error bars\n",
    "    draw_scatter(\n",
    "        df=df_result_score,\n",
    "        ax=ax_left,\n",
    "        metric=metric,\n",
    "        method_list=df_result_score[\"method\"].unique(),\n",
    "        x_vals=x_vals,\n",
    "        width=width,\n",
    "        palette_dict=palette_dict,\n",
    "        grouped= df_result_score.groupby(['n_monte_carlo', 'method']), \n",
    "        subgroup=\"n_monte_carlo\",\n",
    "    )\n",
    "\n",
    "    grouped = df_result_score.groupby(['n_monte_carlo', 'method'])\n",
    "    x_vals = np.sort(df_result_score['n_monte_carlo'].unique())\n",
    "    draw_errorbar(\n",
    "        df=df_result_score,\n",
    "        ax=ax_left,\n",
    "        metric=metric,\n",
    "        method_list=df_result_score[\"method\"].unique(),\n",
    "        x_vals=x_vals,\n",
    "        width=width,\n",
    "        palette_dict=palette_dict,\n",
    "        grouped= grouped\n",
    "    )\n",
    "\n",
    "    ymin, ymax = ax_left.get_ylim()\n",
    "    y_range = ymax - ymin\n",
    "    for x, mc in enumerate(x_vals):\n",
    "        vals = grouped.get_group((mc, \"score\"))[metric] if (mc, \"score\") in grouped.groups else []\n",
    "        if len(vals) > 0:\n",
    "            y = np.max(vals)\n",
    "            ax_left.text(x, y + 0.05 * y_range, str(mc), ha='center', va='bottom', fontsize=14, color='black')\n",
    "\n",
    "\n",
    "    ax_left.set_xlim(-1, len(x_vals)-0.5)\n",
    "    ax_left.set_xticks(np.arange(len(x_vals)))\n",
    "    ax_left.set_xticklabels([])\n",
    "    # ax_left.tick_params(axis='x', labelsize=16)\n",
    "    # ax_left.set_xticklabels([str(t) for t in x_vals])\n",
    "    ax_left.set_xlabel(\"$N$ monte carlo\", fontsize=18)\n",
    "    ax_left.set_ylabel(metric_alias[metric], fontsize=18, )\n",
    "    ax_left.tick_params(axis='y', labelleft=True, labelsize=16, )\n",
    "    # ax_left.spines['top'].set_visible(False)\n",
    "    ax_left.spines['right'].set_visible(False)\n",
    "    ax_left.grid(True, which='both', axis='both', linestyle='--', linewidth=0.8, alpha=0.6)\n",
    "    # ax_left.legend(title=\"method\", fontsize=12, title_fontsize=12)\n",
    "    ax_left.legend().set_visible(False)\n",
    "\n",
    "\n",
    "    width = 0.2  # x offset for each method\n",
    "    df_result_rest = df_result[(df_result[\"method\"] != \"score\") & (df_result[\"n_monte_carlo\"] == 1)]\n",
    "    # only draw score method error bars\n",
    "    x_vals = np.array(sorted(df_result_rest[\"temp\"].unique()))\n",
    "    draw_scatter(\n",
    "        df=df_result_rest,\n",
    "        ax=ax_right,\n",
    "        metric=metric,\n",
    "        method_list=df_result_rest[\"method\"].unique(),\n",
    "        x_vals=x_vals,\n",
    "        width=width,\n",
    "        palette_dict=palette_dict,\n",
    "        grouped= df_result_rest.groupby(['temp', 'method']), \n",
    "        subgroup=\"temp\",\n",
    "    )\n",
    "\n",
    "    grouped = df_result_rest.groupby(['temp', 'method'])\n",
    "    x_vals = np.sort(df_result_rest['temp'].unique())\n",
    "    draw_errorbar(\n",
    "        df=df_result_rest,\n",
    "        ax=ax_right,\n",
    "        metric=metric,\n",
    "        method_list=df_result_rest[\"method\"].unique(),\n",
    "        x_vals=x_vals,\n",
    "        width=width,\n",
    "        palette_dict=palette_dict,\n",
    "        grouped= grouped\n",
    "    )\n",
    "    ax_right.spines['left'].set_color((0.5, 0.5, 0.5, 0.5))\n",
    "    ax_right.tick_params(axis='y', labelleft=False, labelsize=8)\n",
    "    ax_right.tick_params(axis='x', labelsize=16)\n",
    "    ax_right.set_xticks(np.arange(len(x_vals)))\n",
    "    ax_right.set_xticklabels([str(t) for t in x_vals])\n",
    "    # ax_right.set_xticklabels([])\n",
    "    ax_right.set_xlabel(\"Temperature ($\\\\tau$) ($N$ monte carlo = 1)\", fontsize=18, )\n",
    "    ax_right.spines['top'].set_visible(False) \n",
    "    # ax_right.spines['right'].set_visible(False)\n",
    "    ax_right.grid(True, which='both', axis='both', linestyle='--', linewidth=0.8, alpha=0.6)\n",
    "    # ax_right.legend(title=\"method\", fontsize=16, title_fontsize=16)\n",
    "\n",
    "    handles = [\n",
    "        mlines.Line2D([], [], color=palette_dict[method], marker='o', linestyle='None',\n",
    "                    markersize=14, markerfacecolor=palette_dict[method], markeredgecolor='black', markeredgewidth=1.6, \n",
    "                    label=method_alias[method])\n",
    "        for method in method_list\n",
    "    ]\n",
    "    ax_right.legend(handles=handles, fontsize=16, title_fontsize=16, loc=\"best\") #title=\"method\", \n",
    "    # ax_right.legend().set_visible(False)\n",
    "\n",
    "    # draw break lines\n",
    "    d = 0.03\n",
    "    d_x_left = d \n",
    "    d_x_right = d / left_right_ratio\n",
    "    kwargs = dict(transform=ax_left.transAxes, color='k', clip_on=False, linewidth=1.5)\n",
    "    lx, ly = get_s_curve(1, 0, d_x_left, d)\n",
    "    ax_left.plot(lx, ly, **kwargs)\n",
    "    lx, ly = get_s_curve(1, 1, d_x_left, d)\n",
    "    ax_left.plot(lx, ly, **kwargs)\n",
    "    kwargs.update(transform=ax_right.transAxes)\n",
    "    rx, ry = get_s_curve(0, 1, d_x_right, d)\n",
    "    ax_right.plot(rx, ry, **kwargs)\n",
    "    rx, ry = get_s_curve(0, 0, d_x_right, d)\n",
    "    ax_right.plot(rx, ry, **kwargs)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4423cb4f",
   "metadata": {},
   "source": [
    "### Pick metric vs best tau"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "699ac4d8",
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "best_tau_map = {\n",
    "    \"score\": 0.05,\n",
    "    \"GS\": 0.1,  \n",
    "    \"exp\": 0.02,\n",
    "    \"cubic_exp\": 0.05,\n",
    "}\n",
    "method_alias_tau = {\n",
    "    \"exp\": f\"EAT_sigmoid, $\\\\tau$={best_tau_map['exp']}\", \n",
    "    \"cubic_exp\": f\"EAT_cubic, $\\\\tau$={best_tau_map['cubic_exp']}\",\n",
    "    \"GS\": f\"Gumbel-Softmax, $\\\\tau$={best_tau_map['GS']}\",\n",
    "    \"score\": \"Score\",\n",
    "}\n",
    "\n",
    "data_mc_list = []\n",
    "for m, t in best_tau_map.items():\n",
    "    data_mc_list.append(df_result[(df_result[\"method\"] == m) & (df_result[\"temp\"] == t)])\n",
    "data_p2 = pd.concat(data_mc_list)\n",
    "print(data_p2.shape) \n",
    "\n",
    "dodge_val = 0.5\n",
    "\n",
    "data_p2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc9e6cba",
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "handles = [\n",
    "    mlines.Line2D([], [], color=palette_dict[method], marker='o', linestyle='None',\n",
    "                markersize=14, markerfacecolor=palette_dict[method], markeredgecolor='black', \n",
    "                label=method_alias[method])\n",
    "    for method in list(method_alias.keys())\n",
    "]\n",
    "fig_legend = plt.figure(figsize=(3, 0.5)) \n",
    "\n",
    "leg = fig_legend.legend(\n",
    "    handles=handles, \n",
    "    loc='center', \n",
    "    ncol=len(handles), \n",
    "    fontsize=16, \n",
    "    frameon=False, \n",
    "    columnspacing=1.0, \n",
    "    handletextpad=0.5  \n",
    ")\n",
    "\n",
    "fig_legend.savefig(\n",
    "    \"legend_only.pdf\", \n",
    "    bbox_inches='tight', \n",
    "    bbox_extra_artists=[leg] \n",
    ")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82d76a38",
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "for metric in metric_list:\n",
    "    fig, axs = plt.subplots(1, 1, figsize=(8, 5), layout=\"constrained\")\n",
    "\n",
    "    x_vals = np.array(sorted(data_p2[\"n_monte_carlo\"].unique()))\n",
    "    width = 0.2  # x offset for each method\n",
    "\n",
    "    draw_scatter(\n",
    "        df=data_p2,\n",
    "        ax=axs,\n",
    "        metric=metric,\n",
    "        method_list=list(method_alias.keys()), \n",
    "        x_vals=x_vals,\n",
    "        width=width,\n",
    "        palette_dict=palette_dict,\n",
    "        grouped= data_p2.groupby(['n_monte_carlo', 'method']), \n",
    "        subgroup=\"n_monte_carlo\",\n",
    "    )\n",
    "\n",
    "    axs.set_xticks(np.arange(len(x_vals)))\n",
    "    axs.set_xticklabels([str(t) for t in x_vals])\n",
    "    # axs.set_xlabel(\"$N$ monte carlo\", fontsize=18, )\n",
    "\n",
    "    draw_errorbar(\n",
    "        df=data_p2,\n",
    "        ax=axs,\n",
    "        metric=metric,\n",
    "        method_list=list(method_alias.keys()),\n",
    "        x_vals=x_vals,\n",
    "        width=width,\n",
    "        palette_dict=palette_dict,\n",
    "        grouped= data_p2.groupby(['n_monte_carlo', 'method']), \n",
    "    )\n",
    "\n",
    "    ymin, ymax = axs.get_ylim()\n",
    "    y_range = ymax - ymin\n",
    "    axs.grid(True, which='both', axis='both', linestyle='--', linewidth=0.8, alpha=0.6)\n",
    "    axs.tick_params(axis='x', labelsize=16)\n",
    "    axs.tick_params(axis='y', labelsize=16)\n",
    "    axs.set_xticks(np.arange(len(x_vals)))\n",
    "    axs.set_xticklabels([str(t) for t in x_vals])\n",
    "    # axs.set_ylim(bottom=-y_range * 0.4 + ymin) #, top=max(data_p2[metric])*0.9995\n",
    "    # axs.set_xlabel(\"$N$ monte carlo\", fontsize=18, )\n",
    "    axs.set_ylabel(metric_alias[metric], fontsize=18, )\n",
    "    # axs.set_title(f\"{metric_alias[metric]}(best $\\\\tau$)\", fontsize=20) # vs $N$ monte carlo , weight=\"bold\"\n",
    "\n",
    "    # handles = [\n",
    "    #     mlines.Line2D([], [], color=palette_dict[method], marker='o', linestyle='None',\n",
    "    #                 markersize=14, markerfacecolor=palette_dict[method], markeredgecolor='black', \n",
    "    #                 label=method_alias_tau[method])\n",
    "    #     for method in list(method_alias.keys())\n",
    "    # ]\n",
    "    # axs.legend(\n",
    "    # handles=handles,\n",
    "    # fontsize=16,\n",
    "    # title_fontsize=16,\n",
    "    # loc=\"upper center\",\n",
    "    # bbox_to_anchor=(0.5, -0.15),\n",
    "    # ncol=2,\n",
    "    # columnspacing=0.5\n",
    "    # )\n",
    "    # axs.legend(handles=handles, fontsize=16, title_fontsize=16, loc=\"\", ncol=2, columnspacing=0.5) #title=\"method\", \n",
    "    # axs.legend().set_visible(False)\n",
    "    fig.savefig(f\"lines_best_tau_POGLM_{metric}.pdf\", transparent=True,)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "csai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.23"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
