{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "171796cf-b1f5-40af-a84b-59741680c9ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import utils_exp as ue\n",
    "import importlib\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import glob\n",
    "import utils_exp as ue\n",
    "from matplotlib.lines import Line2D\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf3bfbb8-adfc-4b77-bc51-218df0edd7dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "color_dict = {\"SGD\":\"Black\", \"EEM\":\"Green\"}\n",
    "\n",
    "color_dict_lr = {\n",
    "    \"SGD\":{0.1:\"lightgrey\", 1:\"silver\", 10:\"grey\", 100:\"dimgrey\", 1000:\"black\"},\n",
    "    \"Adam\":{1:\"firebrick\", 0.1:\"brown\", 0.01:\"indianred\", 0.001:\"lightcoral\"},\n",
    "    \"RMSprop\":{0.1:\"violet\", 0.01:\"purple\", 0.001:\"orchid\", 0.0001:\"magenta\"},\n",
    "    \"Adagrad\":{0.1:\"navy\", 0.01:\"mediumblue\", 0.001:\"deepskyblue\", 0.0001:\"skyblue\"},\n",
    "}\n",
    "\n",
    "style_dict = {\"SGD\":\"--\", \"EEM\":\"-\", \"Adam\":\":\", \"RMSprop\":\"-.\", \"Adagrad\":\"-.\"}\n",
    "rep_times = 3\n",
    "\n",
    "αs = [0.3, 0.5, 0.7, 0.9]\n",
    "lrs = {\n",
    "       \"SGD\":[1000, 100, 10, 1, 0.1],\n",
    "       \"Adam\":[1, 0.1, 0.01, 0.001, 0.0001],\n",
    "       \"RMSprop\":[0.1, 0.01, 0.001, 0.0001],\n",
    "       \"Adagrad\":[0.1, 0.01, 0.001, 0.0001],\n",
    "       \"EEM\":[None],\n",
    "      }\n",
    "\n",
    "lrs = {\n",
    "       \"SGD\":[1000, 100, 10, 1, 0.1],\n",
    "       \"Adam\":[1, 0.1, 0.01, 0.001],\n",
    "       \"RMSprop\":[0.1, 0.01, 0.001,0.0001],\n",
    "       \"Adagrad\":[0.1, 0.01, 0.001,0.0001],\n",
    "       \"EEM\":[None],\n",
    "      }\n",
    "\n",
    "\n",
    "photo_name = \"house\"\n",
    "methods = [\"EEM\", \"RMSprop\", \"SGD\", \"Adam\"]# \"RMSprop\"]\n",
    "\n",
    "methods = [\"RMSprop\", \"SGD\", \"Adam\",\"EEM\"]# \"RMSprop\"]\n",
    "methods = [\"Adagrad\", \"RMSprop\", \"SGD\", \"Adam\",\"EEM\"]# \"RMSprop\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8372c42f-52b1-4456-804c-3b8ce44949cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({\n",
    "    \"font.size\": 14,            \n",
    "    \"axes.titlesize\": 16,      \n",
    "    \"axes.labelsize\": 14,       \n",
    "    \"xtick.labelsize\": 12,      \n",
    "    \"ytick.labelsize\": 12,      \n",
    "    \"legend.fontsize\": 12,      \n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "341d2909-47c1-4206-b4d0-10b70476cdd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def collect_pkl_path(photo_name, method, idα, idlr):\n",
    "    load_dir = f\"../results/exp_loss/{photo_name}/{method}/a{idα}_lr{idlr}_*.pkl\"\n",
    "    load_paths = glob.glob(load_dir)\n",
    "    return load_paths\n",
    "\n",
    "def collect_pkl_path_eem(photo_name,idα):\n",
    "    load_dir = f\"../results/exp_loss/{photo_name}/eem/a{idα}_*.pkl\"\n",
    "    load_paths = glob.glob(load_dir)\n",
    "    return load_paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "441c22bc-7c7c-4d79-850f-7899493d1e20",
   "metadata": {},
   "outputs": [],
   "source": [
    "exclude_list = []\n",
    "def exclude(method, α, lr):\n",
    "    if [method, α, lr] in exclude_list:\n",
    "        return True\n",
    "    else:\n",
    "        return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ed8c1b8-e0ff-4a4c-b01a-e911fc222be9",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(12,5))\n",
    "\n",
    "legend_handles = {}\n",
    "\n",
    "for idα, α in enumerate(αs):\n",
    "    ax = plt.subplot(1, 4, idα + 1)\n",
    "    ax.set_title(f\"α: {αs[idα]}\")\n",
    "    ax.set_xlabel(\"Iteration\")\n",
    "    ax.set_ylabel(\"α-div.\")\n",
    "    ax.set_ylim(0.0, 0.1)\n",
    "    ax.grid()\n",
    "\n",
    "    for method in methods:\n",
    "        for idlr, lr in enumerate(lrs[method]):\n",
    "            if exclude(method, α, lr):\n",
    "                continue\n",
    "\n",
    "            if method == \"EEM\":\n",
    "                load_paths = collect_pkl_path_eem(photo_name, idα)\n",
    "            else:\n",
    "                load_paths = collect_pkl_path(photo_name, method, idα, idlr)\n",
    "\n",
    "            for rep in range(rep_times):\n",
    "                res = ue.pickle_load(load_paths[rep])\n",
    "                loss = res[\"alpha\"] if method == \"EEM\" else res[0][\"alpha_div\"]\n",
    "\n",
    "                if rep == 0:\n",
    "                    label = r\"E$^2$MCPTuckerTTB\" if method == \"EEM\" else f\"{method} (lr={lr})\"\n",
    "                    color = color_dict[method] if method == \"EEM\" else color_dict_lr[method][lr]\n",
    "                    line, = ax.plot(loss, c=color, ls=style_dict[method], lw=1.5, label=label)\n",
    "                    legend_handles[label] = line\n",
    "                else:\n",
    "                    color = color_dict[method] if method == \"EEM\" else color_dict_lr[method][lr]\n",
    "                    ax.plot(loss, c=color, ls=style_dict[method], lw=1.5)\n",
    "\n",
    "# --- SEM を最後に移動\n",
    "if \"EEMCPTuckerTTB\" in legend_handles:\n",
    "    sem_handle = legend_handles.pop(\"EEMCPTuckerTTB\")\n",
    "    legend_handles[\"EEMCPTuckerTTB\"] = sem_handle\n",
    "\n",
    "# ラベルとハンドルを取り出し\n",
    "labels = list(legend_handles.keys())\n",
    "handles = list(legend_handles.values())\n",
    "\n",
    "# ラベル長を揃える（スペースパディング）\n",
    "maxlen = max(len(label) for label in labels)\n",
    "labels = [label.ljust(maxlen) for label in labels]\n",
    "\n",
    "# 改行を挿入（例：1段目: 0-4, 2段目: 5-9, 3段目: 10-）\n",
    "grouped_handles = []\n",
    "grouped_labels = []\n",
    "line_breaks = [5, 10]  # 改行位置\n",
    "\n",
    "for i, (h, l) in enumerate(zip(handles, labels)):\n",
    "    grouped_handles.append(h)\n",
    "    grouped_labels.append(l)\n",
    "    if i + 1 in line_breaks:\n",
    "        grouped_handles.append(None)  # None で空欄行\n",
    "        grouped_labels.append(\"\")     # 空ラベル\n",
    "\n",
    "\n",
    "# 凡例の描画\n",
    "legend = fig.legend(\n",
    "    handles=grouped_handles,\n",
    "    labels=grouped_labels,\n",
    "    loc=\"lower center\",\n",
    "    bbox_to_anchor=(0.5, -0.2),\n",
    "    ncol=5,\n",
    "    frameon=False,\n",
    "    handletextpad=1.5,\n",
    "    columnspacing=1.2,\n",
    "    borderaxespad=0.5,\n",
    "    alignment=\"left\"  # matplotlib 3.7+ required\n",
    ")\n",
    "\n",
    "# 凡例線の太さ調整\n",
    "for line in legend.legend_handles:\n",
    "    if line is not None:\n",
    "        line.set_linewidth(2)\n",
    "\n",
    "\n",
    "plt.tight_layout(rect=[0, 0.02, 1, 1])  # 下部スペースを広く確保\n",
    "plt.savefig(\"figs/exp_loss/loss_sm.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
