{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9211b890",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import json\n",
    "import matplotlib.patches as patches\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcfcfe1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "file_path = '../../results/data_results/cifar10_all_layers/cifar10_results.jsonl'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ab460c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n",
    "best_mean_f1 = 0\n",
    "best_acc = 0\n",
    "best_epoch = 0\n",
    "best_data = {}\n",
    "\n",
    "worst_mean_f1 = 100\n",
    "worst_acc = 0\n",
    "worst_epoch = 100\n",
    "worst_data = {}\n",
    "\n",
    "mid_mean_f1 = 0\n",
    "mid_acc = 0\n",
    "mid_epoch = 0\n",
    "mid_data = {}\n",
    "with open(file_path, \"r\", encoding=\"utf-8\") as f:\n",
    "    for line in f:\n",
    "        item = json.loads(line)\n",
    "        epoch = item['epoch']\n",
    "        mean_acc = item['mean_acc']\n",
    "        mean_f1 = item['mean_f1']\n",
    "\n",
    "        if epoch == 90:\n",
    "            best_mean_f1 = mean_f1\n",
    "            best_acc = mean_acc\n",
    "            best_epoch = epoch\n",
    "            best_data = item['distance']\n",
    "\n",
    "        if epoch == 1:\n",
    "            worst_mean_f1 = mean_f1\n",
    "            worst_acc = mean_acc\n",
    "            worst_epoch = epoch\n",
    "            worst_data = item['distance']\n",
    "\n",
    "        if epoch == 45:\n",
    "            mid_mean_f1 = mean_f1\n",
    "            mid_acc = mean_acc\n",
    "            mid_epoch = epoch\n",
    "            mid_data = item['distance']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3341dfad",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_mean_f1 \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec70f130",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e6b2cbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "worst_mean_f1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d193eda5",
   "metadata": {},
   "outputs": [],
   "source": [
    "worst_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8061a403",
   "metadata": {},
   "outputs": [],
   "source": [
    "mid_mean_f1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59c07cad",
   "metadata": {},
   "outputs": [],
   "source": [
    "mid_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5245dad3",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa6f99c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "worst_epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a2a5cf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_plot_dict(data):\n",
    "    intra = data['intra']\n",
    "    inter = data['inter']\n",
    "\n",
    "    layers = ['input', 'layer1', 'layer2', 'layer3', 'layer4', 'proj', 'fc']\n",
    "    ids = [str(i) for i in range(10)]\n",
    "\n",
    "    def compute_agg(dct, field):\n",
    "        values = []\n",
    "        for ly in layers:\n",
    "            layer_vals = []\n",
    "            for i in ids:\n",
    "                v = dct[i][ly][field]\n",
    "                if np.isnan(v):\n",
    "                    v = 0.0\n",
    "                layer_vals.append(v)\n",
    "            values.append(np.mean(layer_vals))\n",
    "        return values\n",
    "\n",
    "    intra_mean = compute_agg(intra, 'mean')\n",
    "    inter_mean = compute_agg(inter, 'mean')\n",
    "    intra_std  = compute_agg(intra, 'std')\n",
    "    inter_std  = compute_agg(inter, 'std')\n",
    "    gap_mean   = (np.array(inter_mean) - np.array(intra_mean)).tolist()\n",
    "\n",
    "    plot_dict = {\n",
    "        \"intra_mean\": intra_mean,\n",
    "        \"inter_mean\": inter_mean,\n",
    "        \"intra_std\": intra_std,\n",
    "        \"inter_std\": inter_std,\n",
    "        \"gap_mean\": gap_mean,\n",
    "        \"gap_std\": None\n",
    "    }\n",
    "    return plot_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec9820bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_plot_dict = get_plot_dict(best_data)\n",
    "worst_plot_dict = get_plot_dict(worst_data)\n",
    "mid_plot_dict = get_plot_dict(mid_data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea06bc53",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_split(ax, xs, ys, color, marker, label, lw=2, alpha_tail=0.15, split_idx=5):\n",
    "    ax.plot(xs[:split_idx], ys[:split_idx], color=color, linewidth=lw, label=label,marker=marker)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bf04162",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_state_lines_intra_inter(worst_plot_dict, mid_plot_dict, best_plot_dict,\n",
    "                                worst_epoch=None, mid_epoch=None, best_epoch=None, \n",
    "                                worst_f1=None, mid_f1=None, best_f1=None,\n",
    "                                save_path=None, title=\"\"):\n",
    "\n",
    "    layers = [r\"$\\mathbf{C}_{s}$\", r\"$\\mathbf{C}_{1}$\", r\"$\\mathbf{C}_{2}$\", r\"$\\mathbf{C}_{3}$\", r\"$\\mathbf{C}_{4}$\", r\"\", r\"\"]\n",
    "    xs = list(range(len(layers)))\n",
    "    \n",
    "    fig, ax = plt.subplots(figsize=(4, 3))\n",
    "\n",
    "    worst_mean = np.array(worst_plot_dict[\"intra_mean\"])\n",
    "    mid_mean = np.array(mid_plot_dict[\"intra_mean\"])\n",
    "    best_mean = np.array(best_plot_dict[\"intra_mean\"])\n",
    "\n",
    "    epoch_1_color = \"#BF1E2E\"\n",
    "    epoch_10_color = \"#FF9E02\"\n",
    "    epoch_90_color = \"#136783\"\n",
    "    plot_split(ax, xs, worst_mean, color=epoch_1_color,  marker='o', label=f\"Epoch {worst_epoch} (F1={np.round(worst_f1,4)})\")\n",
    "    plot_split(ax, xs, mid_mean,   color=epoch_10_color, marker='s', label=f\"Epoch {mid_epoch} (F1={np.round(mid_f1,4)})\")\n",
    "    plot_split(ax, xs, best_mean,  color=epoch_90_color,   marker='^', label=f\"Epoch {best_epoch} (F1={np.round(best_f1,4)})\")\n",
    "\n",
    "    rect = patches.Rectangle(\n",
    "        (4.15, 0.1),   \n",
    "        1.8,      \n",
    "        0.8,       \n",
    "        linewidth=1,\n",
    "        edgecolor='lightgray',\n",
    "        facecolor='whitesmoke',\n",
    "        alpha=1   \n",
    "    )\n",
    "    ax.add_patch(rect)\n",
    "\n",
    "    plt.text(\n",
    "        x=5.05, y=0.5,          \n",
    "        s=\"Projection\\n+\\nFC\\n($\\mathcal{F}_f$)\",       \n",
    "        fontsize=12, \n",
    "        rotation=0,         \n",
    "        ha=\"center\", va=\"center\",\n",
    "        multialignment=\"center\"\n",
    "    )\n",
    "    \n",
    "    ax.set_xticks(xs)\n",
    "    ax.set_xticklabels(layers)\n",
    "    ax.grid(True, alpha=0.3)\n",
    "    ax.legend()\n",
    "    ax.set_title(title)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    if save_path is not None:\n",
    "        os.makedirs(os.path.dirname(save_path), exist_ok=True)\n",
    "        plt.savefig(save_path, dpi=200, bbox_inches=\"tight\")\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab506be1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "049fef4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_state_lines_intra_inter(\n",
    "    worst_plot_dict, mid_plot_dict, best_plot_dict, \n",
    "    worst_epoch=worst_epoch, mid_epoch=mid_epoch, best_epoch=best_epoch,\n",
    "    worst_f1=worst_mean_f1, mid_f1=mid_mean_f1, best_f1=best_mean_f1,\n",
    "    save_path=f\"../../results/figure_results/2_in_one.pdf\",\n",
    "    title=f\"Intra-class mean distance\"\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
