{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "785988d8-c974-449e-848a-4b12bfb25362",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from pathlib import Path\n",
    "import time\n",
    "from datetime import datetime\n",
    "from collections import defaultdict\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "result_root_dir: Path = Path(\"results\").resolve()\n",
    "\n",
    "\n",
    "def get_result_config(dir_path):\n",
    "    dir_path = Path(str(dir_path)).resolve()\n",
    "    relative_path = dir_path.relative_to(result_root_dir)\n",
    "    # {OBJECTIVE_NAME}/{OBJECTIVE_CONFIG}/NoConstraint/{ALGORITHM_NAME}/{ALGORITHM_CONFIG}/\n",
    "    parts: list[str] = relative_path.parts\n",
    "    objective_name: str = parts[0]\n",
    "    objective_config: str = parts[1]\n",
    "    algorithm_name: str = parts[3]\n",
    "    algorithm_config: str = parts[4]\n",
    "\n",
    "    objective_config_dict = {}\n",
    "    for part in objective_config.split(\"~\"):\n",
    "        key, value = part.split(\"@\")\n",
    "        objective_config_dict[key] = value\n",
    "    \n",
    "    algorithm_config_dict = {}\n",
    "    for part in algorithm_config.split(\"~\"):\n",
    "        key, value = part.split(\"@\")\n",
    "        algorithm_config_dict[key] = value\n",
    "\n",
    "    last_modified_time = dir_path.stat().st_mtime\n",
    "    seconds_to_now = time.time() - last_modified_time\n",
    "\n",
    "    return (\n",
    "        objective_name,\n",
    "        objective_config_dict,\n",
    "        algorithm_name,\n",
    "        algorithm_config_dict,\n",
    "        seconds_to_now,\n",
    "    )\n",
    "\n",
    "def get_filtered_results(is_valid):\n",
    "    results = []\n",
    "    # func_values.npy, grad_norm.npy, time.npyが存在するディレクトリを取得\n",
    "    for result_file in result_root_dir.glob(\"**/func_values.npy\"):\n",
    "        result_dir = result_file.parent\n",
    "        config = get_result_config(result_dir)\n",
    "        if is_valid(*config):\n",
    "            results.append(result_dir)\n",
    "    return results\n",
    "\n",
    "def get_keys_with_diff(dict_list: list[dict]) -> set[str]:\n",
    "    keys = set()\n",
    "    for d in dict_list:\n",
    "        keys |= d.keys()\n",
    "\n",
    "    keys_with_diff = set()    \n",
    "    for k in keys:\n",
    "        value_set = set()\n",
    "        for d in dict_list:\n",
    "            value_set.add(d.get(k))\n",
    "        if len(value_set) > 1:\n",
    "            keys_with_diff.add(k)\n",
    "    return keys_with_diff\n",
    "\n",
    "\n",
    "def take_average(fig, ax, groupings: dict[str, list[list[int]]], colors: dict[str, np.array], linestyles: dict[str, str]):\n",
    "    \"\"\"別の ax に平均をプロットする。標準偏差で影をつける\"\"\"\n",
    "    fig_ave, ax_ave = plt.subplots()\n",
    "\n",
    "    for method, group in groupings.items():\n",
    "        xs = [ax.lines[i].get_xdata() for i in group]\n",
    "        ys = [ax.lines[i].get_ydata() for i in group]\n",
    "\n",
    "        all_x = np.unique(np.concatenate(xs))\n",
    "        all_y = np.zeros_like(all_x)\n",
    "\n",
    "        for x, y in zip(xs, ys):\n",
    "            all_y += np.interp(all_x, x, y)\n",
    "\n",
    "        all_y /= len(group)\n",
    "\n",
    "        ax_ave.plot(all_x, all_y, label=method, color=colors[method], linestyle=linestyles.get(\"method\", \"-\"))\n",
    "\n",
    "        # 標準偏差をプロット\n",
    "        all_y = np.zeros_like(all_x)\n",
    "        for x, y in zip(xs, ys):\n",
    "            all_y += np.interp(all_x, x, y)**2\n",
    "        all_y /= len(group)\n",
    "        all_y = np.sqrt(all_y - ax_ave.lines[-1].get_ydata()**2)\n",
    "        ax_ave.fill_between(all_x, ax_ave.lines[-1].get_ydata()-all_y, ax_ave.lines[-1].get_ydata()+all_y, alpha=0.3, color=colors[method])\n",
    "    return fig_ave, ax_ave\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7da76c8-5b66-48f0-8458-f8909c466224",
   "metadata": {},
   "outputs": [],
   "source": [
    "###########################\n",
    "# Define your filter here #\n",
    "###########################\n",
    "\n",
    "# CIFAR-10 STOCHASTIC SETTINGS\n",
    "def is_valid(obj_name, obj_config, alg_name, alg_config, seconds_to_now):\n",
    "    if \"random_seed\" in obj_config:\n",
    "        return False\n",
    "    if obj_config.get(\"data_size\") != \"large\":\n",
    "        return False\n",
    "    if alg_name not in (\"GD\", \"RSHTR\"):\n",
    "        return False\n",
    "    if alg_name == \"RSHTR\":\n",
    "        if alg_config[\"delta\"] != \"0.0001\":\n",
    "            return False\n",
    "        if alg_config[\"Delta\"] != \"0.0001\":\n",
    "            return False\n",
    "    return True\n",
    "\n",
    "# Rosenbrock\n",
    "def is_valid_rosenbrock(obj_name, obj_config, alg_name, alg_config, seconds_to_now):\n",
    "    if obj_name != \"RosenbrockRankDeficient\":\n",
    "        return False\n",
    "    if obj_config[\"dim\"] != \"10000\":\n",
    "        return False\n",
    "    if obj_config[\"rank\"] != \"50\":\n",
    "        return False\n",
    "    if alg_name == \"KCRN\":\n",
    "        return False\n",
    "    if alg_name == \"HSODM\":\n",
    "        return False\n",
    "    if alg_config.get(\"reduced_dim\") not in (None, \"100\"):\n",
    "        return False\n",
    "    if alg_name == \"RSHTR\":\n",
    "        if alg_config[\"lanczos_precision\"] != \"single\":\n",
    "            return False\n",
    "        if alg_config[\"delta\"] != \"0.001\":\n",
    "            return False\n",
    "    \n",
    "    return True\n",
    "\n",
    "def is_valid_rosenbrock_multi_rank(obj_name, obj_config, alg_name, alg_config, seconds_to_now):\n",
    "    if obj_name != \"RosenbrockRankDeficient\":\n",
    "        return False\n",
    "    if obj_config[\"dim\"] != \"10000\":\n",
    "        return False\n",
    "    if obj_config[\"rank\"] not in (\"25\", \"50\", \"100\", \"150\"):\n",
    "        return False\n",
    "    # if obj_config[\"matrix_seed\"] != \"1\":\n",
    "    #     return False\n",
    "    if alg_name not in (\"RSHTR\",):\n",
    "        return False\n",
    "    if alg_config[\"delta\"] != \"0.0001\":\n",
    "        return False\n",
    "    return True\n",
    "\n",
    "def is_valid_logistic_internet_ads(obj_name, obj_config, alg_name, alg_config, seconds_to_now):\n",
    "    if obj_name != \"Logistic\":\n",
    "        return False\n",
    "    if obj_config[\"data_name\"] != \"ad\":\n",
    "        return False\n",
    "    if alg_name not in (\"GD\", \"HSODM\"):\n",
    "        if alg_config[\"eps\"] != \"1e-06\":\n",
    "            return False\n",
    "    return True\n",
    "\n",
    "def is_valid_logistic_news20_binary(obj_name, obj_config, alg_name, alg_config, seconds_to_now):\n",
    "    if obj_name != \"Logistic\":\n",
    "        return False\n",
    "    if obj_config[\"data_name\"] != \"news20\":\n",
    "        return False\n",
    "    if alg_name not in (\"GD\", \"HSODM\"):\n",
    "        if alg_config[\"eps\"] != \"1e-06\":\n",
    "            return False\n",
    "    return True\n",
    "\n",
    "def is_valid_logistic_rcv1(obj_name, obj_config, alg_name, alg_config, seconds_to_now):\n",
    "    if obj_name != \"Logistic\":\n",
    "        return False\n",
    "    if obj_config[\"data_name\"] != \"rcv1\":\n",
    "        return False\n",
    "    if alg_name not in (\"GD\", \"HSODM\"):\n",
    "        if alg_config[\"eps\"] != \"1e-06\":\n",
    "            return False\n",
    "    return True\n",
    "\n",
    "def is_valid_softmax_news20(obj_name, obj_config, alg_name, alg_config, seconds_to_now):\n",
    "    if obj_name != \"Softmax\":\n",
    "        return False\n",
    "    if obj_config[\"data_name\"] != \"news20\":\n",
    "        return False\n",
    "    if alg_name not in (\"GD\", \"HSODM\"):\n",
    "        if alg_config[\"eps\"] != \"1e-06\":\n",
    "            return False\n",
    "    return True\n",
    "\n",
    "def is_valid_softmax_scotus(obj_name, obj_config, alg_name, alg_config, seconds_to_now):\n",
    "    if obj_name != \"Softmax\":\n",
    "        return False\n",
    "    if obj_config[\"data_name\"] != \"scotus\":\n",
    "        return False\n",
    "    # if alg_name not in (\"GD\", \"HSODM\"):\n",
    "    #     if alg_config[\"eps\"] != \"1e-06\":\n",
    "    #         return False\n",
    "    if alg_config.get(\"random_matrix_seed\") not in (None, \"0\"):\n",
    "        return False\n",
    "    return True\n",
    "\n",
    "def is_valid_mf_movie_lens(obj_name, obj_config, alg_name, alg_config, seconds_to_now):\n",
    "    if obj_name != \"MatrixFactorization\":\n",
    "        return False\n",
    "    if obj_config[\"data_name\"] != \"movie\":\n",
    "        return False\n",
    "    if obj_config[\"rank\"] != \"50\":\n",
    "        return False\n",
    "    if alg_config.get(\"reduced_dim\") not in (None, \"100\"):\n",
    "        return False\n",
    "    if alg_name not in (\"GD\", \"HSODM\"):\n",
    "        if alg_config[\"eps\"] != \"1e-06\":\n",
    "            return False\n",
    "    if alg_config[\"random_matrix_seed\"] != \"0\":\n",
    "        return False\n",
    "    return True\n",
    "\n",
    "def is_valid_mf_movie_lens_masked(obj_name, obj_config, alg_name, alg_config, seconds_to_now):\n",
    "    if obj_name != \"MatrixFactorizationMasked\":\n",
    "        return False\n",
    "    if obj_config[\"data_name\"] != \"movie\":\n",
    "        return False\n",
    "    if obj_config[\"rank\"] != \"50\":\n",
    "        return False\n",
    "    if alg_config.get(\"reduced_dim\") not in (None, \"100\"):\n",
    "        return False\n",
    "    if alg_name not in (\"GD\", \"HSODM\"):\n",
    "        if alg_config[\"eps\"] != \"1e-06\":\n",
    "            return False\n",
    "    return True\n",
    "\n",
    "\n",
    "def is_valid_mf_multi_s(obj_name, obj_config, alg_name, alg_config, seconds_to_now):\n",
    "    if obj_name != \"MatrixFactorization\":\n",
    "        return False\n",
    "    if obj_config[\"data_name\"] != \"movie\":\n",
    "        return False\n",
    "    if obj_config[\"rank\"] != \"50\":\n",
    "        return False\n",
    "    if alg_name not in (\"RSGD\", \"RSRN\", \"RSHTR\"):\n",
    "        return False\n",
    "    if alg_config[\"reduced_dim\"] not in (\"50\", \"100\", \"200\"):\n",
    "        return False\n",
    "    return True\n",
    "\n",
    "def is_valid_cifar10_stochastic(obj_name, obj_config, alg_name, alg_config, seconds_to_now):\n",
    "    if obj_name != \"MLPNET\":\n",
    "        return False\n",
    "    if obj_config[\"data_name\"] != \"cifar10\":\n",
    "        return False\n",
    "    if obj_config[\"data_size\"] != \"large\":\n",
    "        return False\n",
    "    # if obj_config[\"batch_size\"] not in (\"1\", \"32\", \"1024\"):\n",
    "    if obj_config[\"batch_size\"] not in (\"1024\",):\n",
    "        return False\n",
    "    if alg_name not in (\"RSHTR\", \"GD\"):\n",
    "        return False\n",
    "    if alg_name == \"RSHTR\":\n",
    "        if obj_config.get(\"random_seed\") is None:\n",
    "            return False\n",
    "        if alg_config[\"delta\"] != \"0.0001\":\n",
    "            return False\n",
    "        if alg_config[\"Delta\"] != \"0.0001\":\n",
    "            return False\n",
    "    return True\n",
    "\n",
    "def is_valid_mnist_stochastic(obj_name, obj_config, alg_name, alg_config, seconds_to_now):\n",
    "    if obj_name != \"MLPNET\":\n",
    "        return False\n",
    "    if obj_config[\"data_name\"] != \"mnist\":\n",
    "        return False\n",
    "    if obj_config[\"data_size\"] != \"full\":\n",
    "        return False\n",
    "    # if obj_config[\"batch_size\"] not in (\"1\", \"32\", \"1024\"):\n",
    "    if obj_config[\"batch_size\"] not in (\"1024\",):\n",
    "        return False\n",
    "    if alg_name not in (\"RSHTR\", \"GD\"):\n",
    "        return False\n",
    "    if alg_name == \"RSHTR\":\n",
    "        if alg_config[\"delta\"] != \"0.001\":\n",
    "            return False\n",
    "        if alg_config[\"Delta\"] != \"0.001\":\n",
    "            return False\n",
    "    return True\n",
    "\n",
    "\n",
    "\n",
    "#################\n",
    "# plot settings #\n",
    "#################\n",
    "results = get_filtered_results(is_valid_rosenbrock_multi_rank)\n",
    "plot_mini_batch_loss = None\n",
    "vs_time = False\n",
    "y_log = True\n",
    "\n",
    "\n",
    "# config の違いを整理\n",
    "obj_name_to_obj_configs = defaultdict(list)\n",
    "alg_name_to_alg_configs = defaultdict(list)\n",
    "for result_dir in results:\n",
    "    obj_name, obj_config, alg_name, alg_config, _ = get_result_config(result_dir)\n",
    "    obj_name_to_obj_configs[obj_name].append(obj_config)\n",
    "    alg_name_to_alg_configs[alg_name].append(alg_config)\n",
    "obj_config_keys_with_diff = set()\n",
    "alg_config_keys_with_diff = set()\n",
    "for obj_name, obj_configs in obj_name_to_obj_configs.items():\n",
    "    obj_config_keys_with_diff |= get_keys_with_diff(obj_configs)\n",
    "for alg_name, alg_configs in alg_name_to_alg_configs.items():\n",
    "    alg_config_keys_with_diff |= get_keys_with_diff(alg_configs)\n",
    "print(\"obj_config keys with diff:\", obj_config_keys_with_diff)\n",
    "print(\"alg_config keys with diff:\", alg_config_keys_with_diff)\n",
    "\n",
    "# config の違いをラベルに反映\n",
    "def gen_label(result_dir, obj_keys, alg_keys):\n",
    "    _, obj_config, alg_name, alg_config, __ = get_result_config(result_dir)\n",
    "    obj_configs = [f\"{k}={obj_config.get(k, '-')}\" for k in obj_keys]\n",
    "    alg_configs = [f\"{k}={alg_config.get(k, '-')}\" for k in alg_keys]\n",
    "    config_str = \", \".join([*obj_configs, *alg_configs])\n",
    "    if config_str == \"\":\n",
    "        return f\"{alg_name}\"\n",
    "    else:\n",
    "        return f\"{alg_name} ({config_str})\"\n",
    "\n",
    "\n",
    "# =========================\n",
    "# final_results_dir = Path(\"final_results\").resolve()\n",
    "# final_results_dir.mkdir(exist_ok=True)\n",
    "# for result_dir in sorted_results:\n",
    "#     result_dir_rel = result_dir.relative_to(result_root_dir)\n",
    "#     final_result_dir = final_results_dir / result_dir_rel\n",
    "#     final_result_dir.parent.mkdir(parents=True, exist_ok=True)\n",
    "#     for src in result_dir.glob(\"*\"):\n",
    "#         dst = final_result_dir / src.relative_to(result_dir)\n",
    "#         if src.is_dir():\n",
    "#             dst.mkdir(exist_ok=True)\n",
    "#         else:\n",
    "#             dst.parent.mkdir(parents=True, exist_ok=True)\n",
    "#             dst.write_bytes(src.read_bytes())\n",
    "# ===============================\n",
    "\n",
    "\n",
    "# 可視化\n",
    "sorted_results = sorted(results, key=lambda result_dir: gen_label(result_dir, obj_config_keys_with_diff, alg_config_keys_with_diff))\n",
    "for result_dir in sorted_results:\n",
    "    obj_name, obj_config, alg_name, alg_config, _ = get_result_config(result_dir)\n",
    "\n",
    "    if plot_mini_batch_loss is None or plot_mini_batch_loss:\n",
    "        func_values = np.load(result_dir / \"func_values.npy\")\n",
    "    else:\n",
    "        func_values = np.load(result_dir / \"full_batch_loss.npy\")\n",
    "    grad_norm = np.load(result_dir / \"grad_norm.npy\")\n",
    "    times = np.load(result_dir / \"time.npy\")\n",
    "    # valid_indices = np.where((grad_norm != 0) & (func_values < 10))[0]\n",
    "    # valid_indices = np.where(grad_norm != 0)[0]\n",
    "    valid_indices = np.where(func_values > 0)[0]\n",
    "    if obj_config[\"rank\"] != \"150\":\n",
    "        valid_indices = valid_indices[:np.argmax(func_values[valid_indices] < 1e-4) + 1]\n",
    "\n",
    "    func_values = func_values[valid_indices]\n",
    "    grad_norm = grad_norm[valid_indices]\n",
    "    times = times[valid_indices]\n",
    "\n",
    "    \n",
    "    label = gen_label(result_dir, obj_config_keys_with_diff, alg_config_keys_with_diff)\n",
    "    \n",
    "    linestyle = {\"RSHTR\": \"-\", \"GD\": \"-\"}.get(alg_name, \"-\")\n",
    "\n",
    "    if vs_time:\n",
    "        # ax.plot(times, func_values, label=label, linestyle=linestyle)\n",
    "        ax.plot(times, func_values, label=label, linestyle=linestyle)\n",
    "    else:\n",
    "        ax.plot(func_values, label=label, linestyle=linestyle)\n",
    "\n",
    "\n",
    "if y_log:\n",
    "    ax.set_yscale(\"log\")\n",
    "\n",
    "if plot_mini_batch_loss is None:\n",
    "    ax.set_title(f\"{obj_name}\")\n",
    "    ax.set_ylabel(\"function value\")\n",
    "elif plot_mini_batch_loss:\n",
    "    ax.set_title(f\"{obj_name} (mini-batch loss)\")\n",
    "    ax.set_ylabel(\"mini batch loss\")\n",
    "else:\n",
    "    ax.set_title(f\"{obj_name} (full-batch loss)\")\n",
    "    ax.set_ylabel(\"full batch loss\")\n",
    "\n",
    "if vs_time:\n",
    "    ax.set_xlabel(\"time (s)\")\n",
    "else:\n",
    "    ax.set_xlabel(\"iterations\")\n",
    "\n",
    "\n",
    "ax.legend()\n",
    "fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d326ba2f-567f-4b6e-ab57-20f3c4e00bf3",
   "metadata": {},
   "outputs": [],
   "source": [
    "groupings = {\n",
    "    # \"GD\": range(1),\n",
    "    # \"HSODM\": range(1, 2),\n",
    "    # \"RSGD\": range(2, 7),\n",
    "    # \"RSHTR\": range(7, 12), \n",
    "    # \"RSRN\": range(12, 17),\n",
    "    # \"RSHTR (s=50)\": range(3, 15, 3),\n",
    "    # \"RSHTR (s=100)\": range(0, 15, 3),\n",
    "    # \"RSHTR (s=200)\": range(1, 15, 3),\n",
    "    \"rank=100\": range(4),\n",
    "    \"rank=150\": range(4, 8),\n",
    "    \"rank=25\": range(8, 12),\n",
    "    \"rank=50\": range(12, 16),\n",
    "}\n",
    "default_colors = {\n",
    "    \"GD\": np.array([0, 0.75, 0.75]),\n",
    "    \"HSODM\": np.array([1, 0.5, 0]),\n",
    "    \"RSGD\": np.array([0, 0.5, 0]),\n",
    "    \"RSRN\": np.array([0.25, 0.25, 0.95]),\n",
    "    \"RSHTR\": np.array([0.95, 0.25, 0.25]),\n",
    "    # \"RSHTR (s=50)\": np.array([0.3, 0.3, 1]),\n",
    "    # \"RSHTR (s=100)\": np.array([0.3, 0.3, 1]),\n",
    "    # \"RSHTR (s=200)\": np.array([0.3, 0.3, 1]),\n",
    "    \"rank=100\": np.array([0.95, 0.25, 0.25]),\n",
    "    \"rank=150\": np.array([0.95, 0.25, 0.25]),\n",
    "    \"rank=25\": np.array([0.95, 0.25, 0.25]),\n",
    "    \"rank=50\": np.array([0.95, 0.25, 0.25]),\n",
    "}\n",
    "default_linestyles = {\n",
    "    # \"GD\": \":\",\n",
    "    # \"RSHTR (s=50)\": \"-\",\n",
    "    # \"RSHTR (s=100)\": \"--\",\n",
    "    # \"RSHTR (s=200)\": \"-.\",\n",
    "}\n",
    "fig, ax = take_average(fig, ax, groupings, default_colors, default_linestyles)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b376309-ae82-4ce5-aa47-97271766537e",
   "metadata": {},
   "outputs": [],
   "source": [
    "handles, labels = ax.get_legend_handles_labels()\n",
    "\n",
    "\n",
    "ax.set_xlim(0, 1200)\n",
    "ax.set_ylim(1e-4, 400)\n",
    "\n",
    "if vs_time:\n",
    "    ax.set_xlabel(\"time (s)\", fontsize=24)\n",
    "else:\n",
    "    ax.set_xlabel(\"iterations\", fontsize=24)\n",
    "\n",
    "\n",
    "if plot_mini_batch_loss is None:\n",
    "    ax.set_title(f\"{obj_name}\")\n",
    "    ax.set_ylabel(\"function value\", fontsize=24)\n",
    "elif plot_mini_batch_loss:\n",
    "    ax.set_title(f\"{obj_name} (mini-batch loss, batch_size=1024)\")\n",
    "    ax.set_ylabel(\"mini batch loss\", fontsize=24)\n",
    "else:\n",
    "    ax.set_title(f\"{obj_name} (full-batch loss, batch_size=1024)\")\n",
    "    ax.set_ylabel(\"full batch loss\", fontsize=24)\n",
    "\n",
    "\n",
    "###########################\n",
    "# settings for mf_multi_s # \n",
    "###########################\n",
    "\n",
    "# legend を並べ替える\n",
    "# indices = [0, 1, 2, 6, 7, 8, 3, 4, 5]\n",
    "# handles = [handles[i] for i in indices]\n",
    "# labels = [labels[i] for i in indices]\n",
    "\n",
    "# labels を変更する\n",
    "# labels = [label.replace(\"reduced_dim\", \"s\") for label in labels]\n",
    "\n",
    "# 色を変更\n",
    "# shades = [0.5, 0.7, 0.9]\n",
    "# maps = [\"Greens\", \"Blues\", \"Reds\"]\n",
    "# for i, handle in enumerate(handles):\n",
    "#     cmap = plt.colormaps[maps[i // 3]]\n",
    "#     color = plt.get_cmap(cmap)(shades[i % 3])\n",
    "#     handle.set_color(color)\n",
    "\n",
    "# 値を変更\n",
    "# for i, handle in enumerate(handles):\n",
    "#     handle.set_ydata(handle.get_ydata() / (943 * 1682))\n",
    "\n",
    "\n",
    "######################################\n",
    "# settings for rosenbrock_multi_rank # \n",
    "######################################\n",
    "\n",
    "# legend を並べ替える\n",
    "indices = [2, 3, 0, 1]\n",
    "handles = [handles[i] for i in indices]\n",
    "labels = [labels[i] for i in indices]\n",
    "\n",
    "# 色を変更\n",
    "shades = [0.5, 0.7, 0.9, 1.0]\n",
    "for i, handle in enumerate(handles):\n",
    "    cmap = plt.colormaps[\"Reds\"]\n",
    "    handle.set_color(plt.get_cmap(cmap)(shades[i]))\n",
    "\n",
    "# labels を変更する\n",
    "labels = [label.replace(\"rank\", \"r\") for label in labels]\n",
    "\n",
    "# legend を表示\n",
    "ax.legend(handles, labels, fontsize=24)\n",
    "# ax.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.23), ncol=3)\n",
    "# ax.legend(handles, labels, loc='center left', bbox_to_anchor=(1, 0.5), fontsize=16)\n",
    "\n",
    "\n",
    "\n",
    "# 線の太さ\n",
    "for line in ax.lines:\n",
    "    line.set_linewidth(3)\n",
    "\n",
    "ax.tick_params(labelsize=16)\n",
    "ax.set_title(\"\")\n",
    "\n",
    "# 目盛りの分割数を4以下に指定\n",
    "ax.xaxis.set_major_locator(plt.MaxNLocator(4))\n",
    "ax.yaxis.set_major_locator(plt.MaxNLocator(4))\n",
    "\n",
    "if y_log:\n",
    "    ax.set_yscale(\"log\")\n",
    "\n",
    "fig.set_size_inches(6, 4.5)\n",
    "\n",
    "fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2468ac04-56e8-40d0-87a9-e34935cb56a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.patches as mpatches\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "ax.axis('off')\n",
    "ax.legend(handles=[\n",
    "    mpatches.Patch(color=np.array([0, 0.75, 0.75]), label='GD'),\n",
    "    mpatches.Patch(color=np.array([1, 0.5, 0]), label='HSODM'),\n",
    "    mpatches.Patch(color=np.array([0, 0.5, 0]), label='RSGD (s=100)'),\n",
    "    mpatches.Patch(color=np.array([0.25, 0.25, 0.95]), label='RSRN (s=100)'),\n",
    "    mpatches.Patch(color=np.array([0.95, 0.25, 0.25]), label='RSHTR (s=100)'),\n",
    "], loc='center', ncol=5)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
