{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e738b0a4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                  Method  Final Accuracy\n",
      "0  shannon_hypersphere_k1_stl10_resnet18        0.819750\n",
      "1                  VICReg_stl10_resnet18        0.816875\n",
      "2              Supervised_stl10_resnet18        0.718625\n",
      "3                  SimCLR_stl10_resnet18        0.828875\n",
      "4             BarlowTwins_stl10_resnet18        0.821875\n",
      "5                    BYOL_stl10_resnet18        0.812750\n",
      "6                    DINO_stl10_resnet18        0.797500\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Read the CSV file\n",
    "df = pd.read_csv(\"stl10_lin_top1.csv\")\n",
    "\n",
    "# Plot the training process for val_metrics/lin_top1\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.plot(df[\"epoch\"], df[\"shannon_hypersphere_k1_stl10_resnet18 - val_metrics/lin_top1\"], label=\"Shannon Hypersphere\")\n",
    "plt.plot(df[\"epoch\"], df[\"VICReg_stl10_resnet18 - val_metrics/lin_top1\"], label=\"VICReg\")\n",
    "plt.plot(df[\"epoch\"], df[\"Supervised_stl10_resnet18 - val_metrics/lin_top1\"], label=\"Supervised\")\n",
    "plt.plot(df[\"epoch\"], df[\"SimCLR_stl10_resnet18 - val_metrics/lin_top1\"], label=\"SimCLR\")\n",
    "plt.plot(df[\"epoch\"], df[\"BarlowTwins_stl10_resnet18 - val_metrics/lin_top1\"], label=\"BarlowTwins\")\n",
    "plt.plot(df[\"epoch\"], df[\"BYOL_stl10_resnet18 - val_metrics/lin_top1\"], label=\"BYOL\")\n",
    "plt.plot(df[\"epoch\"], df[\"DINO_stl10_resnet18 - val_metrics/lin_top1\"], label=\"DINO\")\n",
    "\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"val_metrics/lin_top1\")\n",
    "plt.title(\"Training Process for val_metrics/lin_top1\")\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"training_process_val_metrics_lin_top1.pdf\")\n",
    "plt.close()\n",
    "\n",
    "methods = [\n",
    "    \"shannon_hypersphere_k1_stl10_resnet18\",\n",
    "    \"VICReg_stl10_resnet18\",\n",
    "    \"Supervised_stl10_resnet18\",\n",
    "    \"SimCLR_stl10_resnet18\",\n",
    "    \"BarlowTwins_stl10_resnet18\",\n",
    "    \"BYOL_stl10_resnet18\",\n",
    "    \"DINO_stl10_resnet18\"\n",
    "]\n",
    "\n",
    "final_acc = []\n",
    "for method in methods:\n",
    "    col = f\"{method} - val_metrics/lin_top1\"\n",
    "    final_acc.append(df[col].iloc[-1])\n",
    "\n",
    "result_df = pd.DataFrame({\n",
    "    \"Method\": methods,\n",
    "    \"Final Accuracy\": final_acc\n",
    "})\n",
    "\n",
    "print(result_df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c1396bf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "from typing import Dict, List, Tuple, Optional\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def generate_plots_and_table(\n",
    "    dataset: str,\n",
    "    dataset_label: str,\n",
    "    arch_key: str,\n",
    "    arch_label: str,\n",
    "    methods_map: Dict[str, str],\n",
    "    metrics: Optional[List[str]] = None,\n",
    "    out_dir: str = \"figs\",\n",
    "    csv_dir: str = \".\",\n",
    "    epoch_col: str = \"epoch\",\n",
    "    series_prefix: str = \"\",  # if CSV columns are like \"<method> - val_metrics/<metric>\", leave as \"\"\n",
    "    series_template: str = \"{method} - val_metrics/{metric}\",  # customize if your CSV uses a different pattern\n",
    ") -> Tuple[pd.DataFrame, str]:\n",
    "    \"\"\"\n",
    "    Generate one PDF per metric with training curves, and a summary table of final values (last row).\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    dataset : str\n",
    "        Dataset name used in CSV filenames, e.g. \"stl10\".\n",
    "    dataset_label : str\n",
    "        Pretty name for the dataset as shown in the LaTeX table (e.g., \"STL-10\").\n",
    "    arch_key : str\n",
    "        Architecture key (for your own bookkeeping/filenames); not used to find columns.\n",
    "    arch_label : str\n",
    "        Pretty name for the architecture as shown in the LaTeX table (e.g., \"RN18\").\n",
    "    methods_map : Dict[str, str]\n",
    "        Mapping from CSV method key -> pretty display label in plots/table.\n",
    "        Example: {\"shannon_hypersphere_k1_stl10_resnet18\": \"Ours\", \"VICReg_stl10_resnet18\": \"VICReg\", ...}\n",
    "    metrics : List[str]\n",
    "        Metrics to process. Defaults to [\"lin_top1\",\"lin_top5\",\"knn_top1\",\"knn_top5\"] if None.\n",
    "    out_dir : str\n",
    "        Directory to save PDFs.\n",
    "    csv_dir : str\n",
    "        Directory where CSVs live.\n",
    "    epoch_col : str\n",
    "        Name of the epoch column in CSVs.\n",
    "    series_prefix : str\n",
    "        If your column names have a prefix, set it here (usually empty).\n",
    "    series_template : str\n",
    "        Template to build the series column name. Default matches \"<method> - val_metrics/<metric>\".\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    final_table_df : pd.DataFrame\n",
    "        Columns: [\"Dataset\",\"Method\",\"Arch\",\"Linear Top 1\",\"Linear Top 5\",\"KNN Top 1\",\"KNN Top 5\"]\n",
    "    latex_tabular : str\n",
    "        A LaTeX tabular environment string matching your multi-header layout.\n",
    "    \"\"\"\n",
    "    if metrics is None:\n",
    "        metrics = [\"lin_top1\", \"lin_top5\", \"knn_top1\", \"knn_top5\"]\n",
    "\n",
    "    Path(out_dir).mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "    # Collect final values per method across all metrics\n",
    "    final_records = []  # we’ll fill after collecting finals per method\n",
    "\n",
    "    # Will store finals in a dict per method, then convert to DataFrame\n",
    "    # Initialize all methods with NaNs so missing metrics don't break\n",
    "    per_method_finals = {\n",
    "        mkey: {\"Dataset\": dataset_label, \"Method\": methods_map[mkey], \"Arch\": arch_label,\n",
    "               \"Linear Top 1\": float(\"nan\"), \"Linear Top 5\": float(\"nan\"),\n",
    "               \"KNN Top 1\": float(\"nan\"), \"KNN Top 5\": float(\"nan\")}\n",
    "        for mkey in methods_map.keys()\n",
    "    }\n",
    "\n",
    "    # Plot one figure per metric\n",
    "    for metric in metrics:\n",
    "        csv_path = Path(csv_dir) / f\"{dataset}_{metric}.csv\"\n",
    "        if not csv_path.exists():\n",
    "            print(f\"[WARN] Missing CSV: {csv_path} — skipping metric '{metric}'.\")\n",
    "            continue\n",
    "\n",
    "        df = pd.read_csv(csv_path)\n",
    "\n",
    "        if epoch_col not in df.columns:\n",
    "            raise ValueError(f\"Epoch column '{epoch_col}' not found in {csv_path}\")\n",
    "\n",
    "        # Build series for selected methods only\n",
    "        plotted_any = False\n",
    "        plt.figure(figsize=(8, 5))  # single plot per metric (no subplots), no color spec\n",
    "        for mkey, pretty in methods_map.items():\n",
    "            col_name = series_template.format(method=mkey, metric=metric)\n",
    "            if col_name not in df.columns:\n",
    "                # silently skip if that method isn't present in this CSV\n",
    "                continue\n",
    "            plt.plot(df[epoch_col], df[col_name], label=pretty)\n",
    "            plotted_any = True\n",
    "\n",
    "            # Record the final value (last row)\n",
    "            final_val = df[col_name].iloc[-1]\n",
    "            if metric == \"lin_top1\":\n",
    "                per_method_finals[mkey][\"Linear Top 1\"] = float(final_val)\n",
    "            elif metric == \"lin_top5\":\n",
    "                per_method_finals[mkey][\"Linear Top 5\"] = float(final_val)\n",
    "            elif metric == \"knn_top1\":\n",
    "                per_method_finals[mkey][\"KNN Top 1\"] = float(final_val)\n",
    "            elif metric == \"knn_top5\":\n",
    "                per_method_finals[mkey][\"KNN Top 5\"] = float(final_val)\n",
    "\n",
    "        if plotted_any:\n",
    "            plt.xlabel(\"Epoch\")\n",
    "            plt.ylabel(metric.replace(\"_\", \"/\"))\n",
    "            plt.title(f\"{dataset.upper()} {arch_label} — {metric.replace('_', ' ').title()}\")\n",
    "            plt.legend()\n",
    "            plt.tight_layout()\n",
    "            out_pdf = Path(out_dir) / f\"{dataset}_{arch_key}_{metric}.pdf\"\n",
    "            plt.savefig(out_pdf)\n",
    "            plt.close()\n",
    "            print(f\"[OK] Saved {out_pdf}\")\n",
    "        else:\n",
    "            plt.close()\n",
    "            print(f\"[WARN] No selected methods found for metric '{metric}' in {csv_path} — no plot saved.\")\n",
    "\n",
    "    # Build final table DataFrame in your requested column order\n",
    "    ordered_rows = []\n",
    "    for mkey in methods_map.keys():\n",
    "        ordered_rows.append(per_method_finals[mkey])\n",
    "    final_table_df = pd.DataFrame(ordered_rows, columns=[\n",
    "        \"Dataset\",\"Method\",\"Arch\",\"Linear Top 1\",\"Linear Top 5\",\"KNN Top 1\",\"KNN Top 5\"\n",
    "    ])\n",
    "\n",
    "    # Build LaTeX tabular with your grouped headers\n",
    "    # (No booktabs needed; matches your format)\n",
    "    header = (\n",
    "        \"\\\\begin{tabular}{lllcccc}\\n\"\n",
    "        \"    & & &\\n\"\n",
    "        \"    \\\\multicolumn{2}{c}{\\\\bfseries Linear} & \\n\"\n",
    "        \"    \\\\multicolumn{2}{c}{\\\\bfseries K-NN} \\\\\\\\\\n\"\n",
    "        \"    \\\\multicolumn{1}{c}{\\\\bfseries Dataset} & \\n\"\n",
    "        \"    \\\\multicolumn{1}{c}{\\\\bfseries Method} & \\n\"\n",
    "        \"    \\\\multicolumn{1}{c}{\\\\bfseries Arch} &\\n\"\n",
    "        \"    \\\\multicolumn{1}{c}{\\\\bfseries Top 1} & \\\\multicolumn{1}{c}{\\\\bfseries Top 5} &\\n\"\n",
    "        \"    \\\\multicolumn{1}{c}{\\\\bfseries Top 1} & \\\\multicolumn{1}{c}{\\\\bfseries Top 5} \\\\\\\\\\n\"\n",
    "        \"    \\\\hline\\n\"\n",
    "    )\n",
    "    body_lines = []\n",
    "    for _, row in final_table_df.iterrows():\n",
    "        line = (\n",
    "            f\"    {row['Dataset']} & {row['Method']} & {row['Arch']} & \"\n",
    "            f\"{'' if pd.isna(row['Linear Top 1']) else f'{row['Linear Top 1']*100:.2f}'} & \"\n",
    "            f\"{'' if pd.isna(row['Linear Top 5']) else f'{row['Linear Top 5']*100:.2f}'} & \"\n",
    "            f\"{'' if pd.isna(row['KNN Top 1']) else f'{row['KNN Top 1']*100:.2f}'} & \"\n",
    "            f\"{'' if pd.isna(row['KNN Top 5']) else f'{row['KNN Top 5']*100:.2f}'} \\\\\\\\\"\n",
    "        )\n",
    "        body_lines.append(line)\n",
    "    footer = \"\\n\\\\end{tabular}\\n\"\n",
    "    latex_tabular = header + \"\\n\".join(body_lines) + footer\n",
    "\n",
    "    return final_table_df, latex_tabular\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "587b8024",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[OK] Saved figs/stl10_resnet18_lin_top1.pdf\n",
      "[OK] Saved figs/stl10_resnet18_lin_top5.pdf\n",
      "[OK] Saved figs/stl10_resnet18_knn_top1.pdf\n",
      "[OK] Saved figs/stl10_resnet18_knn_top5.pdf\n",
      "  Dataset        Method  Arch  Linear Top 1  Linear Top 5  KNN Top 1  \\\n",
      "0  STL-10    Supervised  RN18      0.718625      0.974125   0.724250   \n",
      "1  STL-10          Ours  RN18      0.819750      0.992500   0.768500   \n",
      "2  STL-10        VICReg  RN18      0.816875      0.991375   0.784750   \n",
      "3  STL-10        SimCLR  RN18      0.828875      0.993625   0.791625   \n",
      "4  STL-10  Barlow Twins  RN18      0.821875      0.990375   0.782000   \n",
      "5  STL-10          BYOL  RN18      0.812750      0.992750   0.771375   \n",
      "6  STL-10          DINO  RN18      0.797500      0.991125   0.768500   \n",
      "\n",
      "   KNN Top 5  \n",
      "0   0.876875  \n",
      "1   0.938250  \n",
      "2   0.938500  \n",
      "3   0.949125  \n",
      "4   0.939750  \n",
      "5   0.943875  \n",
      "6   0.935250  \n",
      "\\begin{tabular}{lllcccc}\n",
      "    & & &\n",
      "    \\multicolumn{2}{c}{\\bfseries Linear} & \n",
      "    \\multicolumn{2}{c}{\\bfseries K-NN} \\\\\n",
      "    \\multicolumn{1}{c}{\\bfseries Dataset} & \n",
      "    \\multicolumn{1}{c}{\\bfseries Method} & \n",
      "    \\multicolumn{1}{c}{\\bfseries Arch} &\n",
      "    \\multicolumn{1}{c}{\\bfseries Top 1} & \\multicolumn{1}{c}{\\bfseries Top 5} &\n",
      "    \\multicolumn{1}{c}{\\bfseries Top 1} & \\multicolumn{1}{c}{\\bfseries Top 5} \\\\\n",
      "    \\hline\n",
      "    STL-10 & Supervised & RN18 & 71.86 & 97.41 & 72.43 & 87.69 \\\\\n",
      "    STL-10 & Ours & RN18 & 81.98 & 99.25 & 76.85 & 93.83 \\\\\n",
      "    STL-10 & VICReg & RN18 & 81.69 & 99.14 & 78.48 & 93.85 \\\\\n",
      "    STL-10 & SimCLR & RN18 & 82.89 & 99.36 & 79.16 & 94.91 \\\\\n",
      "    STL-10 & Barlow Twins & RN18 & 82.19 & 99.04 & 78.20 & 93.98 \\\\\n",
      "    STL-10 & BYOL & RN18 & 81.28 & 99.28 & 77.14 & 94.39 \\\\\n",
      "    STL-10 & DINO & RN18 & 79.75 & 99.11 & 76.85 & 93.53 \\\\\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "methods_map = {\n",
    "    \"Supervised_stl10_resnet18\": \"Supervised\",\n",
    "    \"shannon_hypersphere_k1_stl10_resnet18\": \"Ours\",\n",
    "    \"VICReg_stl10_resnet18\": \"VICReg\",\n",
    "    \"SimCLR_stl10_resnet18\": \"SimCLR\",\n",
    "    \"BarlowTwins_stl10_resnet18\": \"Barlow Twins\",\n",
    "    \"BYOL_stl10_resnet18\": \"BYOL\",\n",
    "    \"DINO_stl10_resnet18\": \"DINO\",\n",
    "}\n",
    "\n",
    "df_table, latex = generate_plots_and_table(\n",
    "    dataset=\"stl10\",\n",
    "    dataset_label=\"STL-10\",\n",
    "    arch_key=\"resnet18\",      # just for filenames\n",
    "    arch_label=\"RN18\",        # what appears in the LaTeX table\n",
    "    methods_map=methods_map,\n",
    "    metrics=[\"lin_top1\",\"lin_top5\",\"knn_top1\",\"knn_top5\"],  # adjust as needed\n",
    "    out_dir=\"figs\",           # PDFs saved here\n",
    "    csv_dir=\".\",              # where your CSVs are\n",
    "    # If your CSV uses a different column naming, tweak the template:\n",
    "    series_template=\"{method} - val_metrics/{metric}\",\n",
    ")\n",
    "\n",
    "print(df_table)   # nice dataframe view in the notebook\n",
    "print(latex)      # copy-paste this into your LaTeX doc\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b2d5a926",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[OK] Saved cifar10-figs/cifar10_resnet18_lin_top1.pdf\n",
      "[OK] Saved cifar10-figs/cifar10_resnet18_lin_top5.pdf\n",
      "[OK] Saved cifar10-figs/cifar10_resnet18_knn_top1.pdf\n",
      "[OK] Saved cifar10-figs/cifar10_resnet18_knn_top5.pdf\n",
      "    Dataset        Method  Arch  Linear Top 1  Linear Top 5  KNN Top 1  \\\n",
      "0  CIFAR-10    Supervised  RN18        0.8417        0.9897     0.8416   \n",
      "1  CIFAR-10          Ours  RN18        0.7254        0.9797     0.6811   \n",
      "2  CIFAR-10        VICReg  RN18        0.7472        0.9814     0.7267   \n",
      "3  CIFAR-10        SimCLR  RN18        0.7258        0.9786     0.6916   \n",
      "4  CIFAR-10  Barlow Twins  RN18        0.7463        0.9791     0.7181   \n",
      "5  CIFAR-10          BYOL  RN18        0.6937        0.9781     0.6554   \n",
      "6  CIFAR-10          DINO  RN18        0.7246        0.9802     0.7007   \n",
      "\n",
      "   KNN Top 5  \n",
      "0     0.9336  \n",
      "1     0.8914  \n",
      "2     0.9059  \n",
      "3     0.9025  \n",
      "4     0.9087  \n",
      "5     0.8974  \n",
      "6     0.9034  \n",
      "\\begin{tabular}{lllcccc}\n",
      "    & & &\n",
      "    \\multicolumn{2}{c}{\\bfseries Linear} & \n",
      "    \\multicolumn{2}{c}{\\bfseries K-NN} \\\\\n",
      "    \\multicolumn{1}{c}{\\bfseries Dataset} & \n",
      "    \\multicolumn{1}{c}{\\bfseries Method} & \n",
      "    \\multicolumn{1}{c}{\\bfseries Arch} &\n",
      "    \\multicolumn{1}{c}{\\bfseries Top 1} & \\multicolumn{1}{c}{\\bfseries Top 5} &\n",
      "    \\multicolumn{1}{c}{\\bfseries Top 1} & \\multicolumn{1}{c}{\\bfseries Top 5} \\\\\n",
      "    \\hline\n",
      "    CIFAR-10 & Supervised & RN18 & 84.17 & 98.97 & 84.16 & 93.36 \\\\\n",
      "    CIFAR-10 & Ours & RN18 & 72.54 & 97.97 & 68.11 & 89.14 \\\\\n",
      "    CIFAR-10 & VICReg & RN18 & 74.72 & 98.14 & 72.67 & 90.59 \\\\\n",
      "    CIFAR-10 & SimCLR & RN18 & 72.58 & 97.86 & 69.16 & 90.25 \\\\\n",
      "    CIFAR-10 & Barlow Twins & RN18 & 74.63 & 97.91 & 71.81 & 90.87 \\\\\n",
      "    CIFAR-10 & BYOL & RN18 & 69.37 & 97.81 & 65.54 & 89.74 \\\\\n",
      "    CIFAR-10 & DINO & RN18 & 72.46 & 98.02 & 70.07 & 90.34 \\\\\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "methods_map = {\n",
    "    \"Supervised_cifar10_resnet18\": \"Supervised\",\n",
    "    \"shannon_hypersphere_k1_cifar10_resnet18\": \"Ours\",\n",
    "    \"VICReg_cifar10_resnet18\": \"VICReg\",\n",
    "    \"SimCLR_cifar10_resnet18\": \"SimCLR\",\n",
    "    \"BarlowTwins_cifar10_resnet18\": \"Barlow Twins\",\n",
    "    \"BYOL_cifar10_resnet18\": \"BYOL\",\n",
    "    \"DINO_cifar10_resnet18\": \"DINO\",\n",
    "}\n",
    "\n",
    "df_table, latex = generate_plots_and_table(\n",
    "    dataset=\"cifar10\",\n",
    "    dataset_label=\"CIFAR-10\",\n",
    "    arch_key=\"resnet18\",      # just for filenames\n",
    "    arch_label=\"RN18\",        # what appears in the LaTeX table\n",
    "    methods_map=methods_map,\n",
    "    metrics=[\"lin_top1\",\"lin_top5\",\"knn_top1\",\"knn_top5\"],  # adjust as needed\n",
    "    out_dir=\"cifar10-figs\",           # PDFs saved here\n",
    "    csv_dir=\".\",              # where your CSVs are\n",
    "    # If your CSV uses a different column naming, tweak the template:\n",
    "    series_template=\"{method} - val_metrics/{metric}\",\n",
    ")\n",
    "\n",
    "print(df_table)   # nice dataframe view in the notebook\n",
    "print(latex)      # copy-paste this into your LaTeX doc\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3cb407fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[OK] Saved lc25000-figs/lc25000_resnet18_lin_top1.pdf\n",
      "[OK] Saved lc25000-figs/lc25000_resnet18_lin_top5.pdf\n",
      "[OK] Saved lc25000-figs/lc25000_resnet18_knn_top1.pdf\n",
      "[OK] Saved lc25000-figs/lc25000_resnet18_knn_top5.pdf\n",
      "    Dataset        Method  Arch  Linear Top 1  Linear Top 5  KNN Top 1  \\\n",
      "0  LC-25000    Supervised  RN18       0.99832           1.0    1.00000   \n",
      "1  LC-25000          Ours  RN18       0.99752           1.0    0.99992   \n",
      "2  LC-25000        VICReg  RN18       0.99300           1.0    1.00000   \n",
      "3  LC-25000        SimCLR  RN18       0.98348           1.0    0.99992   \n",
      "4  LC-25000  Barlow Twins  RN18       0.98092           1.0    0.99976   \n",
      "5  LC-25000          BYOL  RN18       0.95940           1.0    0.99536   \n",
      "6  LC-25000          DINO  RN18       0.98460           1.0    0.99992   \n",
      "\n",
      "   KNN Top 5  \n",
      "0        1.0  \n",
      "1        1.0  \n",
      "2        1.0  \n",
      "3        1.0  \n",
      "4        1.0  \n",
      "5        1.0  \n",
      "6        1.0  \n",
      "\\begin{tabular}{lllcccc}\n",
      "    & & &\n",
      "    \\multicolumn{2}{c}{\\bfseries Linear} & \n",
      "    \\multicolumn{2}{c}{\\bfseries K-NN} \\\\\n",
      "    \\multicolumn{1}{c}{\\bfseries Dataset} & \n",
      "    \\multicolumn{1}{c}{\\bfseries Method} & \n",
      "    \\multicolumn{1}{c}{\\bfseries Arch} &\n",
      "    \\multicolumn{1}{c}{\\bfseries Top 1} & \\multicolumn{1}{c}{\\bfseries Top 5} &\n",
      "    \\multicolumn{1}{c}{\\bfseries Top 1} & \\multicolumn{1}{c}{\\bfseries Top 5} \\\\\n",
      "    \\hline\n",
      "    LC-25000 & Supervised & RN18 & 99.83 & 100.00 & 100.00 & 100.00 \\\\\n",
      "    LC-25000 & Ours & RN18 & 99.75 & 100.00 & 99.99 & 100.00 \\\\\n",
      "    LC-25000 & VICReg & RN18 & 99.30 & 100.00 & 100.00 & 100.00 \\\\\n",
      "    LC-25000 & SimCLR & RN18 & 98.35 & 100.00 & 99.99 & 100.00 \\\\\n",
      "    LC-25000 & Barlow Twins & RN18 & 98.09 & 100.00 & 99.98 & 100.00 \\\\\n",
      "    LC-25000 & BYOL & RN18 & 95.94 & 100.00 & 99.54 & 100.00 \\\\\n",
      "    LC-25000 & DINO & RN18 & 98.46 & 100.00 & 99.99 & 100.00 \\\\\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "methods_map = {\n",
    "    \"Supervised_lc25000_resnet18\": \"Supervised\",\n",
    "    \"shannon_hypersphere_k1_lc25000_resnet18\": \"Ours\",\n",
    "    \"VICReg_lc25000_resnet18\": \"VICReg\",\n",
    "    \"SimCLR_lc25000_resnet18\": \"SimCLR\",\n",
    "    \"BarlowTwins_lc25000_resnet18\": \"Barlow Twins\",\n",
    "    \"BYOL_lc25000_resnet18\": \"BYOL\",\n",
    "    \"DINO_lc25000_resnet18\": \"DINO\",\n",
    "}\n",
    "\n",
    "df_table, latex = generate_plots_and_table(\n",
    "    dataset=\"lc25000\",\n",
    "    dataset_label=\"LC-25000\",\n",
    "    arch_key=\"resnet18\",      # just for filenames\n",
    "    arch_label=\"RN18\",        # what appears in the LaTeX table\n",
    "    methods_map=methods_map,\n",
    "    metrics=[\"lin_top1\",\"lin_top5\",\"knn_top1\",\"knn_top5\"],  # adjust as needed\n",
    "    out_dir=\"lc25000-figs\",           # PDFs saved here\n",
    "    csv_dir=\".\",              # where your CSVs are\n",
    "    # If your CSV uses a different column naming, tweak the template:\n",
    "    series_template=\"{method} - val_metrics/{metric}\",\n",
    ")\n",
    "\n",
    "print(df_table)   # nice dataframe view in the notebook\n",
    "print(latex)      # copy-paste this into your LaTeX doc\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "20886231",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "from typing import Dict, List, Optional\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.rcParams.update({'font.size': 9,  # Default font size for all text\n",
    "                    'axes.titlesize': 9, # Title font size\n",
    "                    'axes.labelsize': 9, # Axis label font size\n",
    "                    'xtick.labelsize': 9, # X-axis tick label font size\n",
    "                    'ytick.labelsize': 9}) # Y-axis tick label font size\n",
    "\n",
    "\n",
    "def plot_datasets_grid(\n",
    "    datasets: List[str],\n",
    "    metric: str,                                # \"lin_top1\", \"lin_top5\", \"knn_top1\", \"knn_top5\"\n",
    "    methods_map_by_dataset: Dict[str, Dict[str, str]],\n",
    "    csv_dir: str = \".\",\n",
    "    epoch_col: str = \"epoch\",\n",
    "    series_template: str = \"{method} - val_metrics/{metric}\",\n",
    "    out_pdf: Optional[str] = None,              # e.g. \"figs/lin_top1_3col.pdf\"\n",
    "    figsize=(5.5, 3),                         # ~one-column width; tweak height as needed\n",
    "    legend_ncol: int = 4,\n",
    "    epochs_by_dataset: Optional[Dict[str, int]] = None,  # per-dataset xmax (e.g., {\"lc25000\": 50})\n",
    "    common_arch_label: Optional[str] = None,    # if you want \"(RN18)\" in the shared title\n",
    "):\n",
    "    \"\"\"3-column figure (one subplot per dataset) with a shared legend below and a shared title.\"\"\"\n",
    "\n",
    "    pretty_metric = {\n",
    "        \"lin_top1\":  \"Linear Top-1\",\n",
    "        \"lin_top5\":  \"Linear Top-5\",\n",
    "        \"knn_top1\":  \"K-NN Top-1\",\n",
    "        \"knn_top5\":  \"K-NN Top-5\",\n",
    "    }.get(metric, metric)\n",
    "\n",
    "    if out_pdf is None:\n",
    "        out_pdf = f\"{metric}_3col.pdf\"\n",
    "\n",
    "    # default x-max: 200, except lc25000 -> 50 (as requested)\n",
    "    if epochs_by_dataset is None:\n",
    "        epochs_by_dataset = {ds: (50 if ds.lower() == \"lc25000\" else 200) for ds in datasets}\n",
    "\n",
    "    fig, axes = plt.subplots(1, 3, figsize=figsize, sharex=False, sharey=True)\n",
    "\n",
    "    # If not exactly 3 datasets, still handle gracefully\n",
    "    if len(datasets) != 3:\n",
    "        fig.clf()\n",
    "        n = len(datasets)\n",
    "        fig, axes = plt.subplots(1, n, figsize=(figsize[0], figsize[1]), sharex=False, sharey=True)\n",
    "\n",
    "    # Shared title\n",
    "    if common_arch_label:\n",
    "        fig.suptitle(f\"{pretty_metric} — {common_arch_label}\", y=1.0, fontsize=9)\n",
    "    else:\n",
    "        fig.suptitle(pretty_metric, y=1.0, fontsize=9)\n",
    "\n",
    "    legend_handles, legend_labels = [], []\n",
    "    max_y_seen = None\n",
    "\n",
    "    for ax, ds in zip(axes, datasets):\n",
    "        csv_path = Path(csv_dir) / f\"{ds}_{metric}.csv\"\n",
    "        ax.set_title(ds.upper(), fontsize=9)  # dataset only\n",
    "        ax.set_xlabel(\"Epoch\")\n",
    "        ax.grid(True, alpha=0.2, linewidth=0.5)\n",
    "\n",
    "        if not csv_path.exists():\n",
    "            ax.text(0.5, 0.5, \"Missing CSV\", ha=\"center\", va=\"center\", transform=ax.transAxes, fontsize=9)\n",
    "            ax.set_xlim(0, epochs_by_dataset.get(ds, 200))\n",
    "            continue\n",
    "\n",
    "        df = pd.read_csv(csv_path)\n",
    "        if epoch_col not in df.columns:\n",
    "            raise ValueError(f\"Epoch column '{epoch_col}' not found in {csv_path}\")\n",
    "\n",
    "        selected = methods_map_by_dataset.get(ds, {})\n",
    "        for method_key, pretty_label in selected.items():\n",
    "            col = series_template.format(method=method_key, metric=metric)\n",
    "            if col not in df.columns:\n",
    "                continue\n",
    "            y = df[col] * 100.0\n",
    "            h, = ax.plot(df[epoch_col], y, label=pretty_label, linewidth=1)\n",
    "\n",
    "            if pretty_label not in legend_labels:\n",
    "                legend_labels.append(pretty_label)\n",
    "                legend_handles.append(h)\n",
    "\n",
    "            y_max_here = float(pd.Series(y).max())\n",
    "            max_y_seen = y_max_here if max_y_seen is None else max(max_y_seen, y_max_here)\n",
    "\n",
    "        ax.set_xlim(0, epochs_by_dataset.get(ds, 200))\n",
    "\n",
    "    # Shared Y label on left-most plot\n",
    "    axes[0].set_ylabel(\"Accuracy (%)\")\n",
    "\n",
    "    # Harmonize Y limits across subplots (cap at 100)\n",
    "    if max_y_seen is not None:\n",
    "        ylim_top = min(100.0, max(60.0, (int(max_y_seen / 5) + 1) * 5))\n",
    "        axes[0].set_ylim(0, ylim_top)\n",
    "\n",
    "    # Layout: add space for the shared legend at the bottom\n",
    "    fig.subplots_adjust(bottom=0.30, wspace=0.25)\n",
    "\n",
    "    # Shared legend\n",
    "    if legend_handles:\n",
    "        fig.legend(\n",
    "            handles=legend_handles, labels=legend_labels,\n",
    "            loc=\"lower center\", ncol=legend_ncol, frameon=False\n",
    "        )\n",
    "\n",
    "    plt.tight_layout(rect=[0, 0.12, 1, 0.98])  # leave room for legend + suptitle\n",
    "    Path(os.path.dirname(out_pdf) or \".\").mkdir(parents=True, exist_ok=True)\n",
    "    plt.savefig(out_pdf, bbox_inches=\"tight\")\n",
    "    plt.close()\n",
    "    print(f\"[OK] Saved {out_pdf}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "c2363221",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[OK] Saved figs/lin_top1_3col.pdf\n"
     ]
    }
   ],
   "source": [
    "datasets = [\"cifar10\", \"stl10\", \"lc25000\"]\n",
    "metric = \"lin_top1\"\n",
    "\n",
    "methods_map_by_dataset = {\n",
    "    \"cifar10\": {\n",
    "        \"Supervised_cifar10_resnet18\": \"Supervised\",\n",
    "        \"shannon_hypersphere_k1_cifar10_resnet18\": \"Ours\",\n",
    "        \"VICReg_cifar10_resnet18\": \"VICReg\",\n",
    "        \"SimCLR_cifar10_resnet18\": \"SimCLR\",\n",
    "        \"BarlowTwins_cifar10_resnet18\": \"Barlow Twins\",\n",
    "        \"BYOL_cifar10_resnet18\": \"BYOL\",\n",
    "        \"DINO_cifar10_resnet18\": \"DINO\",\n",
    "    },\n",
    "    \"stl10\": {\n",
    "        \"Supervised_stl10_resnet18\": \"Supervised\",\n",
    "        \"shannon_hypersphere_k1_stl10_resnet18\": \"Ours\",\n",
    "        \"VICReg_stl10_resnet18\": \"VICReg\",\n",
    "        \"SimCLR_stl10_resnet18\": \"SimCLR\",\n",
    "        \"BarlowTwins_stl10_resnet18\": \"Barlow Twins\",\n",
    "        \"BYOL_stl10_resnet18\": \"BYOL\",\n",
    "        \"DINO_stl10_resnet18\": \"DINO\",\n",
    "    },\n",
    "    \"lc25000\": {\n",
    "        \"Supervised_lc25000_resnet18\": \"Supervised\",\n",
    "        \"shannon_hypersphere_k1_lc25000_resnet18\": \"Ours\",\n",
    "        \"VICReg_lc25000_resnet18\": \"VICReg\",\n",
    "        \"SimCLR_lc25000_resnet18\": \"SimCLR\",\n",
    "        \"BarlowTwins_lc25000_resnet18\": \"Barlow Twins\",\n",
    "        \"BYOL_lc25000_resnet18\": \"BYOL\",\n",
    "        \"DINO_lc25000_resnet18\": \"DINO\",\n",
    "    },\n",
    "}\n",
    "\n",
    "plot_datasets_grid(\n",
    "    datasets=datasets,\n",
    "    metric=metric,\n",
    "    methods_map_by_dataset=methods_map_by_dataset,\n",
    "    csv_dir=\".\",\n",
    "    out_pdf=f\"figs/{metric}_3col.pdf\",\n",
    "    legend_ncol=4,\n",
    "    epochs_by_dataset={\"cifar10\": 200, \"stl10\": 200, \"lc25000\": 50},\n",
    "    common_arch_label=None,  # or \"RN18\" if you want it in the shared title\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "ef550d0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "from typing import Dict, List, Optional\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def plot_datasets_vertical(\n",
    "    datasets: List[str],                         # e.g., [\"cifar10\",\"stl10\",\"lc25000\"]\n",
    "    metric: str,                                 # \"lin_top1\", \"lin_top5\", \"knn_top1\", \"knn_top5\"\n",
    "    methods_map_by_dataset: Dict[str, Dict[str, str]],\n",
    "    csv_dir: str = \".\",\n",
    "    epoch_col: str = \"epoch\",\n",
    "    series_template: str = \"{method} - val_metrics/{metric}\",\n",
    "    out_pdf: Optional[str] = None,               # e.g., \"figs/lin_top1_vertical.pdf\"\n",
    "    figsize=(5.5, 7.6),\n",
    "    legend_ncol: int = 4,\n",
    "    epochs_by_dataset: Optional[Dict[str, int]] = None,  # {\"cifar10\":200,\"stl10\":200,\"lc25000\":50}\n",
    "    ylim_overrides: Optional[Dict[str, tuple]] = None,   # e.g., {\"lc25000\": (80,100)}\n",
    "):\n",
    "    \"\"\"Vertical 3-row figure, no suptitle, de-collided labels, tighter legend spacing.\"\"\"\n",
    "    pretty_metric = {\n",
    "        \"lin_top1\":  \"Linear Top-1\",\n",
    "        \"lin_top5\":  \"Linear Top-5\",\n",
    "        \"knn_top1\":  \"K-NN Top-1\",\n",
    "        \"knn_top5\":  \"K-NN Top-5\",\n",
    "    }.get(metric, metric)\n",
    "\n",
    "    if out_pdf is None:\n",
    "        out_pdf = f\"{metric}_vertical.pdf\"\n",
    "\n",
    "    if epochs_by_dataset is None:\n",
    "        epochs_by_dataset = {ds: (50 if ds.lower() == \"lc25000\" else 200) for ds in datasets}\n",
    "\n",
    "    if ylim_overrides is None:\n",
    "        ylim_overrides = {ds: (80, 100) for ds in datasets if ds.lower() == \"lc25000\"}\n",
    "\n",
    "    fig, axes = plt.subplots(len(datasets), 1, figsize=figsize, sharex=False, sharey=False)\n",
    "\n",
    "    if len(datasets) == 1:\n",
    "        axes = [axes]\n",
    "\n",
    "    legend_handles, legend_labels = [], []\n",
    "\n",
    "    for ax, ds in zip(axes, datasets):\n",
    "        csv_path = Path(csv_dir) / f\"{ds}_{metric}.csv\"\n",
    "\n",
    "        # Titles & labels with padding to avoid collisions\n",
    "        ax.set_title(ds.upper(), fontsize=10, pad=4)           # small upward pad\n",
    "        ax.set_xlabel(\"Epoch\", labelpad=1)                     # bring label closer\n",
    "        ax.set_ylabel(\"Accuracy (%)\", labelpad=2)\n",
    "        ax.grid(True, alpha=0.2, linewidth=0.5)\n",
    "        ax.set_xlim(0, epochs_by_dataset.get(ds, 200))\n",
    "\n",
    "        if not csv_path.exists():\n",
    "            ax.text(0.5, 0.5, \"Missing CSV\", ha=\"center\", va=\"center\",\n",
    "                    transform=ax.transAxes, fontsize=9)\n",
    "            if ds in ylim_overrides:\n",
    "                ax.set_ylim(*ylim_overrides[ds])\n",
    "            continue\n",
    "\n",
    "        df = pd.read_csv(csv_path)\n",
    "        if epoch_col not in df.columns:\n",
    "            raise ValueError(f\"Epoch column '{epoch_col}' not found in {csv_path}\")\n",
    "\n",
    "        selected = methods_map_by_dataset.get(ds, {})\n",
    "        for method_key, pretty_label in selected.items():\n",
    "            col = series_template.format(method=method_key, metric=metric)\n",
    "            if col not in df.columns:\n",
    "                continue\n",
    "            y = df[col] * 100.0\n",
    "            h, = ax.plot(df[epoch_col], y, label=pretty_label, linewidth=1)\n",
    "\n",
    "            if pretty_label not in legend_labels:\n",
    "                legend_labels.append(pretty_label)\n",
    "                legend_handles.append(h)\n",
    "\n",
    "        if ds in ylim_overrides:\n",
    "            ax.set_ylim(*ylim_overrides[ds])\n",
    "\n",
    "    # Tighter spacing between subplots; smaller bottom margin\n",
    "    fig.subplots_adjust(hspace=0.32, bottom=0.12)\n",
    "\n",
    "    # Shared legend closer to the last plot\n",
    "    if legend_handles:\n",
    "        fig.legend(\n",
    "            legend_handles, legend_labels,\n",
    "            loc=\"lower center\", ncol=legend_ncol, frameon=False,\n",
    "            bbox_to_anchor=(0.5, 0.01)  # pull legend up closer to plots\n",
    "        )\n",
    "\n",
    "    Path(os.path.dirname(out_pdf) or \".\").mkdir(parents=True, exist_ok=True)\n",
    "    plt.savefig(out_pdf, bbox_inches=\"tight\")\n",
    "    plt.close()\n",
    "    print(f\"[OK] Saved {out_pdf}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "1713b68e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[OK] Saved figs/lin_top1_vertical.pdf\n"
     ]
    }
   ],
   "source": [
    "datasets = [\"cifar10\", \"stl10\", \"lc25000\"]\n",
    "metric = \"lin_top1\"\n",
    "\n",
    "methods_map_by_dataset = {\n",
    "    \"cifar10\": {\n",
    "        \"Supervised_cifar10_resnet18\": \"Supervised\",\n",
    "        \"shannon_hypersphere_k1_cifar10_resnet18\": \"Ours\",\n",
    "        \"VICReg_cifar10_resnet18\": \"VICReg\",\n",
    "        \"SimCLR_cifar10_resnet18\": \"SimCLR\",\n",
    "        \"BarlowTwins_cifar10_resnet18\": \"Barlow Twins\",\n",
    "        \"BYOL_cifar10_resnet18\": \"BYOL\",\n",
    "        \"DINO_cifar10_resnet18\": \"DINO\",\n",
    "    },\n",
    "    \"stl10\": {\n",
    "        \"Supervised_stl10_resnet18\": \"Supervised\",\n",
    "        \"shannon_hypersphere_k1_stl10_resnet18\": \"Ours\",\n",
    "        \"VICReg_stl10_resnet18\": \"VICReg\",\n",
    "        \"SimCLR_stl10_resnet18\": \"SimCLR\",\n",
    "        \"BarlowTwins_stl10_resnet18\": \"Barlow Twins\",\n",
    "        \"BYOL_stl10_resnet18\": \"BYOL\",\n",
    "        \"DINO_stl10_resnet18\": \"DINO\",\n",
    "    },\n",
    "    \"lc25000\": {\n",
    "        \"Supervised_lc25000_resnet18\": \"Supervised\",\n",
    "        \"shannon_hypersphere_k1_lc25000_resnet18\": \"Ours\",\n",
    "        \"VICReg_lc25000_resnet18\": \"VICReg\",\n",
    "        \"SimCLR_lc25000_resnet18\": \"SimCLR\",\n",
    "        \"BarlowTwins_lc25000_resnet18\": \"Barlow Twins\",\n",
    "        \"BYOL_lc25000_resnet18\": \"BYOL\",\n",
    "        \"DINO_lc25000_resnet18\": \"DINO\",\n",
    "    },\n",
    "}\n",
    "\n",
    "plot_datasets_vertical(\n",
    "    datasets=datasets,\n",
    "    metric=metric,\n",
    "    methods_map_by_dataset=methods_map_by_dataset,\n",
    "    csv_dir=\".\",\n",
    "    out_pdf=f\"figs/{metric}_vertical.pdf\",\n",
    "    legend_ncol=4,\n",
    "    epochs_by_dataset={\"cifar10\": 200, \"stl10\": 200, \"lc25000\": 50},\n",
    "    ylim_overrides={\"lc25000\": (80, 100)},      # clamp LC25000 to 80–100%\n",
    ")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv (3.12.3)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
