{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2dc5e63",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "from aeon.visualisation import plot_critical_difference\n",
    "import matplotlib.pyplot as plt\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c3955e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP_IDS = [1,2]\n",
    "\n",
    "DATASETS = [\n",
    "    \"AtrialFibrillation\",\n",
    "    \"BasicMotions\",\n",
    "    \"Cricket\",\n",
    "    \"EigenWorms\",\n",
    "    \"Epilepsy\",\n",
    "    \"ERing\",\n",
    "    \"FingerMovements\",\n",
    "    \"HandMovementDirection\", \n",
    "    \"Handwriting\",  \n",
    "    \"Heartbeat\",\n",
    "    \"NATOPS\",  \n",
    "    \"SelfRegulationSCP1\", \n",
    "    \"StandWalkJump\",\n",
    "    \"UWaveGestureLibrary\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9941107",
   "metadata": {},
   "outputs": [],
   "source": [
    "## COLLECT DATA\n",
    "lst = []\n",
    "for EXP_ID in EXP_IDS:\n",
    "    for file in os.listdir(f'results/exp_{EXP_ID}'):\n",
    "        if file.endswith('.csv') and 'scores' in file:\n",
    "            df = pd.read_csv(f'results/exp_{EXP_ID}/{file}')\n",
    "            df['fold_id'] = (int(file.split('_')[-1].replace('.csv',''))+1) + (EXP_ID-1)*5\n",
    "            lst.append(df)\n",
    "df = pd.concat(lst)\n",
    "df = df[df['dataset'].isin(DATASETS)]\n",
    "df = df.reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "500abd0b",
   "metadata": {},
   "source": [
    "## Computation time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89c727a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "x =np.array(\"275 15 40 108 50 128 137 261 30 5890 316 160 150 204 180 2459 278 180 7494 267 3315 151 268 200 12 120 300 15 40 72 50 131 138 263 270 3524 100 74 850 205 180 2466 100 180 3498 173 3353 152 293 180 15 320\".split(\" \"))\n",
    "n_samples = np.sum(x.astype(int).reshape(2,-1).T,axis=1)\n",
    "y = [\"9 144 25\",\"2 640 3\",\"6 100 4\",\"6 1197 12\",\"1345 270 5\",\"6 17984 5\",\"3 206 4\",\"3 1751 4\",\"4 65 6\",\"144 62 2\",\"28 50 2\",\"10 400 4\",\"3 152 26\",\"61 405 2\",\"2 45 15\",\"6 36 14\",\"64 3000 2\",\"24 51 6\",\"2 8 10\",\"963 144 7\",\"11 217 39\",\"6 30 4\",\"6 896 2\",\"7 1152 2\",\"4 2500 3\",\"3 315 8\"]\n",
    "y = [s.split(\" \") for s in y]\n",
    "y = np.array(y).astype(int)\n",
    "data = np.hstack((n_samples.reshape(-1,1),y))\n",
    "datasets = \"ArticularyWordRecognition AtrialFibrillation BasicMotions Cricket DuckDuckGeese EigenWorms Epilepsy EthanolConcentration ERing FaceDetection FingerMovements HandMovementDirection Handwriting Heartbeat Libras LSST MotorImagery NATOPS PenDigits PEMS-SF PhonemeSpectra RacketSports SelfRegulationSCP1 SelfRegulationSCP2 StandWalkJump UWaveGestureLibrary\"\n",
    "datasets = datasets.split(\" \")\n",
    "datasets\n",
    "idxs = [1,2,3,5,6,8,10,11,12,13,25,24,17,22]\n",
    "idxs = np.sort(np.array(idxs))\n",
    "datasets_selected = np.array([datasets[i] for i in idxs])\n",
    "datasets_selected\n",
    "X = np.hstack((datasets_selected.reshape(-1,1),data[idxs]))\n",
    "X\n",
    "ddf = pd.DataFrame(data[idxs],columns=[\"n_samples\",\"n_channels\",\"n_dims\",\"n_classes\"],index=datasets_selected)\n",
    "tdf = df.groupby(['dataset','metric']).compute_time.mean().reset_index()\n",
    "tdf = tdf.merge(ddf,left_on='dataset',right_index=True)\n",
    "tdf[\"avg_time_per_sample\"] = tdf['compute_time']/tdf['n_samples']**2*1000\n",
    "print(\"average time\\n\",tdf.groupby('metric').avg_time_per_sample.mean().reset_index())\n",
    "print(\"std time\\n\", tdf.groupby('metric').avg_time_per_sample.std().reset_index())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "144e17d6",
   "metadata": {},
   "source": [
    "## Score table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "050c18b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "metric_order = [\"Hilbert-Schmidt\", \"Operator\", \"Martin\", \"SOT\", \"GOT\", \"SGOT\"]\n",
    "folds = df['fold_id'].unique()\n",
    "metrics = df['metric'].unique()\n",
    "chordal_metrics = [m for m in metrics if \"Chordal\" in m]\n",
    "other_metrics = [m for m in metrics if m not in chordal_metrics]\n",
    "\n",
    "lst = []\n",
    "for dataset in DATASETS:\n",
    "    for fold in folds:\n",
    "        df_tmp = df[(df['dataset']==dataset) & (df['fold_id']==fold) & (df[\"metric\"].isin(chordal_metrics))]\n",
    "        best_row = df_tmp.loc[df_tmp['accuracy'].idxmax()].copy()\n",
    "        best_row[\"metric\"] = \"SGOT\"\n",
    "        lst.append(best_row.to_frame().T)\n",
    "\n",
    "tdf1 = pd.concat(lst).reset_index(drop=True)\n",
    "tdf2 = df[df['metric'].isin(other_metrics)]\n",
    "cdf = pd.concat([tdf1, tdf2]).reset_index(drop=True)\n",
    "\n",
    "# Table generation\n",
    "mean_df = pd.pivot_table(cdf,columns=\"metric\", index=\"dataset\", values=\"accuracy\", aggfunc=\"mean\")\n",
    "mean_df = mean_df.reindex(columns=metric_order, level=1)\n",
    "mean_df = mean_df.astype(float)\n",
    "std_df = pd.pivot_table(cdf,columns=\"metric\", index=\"dataset\", values=\"accuracy\", aggfunc=\"std\")\n",
    "std_df = std_df.reindex(columns=metric_order, level=1)\n",
    "std_df = std_df.astype(float)\n",
    "\n",
    "columns = mean_df.columns\n",
    "index = mean_df.index\n",
    "\n",
    "lst = []\n",
    "for dataset in index:\n",
    "    m = mean_df.loc[dataset].values\n",
    "    m = np.where(np.isnan(m), -1, m)  # to handle NaN values\n",
    "    s = std_df.loc[dataset].values\n",
    "    idxs = np.argsort(m)[::-1]\n",
    "    idx_first = idxs[0]\n",
    "    idx_second = idxs[1]\n",
    "\n",
    "    t_lst = []\n",
    "    for i in range(len(columns)):\n",
    "        mean = np.around(m[i], decimals=2)\n",
    "        std = np.around(s[i], decimals=2)\n",
    "        if not (mean == -1):\n",
    "            script = f\"{mean} $\\\\pm$ {std}\"\n",
    "        else:\n",
    "            script = \"$\\\\emptyset$\"\n",
    "        if i == idx_first:\n",
    "            script = \"\\\\textbf{\" + script + \"}\"\n",
    "        elif i == idx_second:\n",
    "            script = \"\\\\underline{\" + script + \"}\"\n",
    "            \n",
    "        t_lst.append(script)\n",
    "    lst.append(t_lst)\n",
    "\n",
    "fdf = pd.DataFrame(lst, columns=columns, index=index)\n",
    "fdf_s = fdf.style\n",
    "print(fdf_s.to_latex(hrules = True, clines = \"skip-last;data\",multirow_align = \"t\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06792db7",
   "metadata": {},
   "source": [
    "## Critical diagram difference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac879c4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "rdf = pd.pivot_table(cdf, index=[\"dataset\",\"fold_id\"], columns=\"metric\", values=\"accuracy\", aggfunc=\"first\")\n",
    "rdf = rdf.astype(float)\n",
    "rdf = rdf.reindex(columns=metric_order, level=1)\n",
    "rdf = rdf.fillna(0)\n",
    "ranks = rdf.rank(ascending=False, axis=1, method='min')\n",
    "avg_ranks = ranks.mean(axis=0).values\n",
    "std_ranks = ranks.std(axis=0).values\n",
    "\n",
    "lst = []\n",
    "for i,(rank,std) in enumerate(zip(avg_ranks,std_ranks)):\n",
    "    sting = f\"{np.around(rank,2)} $\\\\pm$ {np.around(std,2)}\"\n",
    "    if i == np.argmin(avg_ranks):\n",
    "        sting = \"\\\\textbf{\" + sting + \"}\"\n",
    "    elif i == np.argsort(avg_ranks)[1]:\n",
    "        sting = \"\\\\underline{\" + sting + \"}\"\n",
    "    lst.append(sting)\n",
    "\n",
    "line = \"avg. rank (lower is better) & \" + \" & \".join(lst) + \" \\\\\\\\\"\n",
    "print(line)\n",
    "\n",
    "fig,ax = plot_critical_difference(rdf.values,rdf.columns.to_list())\n",
    "fig.set_figheight(1.8)\n",
    "fig.savefig(f\"critical_difference.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72ce515a",
   "metadata": {},
   "source": [
    "## Score comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec9dd928",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams[\"font.family\"] = \"Helvetica\"\n",
    "plt.rcParams[\"xtick.labelsize\"] = 10\n",
    "plt.rcParams[\"ytick.labelsize\"] = 10\n",
    "plt.rcParams[\"axes.labelsize\"] = 14\n",
    "plt.rcParams[\"legend.fontsize\"] = 14\n",
    "plt.rcParams['axes.titlesize'] = 14\n",
    "\n",
    "pv_cdf = pd.pivot_table(cdf, index=[\"dataset\",\"fold_id\"], columns=\"metric\", values=\"accuracy\", aggfunc=\"first\")\n",
    "pv_cdf = pv_cdf.astype(float)\n",
    "\n",
    "compare_metrics = ['Hilbert-Schmidt', 'Operator', 'SOT', 'GOT']\n",
    "metric = 'SGOT'\n",
    "\n",
    "fig,axs = plt.subplots(1,len(compare_metrics), figsize=(3*len(compare_metrics)-1,3),sharey=True)\n",
    "x_min,x_max = -0.05,1.05\n",
    "y_min,y_max = -0.05,1.05\n",
    "for i, m in enumerate(compare_metrics):\n",
    "    x = [x_min, x_max]\n",
    "    y1 = [y_min, y_max]\n",
    "    y2 = [y_max, y_max]\n",
    "    axs[i].fill_between(x, y1, y2, color='green', alpha=0.1, zorder=0)\n",
    "    axs[i].plot(x,y1, ls=\"--\", color=\"k\",alpha=0.5)\n",
    "\n",
    "    axs[i].scatter(pv_cdf[m], pv_cdf[metric],alpha = 0.5, color=\"tab:blue\", label=\"fold/dataset score\")\n",
    "    axs[i].set_aspect('equal', adjustable='box')\n",
    "    axs[i].set_xlabel(m)\n",
    "    axs[i].text(-0.01, 0.9, r\"$\\bf{SGOT}$\" + \"\\n\" + r\"$\\bf{better}$\", color=\"tab:green\", alpha=1, fontsize=9)\n",
    "\n",
    "for i, m in enumerate(compare_metrics):\n",
    "    axs[i].scatter(mean_df[m], mean_df[metric],alpha = 1, color=\"tab:red\",marker='^',s=50, label=\"avg. per dataset\")\n",
    "\n",
    "for ax in axs:\n",
    "    ax.set_xticks(ax.get_yticks())\n",
    "    ax.set_xlim(x_min,x_max)\n",
    "    ax.set_ylim(y_min,y_max)\n",
    "\n",
    "axs[-1].legend(loc='lower right', fontsize=9)\n",
    "\n",
    "axs[0].set_ylabel(metric)\n",
    "fig.tight_layout()\n",
    "\n",
    "plt.savefig(f\"scatter_comparison.pdf\", bbox_inches='tight')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "KooPOT",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
