{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1fcf518",
   "metadata": {},
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'seaborn'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 7\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mpandas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mpd\u001b[39;00m\n\u001b[1;32m      6\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mseaborn\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01msns\u001b[39;00m\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mcalculate_frobenius_norm\u001b[39m(sgd_history, zo_history):\n\u001b[1;32m     11\u001b[0m     sgd \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray(sgd_history)\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'seaborn'"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import glob\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "\n",
    "def calculate_frobenius_norm(sgd_history, zo_history):\n",
    "    sgd = np.array(sgd_history)\n",
    "    zo = np.array(zo_history)\n",
    "    min_len = min(len(sgd), len(zo))\n",
    "    sgd = sgd[:min_len]\n",
    "    zo = zo[:min_len]\n",
    "    return np.linalg.norm(sgd - zo, ord=\"fro\")\n",
    "\n",
    "\n",
    "def parse_and_load_data():\n",
    "    files = glob.glob(\"llm_results/*.json\")\n",
    "    data_list = []\n",
    "\n",
    "    for f in files:\n",
    "        try:\n",
    "            name_no_ext = f.replace(\".json\", \"\")\n",
    "            parts = name_no_ext.split(\"_\")\n",
    "            k = int(parts[-1])\n",
    "            dist_type = parts[-2]\n",
    "            model_name = \"_\".join(parts[:-2])\n",
    "\n",
    "            with open(f, \"r\") as file:\n",
    "                content = json.load(file)\n",
    "\n",
    "            samples = content.get(\"samples\", [])\n",
    "            file_errors = []\n",
    "\n",
    "            for s in samples:\n",
    "                if \"sgd_probs_history\" in s and \"zo_probs_history\" in s:\n",
    "                    err = calculate_frobenius_norm(\n",
    "                        s[\"sgd_probs_history\"], s[\"zo_probs_history\"]\n",
    "                    )\n",
    "                    file_errors.append(err)\n",
    "\n",
    "            if file_errors:\n",
    "                avg_error = np.mean(file_errors)\n",
    "                data_list.append(\n",
    "                    {\n",
    "                        \"Model\": model_name,\n",
    "                        \"Distribution\": dist_type,\n",
    "                        \"Perturbation Size (k)\": k,\n",
    "                        \"Frobenius Norm Error\": avg_error,\n",
    "                    }\n",
    "                )\n",
    "\n",
    "        except Exception as e:\n",
    "            print(f\"Error processing {f}: {e}\")\n",
    "\n",
    "    return pd.DataFrame(data_list)\n",
    "\n",
    "\n",
    "def plot_results(df):\n",
    "    if df.empty:\n",
    "        print(\"No data loaded!\")\n",
    "        return\n",
    "\n",
    "    df = df.copy()\n",
    "\n",
    "    df[\"Distribution\"] = df[\"Distribution\"].replace(\n",
    "        {\n",
    "            \"bernoulli\": \"Rademacher\",\n",
    "            \"gaussian\": \"Gaussian\",\n",
    "        }\n",
    "    )\n",
    "    df[\"Model\"] = df[\"Model\"].str.replace(\"llm_results/\", \"\", regex=False)\n",
    "    sns.set(style=\"whitegrid\", context=\"talk\")\n",
    "\n",
    "    plt.figure(figsize=(10, 6))\n",
    "    sns.lineplot(\n",
    "        data=df,\n",
    "        x=\"Perturbation Size (k)\",\n",
    "        y=\"Frobenius Norm Error\",\n",
    "        hue=\"Distribution\",\n",
    "        style=\"Distribution\",\n",
    "        markers=True,\n",
    "        dashes=False,\n",
    "        linewidth=2.5,\n",
    "        markersize=10,\n",
    "    )\n",
    "    plt.xscale(\"log\")\n",
    "    plt.title(\"ZO Error Convergence (Frobenius Norm)\")\n",
    "    plt.ylabel(\"Trajectory Error (Frobenius Norm)\")\n",
    "    plt.xlabel(\"Number of Perturbations (P)\")\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"frobenius_hint1_hint3.pdf\", bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "\n",
    "    plt.figure(figsize=(10, 6))\n",
    "    gaussian_subset = df[df[\"Distribution\"] == \"Gaussian\"]\n",
    "\n",
    "    if not gaussian_subset.empty:\n",
    "        sns.lineplot(\n",
    "            data=gaussian_subset,\n",
    "            x=\"Perturbation Size (k)\",\n",
    "            y=\"Frobenius Norm Error\",\n",
    "            hue=\"Model\",\n",
    "            marker=\"o\",\n",
    "            linewidth=2.5,\n",
    "            markersize=10,\n",
    "        )\n",
    "        plt.xscale(\"log\")\n",
    "        plt.title(\"Error vs. Model Dimension (Gaussian Perturbation)\")\n",
    "        plt.ylabel(\"Trajectory Error (Frobenius Norm)\")\n",
    "        plt.xlabel(\"Number of Perturbations (P)\")\n",
    "        plt.tight_layout()\n",
    "        plt.savefig(\"frobenius_hint2_dimension_gaussian.pdf\", bbox_inches=\"tight\")\n",
    "        plt.show()\n",
    "\n",
    "    plt.figure(figsize=(10, 6))\n",
    "    rademacher_subset = df[df[\"Distribution\"] == \"Rademacher\"]\n",
    "\n",
    "    if not rademacher_subset.empty:\n",
    "        sns.lineplot(\n",
    "            data=rademacher_subset,\n",
    "            x=\"Perturbation Size (k)\",\n",
    "            y=\"Frobenius Norm Error\",\n",
    "            hue=\"Model\",\n",
    "            marker=\"o\",\n",
    "            linewidth=2.5,\n",
    "            markersize=10,\n",
    "        )\n",
    "        plt.xscale(\"log\")\n",
    "        plt.title(\"Error vs. Model Dimension (Rademacher Perturbation)\")\n",
    "        plt.ylabel(\"Trajectory Error (Frobenius Norm)\")\n",
    "        plt.xlabel(\"Number of Perturbations (P)\")\n",
    "        plt.tight_layout()\n",
    "        # plt.savefig(\"frobenius_hint2_dimension_rademacher.pdf\", bbox_inches=\"tight\")\n",
    "        plt.show()\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    df = parse_and_load_data()\n",
    "    plot_results(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81d32360",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import glob\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.lines import Line2D\n",
    "\n",
    "DATA_DIR = \"llm_results\"\n",
    "K_VALUES_TO_PLOT = [1, 5, 10, 20, 50, 100]\n",
    "CLASS_IDX = 1\n",
    "colors = {\n",
    "    \"SGD\": \"#2d3436\",\n",
    "    \"Gaussian\": \"#0984e3\",\n",
    "    \"Rademacher\": \"#d63031\",\n",
    "}\n",
    "\n",
    "\n",
    "def get_file_path(model, dist, k):\n",
    "    return os.path.join(DATA_DIR, f\"{model}_{dist}_{k}.json\")\n",
    "\n",
    "\n",
    "def get_num_samples(model):\n",
    "    for k in [100, 50, 20, 10, 5, 1]:\n",
    "        for dist in [\"gaussian\", \"bernoulli\"]:\n",
    "            path = get_file_path(model, dist, k)\n",
    "            if os.path.exists(path):\n",
    "                with open(path, \"r\") as f:\n",
    "                    data = json.load(f)\n",
    "                    return len(data.get(\"samples\", []))\n",
    "    return 1\n",
    "\n",
    "\n",
    "def plot_model_trajectories(model_name):\n",
    "    num_samples = get_num_samples(model_name)\n",
    "    print(f\"Found {num_samples} samples for model {model_name}\")\n",
    "\n",
    "    for sample_idx in range(num_samples):\n",
    "        fig, axes = plt.subplots(2, 3, figsize=(18, 9), sharex=True, sharey=True)\n",
    "        axes = axes.flatten()\n",
    "\n",
    "        for i, k in enumerate(K_VALUES_TO_PLOT):\n",
    "            ax = axes[i]\n",
    "            handles = []\n",
    "            labels = []\n",
    "\n",
    "            # --- dummy handle for P value in legend ---\n",
    "            p_handle = Line2D([], [], color=\"none\", label=rf\"$\\mathbf{{P={k}}}$\")\n",
    "            handles.append(p_handle)\n",
    "            labels.append(rf\"$P={k}$\")\n",
    "\n",
    "            for dist in [\"gaussian\", \"bernoulli\"]:\n",
    "                path = get_file_path(model_name, dist, k)\n",
    "                if os.path.exists(path):\n",
    "                    with open(path, \"r\") as f:\n",
    "                        data = json.load(f)\n",
    "                        if sample_idx < len(data.get(\"samples\", [])):\n",
    "                            sample = data[\"samples\"][sample_idx]\n",
    "\n",
    "                            # ZO trajectory\n",
    "                            if \"zo_probs_history\" in sample:\n",
    "                                zo_traj = np.array(sample[\"zo_probs_history\"])[\n",
    "                                    :, CLASS_IDX\n",
    "                                ]\n",
    "                                style = \"--\" if dist == \"gaussian\" else \":\"\n",
    "                                label_text = (\n",
    "                                    f\"ZO SGD ({dist.capitalize()})\"\n",
    "                                    if dist != \"bernoulli\"\n",
    "                                    else \"ZO SGD (Rademacher)\"\n",
    "                                )\n",
    "                                (h,) = ax.plot(\n",
    "                                    zo_traj,\n",
    "                                    linestyle=style,\n",
    "                                    linewidth=2.5,\n",
    "                                    color=colors[\n",
    "                                        (\n",
    "                                            dist.capitalize()\n",
    "                                            if dist != \"bernoulli\"\n",
    "                                            else \"Rademacher\"\n",
    "                                        )\n",
    "                                    ],\n",
    "                                    alpha=0.9,\n",
    "                                    label=label_text,\n",
    "                                )\n",
    "                                handles.append(h)\n",
    "                                labels.append(label_text)\n",
    "\n",
    "                            # SGD trajectory (only once per subplot)\n",
    "                            if \"sgd_probs_history\" in sample and all(\n",
    "                                l != \"FO SGD\" for l in labels\n",
    "                            ):\n",
    "                                sgd_traj = np.array(sample[\"sgd_probs_history\"])[\n",
    "                                    :, CLASS_IDX\n",
    "                                ]\n",
    "                                (h,) = ax.plot(\n",
    "                                    sgd_traj,\n",
    "                                    linestyle=\"-\",\n",
    "                                    linewidth=3.0,\n",
    "                                    color=colors[\"SGD\"],\n",
    "                                    alpha=0.6,\n",
    "                                    label=\"FO SGD\",\n",
    "                                )\n",
    "                                handles.append(h)\n",
    "                                labels.append(\"FO SGD\")\n",
    "\n",
    "            # Axes styling\n",
    "            ax.set_ylim(0.0, 1.0)\n",
    "            ax.grid(True, linestyle=\":\", alpha=0.6)\n",
    "            if i >= 3:\n",
    "                ax.set_xlabel(\"Iterations\")\n",
    "            if i % 3 == 0:\n",
    "                ax.set_ylabel(\"Predicted Probability\")\n",
    "\n",
    "            # Legend\n",
    "            ax.legend(handles, labels, loc=\"lower right\", fontsize=11, frameon=True)\n",
    "\n",
    "        fig.suptitle(\n",
    "            f\"ZO Trajectory Convergence: {model_name.upper()} + SST2 (Sample {sample_idx})\",\n",
    "            fontsize=22,\n",
    "            y=0.97,\n",
    "        )\n",
    "\n",
    "        plt.tight_layout()\n",
    "        plt.savefig(\n",
    "            f\"trajectory_grid_{model_name}_sample{sample_idx}.pdf\", bbox_inches=\"tight\"\n",
    "        )\n",
    "        plt.close()\n",
    "\n",
    "\n",
    "def main():\n",
    "    plot_model_trajectories(\"opt-350m\")\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2c827b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import glob\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.lines import Line2D\n",
    "\n",
    "DATA_DIR = \"llm_results\"\n",
    "K_VALUES_TO_PLOT = [1, 5, 10, 20, 50, 100]\n",
    "CLASS_IDX = 1\n",
    "\n",
    "\n",
    "def get_file_path(model, dist, k):\n",
    "    return os.path.join(DATA_DIR, f\"{model}_{dist}_{k}.json\")\n",
    "\n",
    "\n",
    "def find_representative_sample_idx(model):\n",
    "    ref_path = get_file_path(model, \"gaussian\", 100)\n",
    "\n",
    "    if not os.path.exists(ref_path):\n",
    "        for k in [100, 50, 20, 10, 5, 1]:\n",
    "            for d in [\"gaussian\", \"bernoulli\"]:\n",
    "                p = get_file_path(model, d, k)\n",
    "                if os.path.exists(p):\n",
    "                    ref_path = p\n",
    "                    break\n",
    "\n",
    "    try:\n",
    "        with open(ref_path, \"r\") as f:\n",
    "            data = json.load(f)\n",
    "\n",
    "        max_range = -1\n",
    "        best_idx = 0\n",
    "\n",
    "        for i, sample in enumerate(data.get(\"samples\", [2])):\n",
    "            if \"sgd_probs_history\" in sample:\n",
    "                probs = np.array(sample[\"sgd_probs_history\"])\n",
    "                if probs.shape[1] > CLASS_IDX:\n",
    "                    probs_c = probs[:, CLASS_IDX]\n",
    "                    prob_range = np.max(probs_c) - np.min(probs_c)\n",
    "                    if prob_range > max_range:\n",
    "                        max_range = prob_range\n",
    "                        best_idx = i\n",
    "        return best_idx\n",
    "\n",
    "    except Exception as e:\n",
    "        print(f\"Error searching sample: {e}\")\n",
    "        return 0\n",
    "\n",
    "\n",
    "def plot_model_trajectories(model_name):\n",
    "    sample_idx = find_representative_sample_idx(model_name)\n",
    "\n",
    "    fig, axes = plt.subplots(2, 3, figsize=(18, 9), sharex=True, sharey=True)\n",
    "    axes = axes.flatten()\n",
    "\n",
    "    colors = {\n",
    "        \"SGD\": \"#000000\",  # black\n",
    "        \"Gaussian\": \"#0057B8\",  # deep blue\n",
    "        \"Rademacher\": \"#C1121F\",  # deep red\n",
    "    }\n",
    "\n",
    "    markers = {\n",
    "        \"SGD\": \"o\",\n",
    "        \"Gaussian\": \"^\",\n",
    "        \"Rademacher\": \"s\",\n",
    "    }\n",
    "\n",
    "    for i, k in enumerate(K_VALUES_TO_PLOT):\n",
    "        ax = axes[i]\n",
    "\n",
    "        path_g = get_file_path(model_name, \"gaussian\", k)\n",
    "        path_b = get_file_path(model_name, \"bernoulli\", k)\n",
    "\n",
    "        handles = []\n",
    "        labels = []\n",
    "\n",
    "        # ===== Dummy handle for MODEL NAME =====\n",
    "        model_handle = Line2D([], [], color=\"none\", label=f\"{model_name.upper()}\")\n",
    "        handles.append(model_handle)\n",
    "        labels.append(f\"{model_name.upper()}\")\n",
    "\n",
    "        # ===== Dummy handle for P =====\n",
    "        p_handle = Line2D([], [], color=\"none\", label=rf\"$\\mathbf{{P={k}}}$\")\n",
    "        handles.append(p_handle)\n",
    "        labels.append(rf\"$P={k}$\")\n",
    "\n",
    "        # ===== Gaussian =====\n",
    "        if os.path.exists(path_g):\n",
    "            with open(path_g, \"r\") as f:\n",
    "                data = json.load(f)\n",
    "                if len(data[\"samples\"]) > sample_idx:\n",
    "                    s = data[\"samples\"][sample_idx]\n",
    "\n",
    "                    zo_traj = np.array(s[\"zo_probs_history\"])[:, CLASS_IDX]\n",
    "                    (h1,) = ax.plot(\n",
    "                        zo_traj,\n",
    "                        linestyle=\"--\",\n",
    "                        linewidth=5,\n",
    "                        color=colors[\"Gaussian\"],\n",
    "                        alpha=1,\n",
    "                        marker=markers[\"Gaussian\"],\n",
    "                        markersize=8,\n",
    "                        markevery=20,\n",
    "                        label=\"ZO SGD (Gaussian)\",\n",
    "                    )\n",
    "                    handles.append(h1)\n",
    "                    labels.append(\"ZO SGD (Gaussian)\")\n",
    "\n",
    "                    sgd_traj = np.array(s[\"sgd_probs_history\"])[:, CLASS_IDX]\n",
    "                    (h2,) = ax.plot(\n",
    "                        sgd_traj,\n",
    "                        linestyle=\"-\",\n",
    "                        linewidth=5,\n",
    "                        color=colors[\"SGD\"],\n",
    "                        alpha=0.6,\n",
    "                        marker=markers[\"SGD\"],\n",
    "                        markersize=8,\n",
    "                        markevery=20,\n",
    "                        label=\"FO SGD\",\n",
    "                    )\n",
    "                    handles.append(h2)\n",
    "                    labels.append(\"FO SGD\")\n",
    "\n",
    "        # ===== Rademacher =====\n",
    "        if os.path.exists(path_b):\n",
    "            with open(path_b, \"r\") as f:\n",
    "                data = json.load(f)\n",
    "                if len(data[\"samples\"]) > sample_idx:\n",
    "                    s = data[\"samples\"][sample_idx]\n",
    "\n",
    "                    zo_traj = np.array(s[\"zo_probs_history\"])[:, CLASS_IDX]\n",
    "                    (h3,) = ax.plot(\n",
    "                        zo_traj,\n",
    "                        linestyle=\":\",\n",
    "                        linewidth=5,\n",
    "                        color=colors[\"Rademacher\"],\n",
    "                        alpha=1,\n",
    "                        marker=markers[\"Rademacher\"],\n",
    "                        markersize=8,\n",
    "                        markevery=20,\n",
    "                        label=\"ZO SGD (Rademacher)\",\n",
    "                    )\n",
    "                    handles.append(h3)\n",
    "                    labels.append(\"ZO SGD (Rademacher)\")\n",
    "\n",
    "        # ===== Axes styling =====\n",
    "        ax.set_ylim(0, 1)\n",
    "        ax.grid(True, linestyle=\":\", alpha=0.6)\n",
    "\n",
    "        if i >= 3:\n",
    "            ax.set_xlabel(\"Iterations\")\n",
    "        if i % 3 == 0:\n",
    "            ax.set_ylabel(\"Predicted Probability\")\n",
    "\n",
    "        ax.legend(\n",
    "            handles,\n",
    "            labels,\n",
    "            # loc=\"upper right\",\n",
    "            fontsize=14,\n",
    "            frameon=True,\n",
    "        )\n",
    "\n",
    "    fig.suptitle(\n",
    "        f\"ZO Trajectory Convergence: {model_name.upper()} + SST2\",\n",
    "        fontsize=22,\n",
    "        y=0.97,\n",
    "    )\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"trajectory_grid_{model_name}.pdf\", bbox_inches=\"tight\")\n",
    "    plt.close()\n",
    "\n",
    "\n",
    "def main():\n",
    "    plot_model_trajectories(\"opt-1.3b\")\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ea859c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.lines import Line2D\n",
    "\n",
    "DATA_DIR = \"llm_results\"\n",
    "K_VALUES_TO_PLOT = [1, 5, 10, 20, 50, 100]\n",
    "CLASS_IDX = 1\n",
    "MODELS = [\"opt-125m\", \"opt-350m\", \"opt-1.3b\"]\n",
    "\n",
    "\n",
    "def get_file_path(model, dist, k):\n",
    "    return os.path.join(DATA_DIR, f\"{model}_{dist}_{k}.json\")\n",
    "\n",
    "\n",
    "def find_representative_sample_idx(model):\n",
    "    ref_path = get_file_path(model, \"gaussian\", 100)\n",
    "\n",
    "    if not os.path.exists(ref_path):\n",
    "        for k in [100, 50, 20, 10, 5, 1]:\n",
    "            for d in [\"gaussian\", \"bernoulli\"]:\n",
    "                p = get_file_path(model, d, k)\n",
    "                if os.path.exists(p):\n",
    "                    ref_path = p\n",
    "                    break\n",
    "\n",
    "    try:\n",
    "        with open(ref_path, \"r\") as f:\n",
    "            data = json.load(f)\n",
    "\n",
    "        max_range = -1\n",
    "        best_idx = 0\n",
    "        for i, sample in enumerate(data.get(\"samples\", [])):\n",
    "            if \"sgd_probs_history\" in sample:\n",
    "                probs = np.array(sample[\"sgd_probs_history\"])\n",
    "                if probs.shape[1] > CLASS_IDX:\n",
    "                    probs_c = probs[:, CLASS_IDX]\n",
    "                    prob_range = np.max(probs_c) - np.min(probs_c)\n",
    "                    if prob_range > max_range:\n",
    "                        max_range = prob_range\n",
    "                        best_idx = i\n",
    "        return best_idx\n",
    "\n",
    "    except Exception as e:\n",
    "        print(f\"Error searching sample: {e}\")\n",
    "        return 0\n",
    "\n",
    "\n",
    "def plot_all_models():\n",
    "    fig, axes = plt.subplots(3, 6, figsize=(24, 8), sharex=True, sharey=True)\n",
    "    axes = axes.flatten()\n",
    "\n",
    "    colors = {\n",
    "        \"SGD\": \"#000000\",  # black\n",
    "        \"Gaussian\": \"#0057B8\",  # deep blue\n",
    "        \"Rademacher\": \"#C1121F\",  # deep red\n",
    "    }\n",
    "\n",
    "    markers = {\n",
    "        \"SGD\": \"o\",\n",
    "        \"Gaussian\": \"^\",\n",
    "        \"Rademacher\": \"s\",\n",
    "    }\n",
    "\n",
    "    for row_idx, model_name in enumerate(MODELS):\n",
    "        sample_idx = find_representative_sample_idx(model_name)\n",
    "\n",
    "        for col_idx, k in enumerate(K_VALUES_TO_PLOT):\n",
    "            ax_idx = row_idx * 6 + col_idx\n",
    "            ax = axes[ax_idx]\n",
    "\n",
    "            path_g = get_file_path(model_name, \"gaussian\", k)\n",
    "            path_b = get_file_path(model_name, \"bernoulli\", k)\n",
    "\n",
    "            handles = []\n",
    "            labels = []\n",
    "\n",
    "            # ===== Dummy handle for MODEL NAME =====\n",
    "            model_handle = Line2D([], [], color=\"none\", label=f\"{model_name.upper()}\")\n",
    "            handles.append(model_handle)\n",
    "            labels.append(f\"{model_name.upper()}\")\n",
    "\n",
    "            # ===== Gaussian =====\n",
    "            if os.path.exists(path_g):\n",
    "                with open(path_g, \"r\") as f:\n",
    "                    data = json.load(f)\n",
    "                    if len(data[\"samples\"]) > sample_idx:\n",
    "                        s = data[\"samples\"][sample_idx]\n",
    "\n",
    "                        zo_traj = np.array(s[\"zo_probs_history\"])[:, CLASS_IDX]\n",
    "                        (h1,) = ax.plot(\n",
    "                            zo_traj,\n",
    "                            linestyle=\"--\",\n",
    "                            linewidth=3,\n",
    "                            color=colors[\"Gaussian\"],\n",
    "                            alpha=1,\n",
    "                            marker=markers[\"Gaussian\"],\n",
    "                            markersize=6,\n",
    "                            markevery=20,\n",
    "                            label=\"ZO SGD (Gaus.)\",\n",
    "                        )\n",
    "                        handles.append(h1)\n",
    "                        labels.append(\"ZO SGD (Gaus.)\")\n",
    "\n",
    "                        sgd_traj = np.array(s[\"sgd_probs_history\"])[:, CLASS_IDX]\n",
    "                        (h2,) = ax.plot(\n",
    "                            sgd_traj,\n",
    "                            linestyle=\"-\",\n",
    "                            linewidth=3,\n",
    "                            color=colors[\"SGD\"],\n",
    "                            alpha=0.6,\n",
    "                            marker=markers[\"SGD\"],\n",
    "                            markersize=6,\n",
    "                            markevery=20,\n",
    "                            label=\"FO SGD\",\n",
    "                        )\n",
    "                        handles.append(h2)\n",
    "                        labels.append(\"FO SGD\")\n",
    "\n",
    "            # ===== Rademacher =====\n",
    "            if os.path.exists(path_b):\n",
    "                with open(path_b, \"r\") as f:\n",
    "                    data = json.load(f)\n",
    "                    if len(data[\"samples\"]) > sample_idx:\n",
    "                        s = data[\"samples\"][sample_idx]\n",
    "\n",
    "                        zo_traj = np.array(s[\"zo_probs_history\"])[:, CLASS_IDX]\n",
    "                        (h3,) = ax.plot(\n",
    "                            zo_traj,\n",
    "                            linestyle=\":\",\n",
    "                            linewidth=3,\n",
    "                            color=colors[\"Rademacher\"],\n",
    "                            alpha=1,\n",
    "                            marker=markers[\"Rademacher\"],\n",
    "                            markersize=6,\n",
    "                            markevery=20,\n",
    "                            label=\"ZO SGD (Rade.)\",\n",
    "                        )\n",
    "                        handles.append(h3)\n",
    "                        labels.append(\"ZO SGD (Rade.)\")\n",
    "\n",
    "            ax.set_ylim(0, 1)\n",
    "            ax.grid(True, linestyle=\":\", alpha=0.6)\n",
    "\n",
    "            if row_idx == 0:\n",
    "                ax.set_title(f\"P={k}\", fontsize=20, fontweight=\"bold\")\n",
    "            if row_idx == 2:\n",
    "                ax.set_xlabel(\"Iterations\", fontsize=15)\n",
    "            if col_idx == 0:\n",
    "                ax.set_ylabel(\"Predicted Probability\", fontsize=15)\n",
    "\n",
    "            ax.legend(handles, labels, fontsize=13, frameon=True)\n",
    "\n",
    "    # fig.suptitle(\"ZO Trajectory Convergence for Different Models + SST2\", fontsize=22, y=0.97)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"trajectory_all_models.pdf\", bbox_inches=\"tight\")\n",
    "    plt.close()\n",
    "\n",
    "\n",
    "def main():\n",
    "    plot_all_models()\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "zo-llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
