{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "import logging\n",
    "from ccmm.utils.plot import Palette\n",
    "from tqdm import tqdm\n",
    "from wandb.sdk.wandb_run import Run\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import plotly.express as px\n",
    "import numpy as np\n",
    "import plotly.graph_objs as go\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from nn_core.common import PROJECT_ROOT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pylogger = logging.getLogger(__name__)\n",
    "\n",
    "\n",
    "plt.rcParams.update(\n",
    "    {\n",
    "        \"text.usetex\": True,\n",
    "        \"font.family\": \"serif\",\n",
    "    }\n",
    ")\n",
    "sns.set_context(\"talk\")\n",
    "\n",
    "cmap_name = \"coolwarm_r\"\n",
    "\n",
    "palette = Palette(f\"{PROJECT_ROOT}/misc/palette2.json\")\n",
    "palette"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "api = wandb.Api()\n",
    "entity, project = \"ANONYMIZED\", \"cycle-consistent-model-merging\"  # set to your entity and project"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_runs(entity, project, positive_tags, negative_tags):\n",
    "    filters_pos_tags = {\"$and\": [{\"tags\": {\"$eq\": pos_tag}} for pos_tag in positive_tags]}\n",
    "    filters_neg_tags = {}\n",
    "\n",
    "    print(filters_pos_tags)\n",
    "    filters = {**filters_pos_tags, **filters_neg_tags}\n",
    "    runs = api.runs(entity + \"/\" + project, filters=filters)\n",
    "\n",
    "    print(f\"There are {len(runs)} runs respecting these conditions.\")\n",
    "    return runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tags = [\n",
    "    \"mlp\",\n",
    "    \"cifar10\",\n",
    "    \"match_two_models\",\n",
    "]  # [\"matching\", \"pairwise\", \"final\", \"resnet20\", \"8x\", \"cifar100\"]  # \"mlp\", [\"resnet20\", \"2x\"], \"vgg\", \"batch_norm\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "runs = get_runs(entity, project, positive_tags=tags, negative_tags=[])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extract results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_pairs = [(1, 2), (2, 3), (1, 3)]\n",
    "all_seeds = range(1, 5)\n",
    "matchers = [\"git_rebasin\", \"frank_wolfe\", \"naive\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exps = {matcher: {pair: {seed: {} for seed in all_seeds} for pair in model_pairs} for matcher in matchers}\n",
    "print(exps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_key = \"matching/seed_index\"\n",
    "model_pair_key = \"matching/model_seeds\"\n",
    "\n",
    "matcher_key = \"matching/matcher/_target_\"\n",
    "\n",
    "alternating_diff_classname = \"ccmm.matching.matcher.AlternatingDiffusionMatcher\"\n",
    "gitrebasin_classname = \"ccmm.matching.matcher.GitRebasinMatcher\"\n",
    "quadratic_classname = \"ccmm.matching.matcher.QuadraticMatcher\"\n",
    "frankwolfe_classname = \"ccmm.matching.matcher.FrankWolfeMatcher\"\n",
    "naive_classname = \"ccmm.matching.matcher.DummyMatcher\"\n",
    "\n",
    "model_key = \"model/name\"\n",
    "matcher_mapping = {\n",
    "    alternating_diff_classname: \"alternating_diffusion\",\n",
    "    gitrebasin_classname: \"git_rebasin\",\n",
    "    quadratic_classname: \"quadratic\",\n",
    "    frankwolfe_classname: \"frank_wolfe\",\n",
    "    naive_classname: \"naive\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = \"MLP\"  # VGG, MLP, ResNet20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def remove_nones(array):\n",
    "    return np.array([x for x in array if x is not None])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for run in tqdm(runs):\n",
    "    run: Run\n",
    "    cfg = run.config\n",
    "\n",
    "    if len(cfg) == 0:\n",
    "        pylogger.warning(\"Runs are still running, skipping\")\n",
    "        continue\n",
    "\n",
    "    if cfg[model_key] != model:\n",
    "        continue\n",
    "\n",
    "    seed = cfg[seed_key]\n",
    "    model_pair = cfg[model_pair_key]\n",
    "\n",
    "    matcher_classname = cfg[matcher_key]\n",
    "    matcher_mapped = matcher_mapping[matcher_classname]\n",
    "\n",
    "    hist = run.scan_history()\n",
    "\n",
    "    train_acc_curve = remove_nones(np.array([row[\"train_acc\"] for row in hist if \"train_acc\" in row]))\n",
    "    test_acc_curve = remove_nones(np.array([row[\"test_acc\"] for row in hist if \"test_acc\" in row]))\n",
    "\n",
    "    train_loss_curve = remove_nones(np.array([row[\"train_loss\"] for row in hist if \"train_loss\" in row]))\n",
    "    test_loss_curve = remove_nones(np.array([row[\"test_loss\"] for row in hist if \"test_loss\" in row]))\n",
    "\n",
    "    test_loss_barrier = run.history(keys=[\"test_loss_barrier\"])[\"test_loss_barrier\"][0]\n",
    "    train_loss_barrier = run.history(keys=[\"train_loss_barrier\"])[\"train_loss_barrier\"][0]\n",
    "\n",
    "    exps[matcher_mapped][tuple(model_pair)][seed] = {\n",
    "        \"train_acc_curve\": train_acc_curve,\n",
    "        \"test_acc_curve\": test_acc_curve,\n",
    "        \"train_loss_curve\": train_loss_curve,\n",
    "        \"test_loss_curve\": test_loss_curve,\n",
    "        \"test_loss_barrier\": test_loss_barrier,\n",
    "        \"train_loss_barrier\": train_loss_barrier,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "records = []\n",
    "\n",
    "for matcher_name, matcher_data in exps.items():\n",
    "    for pair, pair_data in matcher_data.items():\n",
    "        for seed, metrics in pair_data.items():\n",
    "            if metrics:\n",
    "                record = {\n",
    "                    \"matcher\": matcher_name,\n",
    "                    \"pair\": pair,\n",
    "                    \"seed\": seed,\n",
    "                    \"train_loss_barrier\": metrics[\"train_loss_barrier\"],\n",
    "                    \"test_loss_barrier\": metrics[\"test_loss_barrier\"],\n",
    "                }\n",
    "\n",
    "                records.append(record)\n",
    "\n",
    "df = pd.DataFrame(records)\n",
    "df[\"pair\"] = df[\"pair\"].astype(str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate mean test and train loss barrier for each matcher\n",
    "mean_metrics = df.groupby([\"matcher\", \"pair\"]).mean().reset_index()\n",
    "mean_metrics\n",
    "\n",
    "\n",
    "# add a column with the standard deviation\n",
    "std_metrics = df.groupby([\"matcher\", \"pair\"]).std().reset_index()\n",
    "mean_metrics[\"test_loss_barrier_std\"] = std_metrics[\"test_loss_barrier\"]\n",
    "mean_metrics[\"train_loss_barrier_std\"] = std_metrics[\"train_loss_barrier\"]\n",
    "mean_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "numeric_cols = df.select_dtypes(include=[np.number]).columns\n",
    "\n",
    "# Calculate mean test and train loss barrier for each matcher\n",
    "mean_metrics = df.groupby(\"matcher\")[numeric_cols].mean().reset_index()\n",
    "mean_metrics\n",
    "\n",
    "\n",
    "# add a column with the standard deviation\n",
    "std_metrics = df.groupby(\"matcher\")[numeric_cols].std().reset_index()\n",
    "mean_metrics[\"test_loss_barrier_std\"] = std_metrics[\"test_loss_barrier\"]\n",
    "mean_metrics[\"train_loss_barrier_std\"] = std_metrics[\"train_loss_barrier\"]\n",
    "mean_metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matcher_to_latex_map = {\n",
    "    \"frank_wolfe\": r\"\\texttt{Frank-Wolfe}\",\n",
    "    \"git_rebasin\": r\"\\texttt{Git-Rebasin}\",\n",
    "    \"naive\": r\"\\texttt{Naive}\",\n",
    "}\n",
    "\n",
    "header = r\"\"\"\n",
    "\\begin{table}\n",
    "    \\begin{center}\n",
    "        \\begin{tabular}{lccc}\n",
    "        \\toprule\n",
    "        \\textbf{Matcher}        & \\multicolumn{2}{c}{\\textbf{Barrier}}                   \\\\\n",
    "                                & \\textbf{Train}                       & \\textbf{Test}   \\\\\n",
    "        \\midrule\n",
    "        \"\"\"\n",
    "\n",
    "\n",
    "body = \"\"\n",
    "\n",
    "for row in mean_metrics.iterrows():\n",
    "    row = row[1]\n",
    "    matcher = row[\"matcher\"]\n",
    "    test_loss_barrier = row[\"test_loss_barrier\"]\n",
    "    test_loss_barrier_std = row[\"test_loss_barrier_std\"]\n",
    "    train_loss_barrier = row[\"train_loss_barrier\"]\n",
    "    train_loss_barrier_std = row[\"train_loss_barrier_std\"]\n",
    "    body += f\"\"\"\n",
    "            & {matcher_to_latex_map[matcher]} & {test_loss_barrier:.2f} $\\pm$ {test_loss_barrier_std:.2f} & {train_loss_barrier:.2f} $\\pm$ {train_loss_barrier_std:.2f} \\\\\\\\\n",
    "    \"\"\"\n",
    "\n",
    "footer = r\"\"\"\n",
    "        \\bottomrule\n",
    "        \\end{tabular}\n",
    "    \\end{center}\n",
    "    \\caption{Mean and standard deviation of the test and train loss barrier for each matcher.}\n",
    "    \\label{tab:MLP_loss_barrier}\n",
    "\\end{table}\"\"\"\n",
    "\n",
    "table = header + body + footer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(table)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot curves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loss_curves = {\n",
    "    seed: {pair: {matcher_name: {} for matcher_name in matchers} for pair in model_pairs} for seed in all_seeds\n",
    "}\n",
    "\n",
    "test_acc_curves = {\n",
    "    seed: {pair: {matcher_name: {} for matcher_name in matchers} for pair in model_pairs} for seed in all_seeds\n",
    "}\n",
    "\n",
    "train_loss_curves = {\n",
    "    seed: {pair: {matcher_name: {} for matcher_name in matchers} for pair in model_pairs} for seed in all_seeds\n",
    "}\n",
    "\n",
    "train_acc_curves = {\n",
    "    seed: {pair: {matcher_name: {} for matcher_name in matchers} for pair in model_pairs} for seed in all_seeds\n",
    "}\n",
    "\n",
    "for matcher_name, matcher_data in exps.items():\n",
    "    for pair, pair_data in matcher_data.items():\n",
    "        for seed, metrics in pair_data.items():\n",
    "            if metrics:\n",
    "                train_loss_curve = metrics[\"train_loss_curve\"]\n",
    "                test_loss_curve = metrics[\"test_loss_curve\"]\n",
    "\n",
    "                train_acc_curve = metrics[\"train_acc_curve\"]\n",
    "                test_acc_curve = metrics[\"test_acc_curve\"]\n",
    "\n",
    "                test_loss_curves[seed][pair][matcher_name] = test_loss_curve\n",
    "                test_acc_curves[seed][pair][matcher_name] = test_acc_curve\n",
    "\n",
    "                train_acc_curves[seed][pair][matcher_name] = train_acc_curve\n",
    "                train_loss_curves[seed][pair][matcher_name] = train_loss_curve"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Naive and Frank-Wolfe matchers can be ran just once as there is no randomness involved; sometimes however we still run all the seeds\n",
    "only_ran_once = True\n",
    "if only_ran_once:\n",
    "    for seed in all_seeds:\n",
    "        for pair in model_pairs:\n",
    "            for matcher_name in matchers:\n",
    "                if matcher_name == \"naive\" or matcher_name == \"frank_wolfe\":\n",
    "                    test_loss_curves[seed][pair][matcher_name] = test_loss_curves[1][pair][matcher_name]\n",
    "                    test_acc_curves[seed][pair][matcher_name] = test_acc_curves[1][pair][matcher_name]\n",
    "                    train_loss_curves[seed][pair][matcher_name] = train_loss_curves[1][pair][matcher_name]\n",
    "                    train_acc_curves[seed][pair][matcher_name] = train_acc_curves[1][pair][matcher_name]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_averaged_over_seeds = {pair: {matcher_name: {} for matcher_name in matchers} for pair in model_pairs}\n",
    "\n",
    "for matcher_name in matchers:\n",
    "    for pair in model_pairs:\n",
    "        test_loss_curves_pair = np.array([test_loss_curves[seed][pair][matcher_name] for seed in all_seeds])\n",
    "        test_acc_curves_pair = np.array([test_acc_curves[seed][pair][matcher_name] for seed in all_seeds])\n",
    "\n",
    "        train_loss_curves_pair = np.array([train_loss_curves[seed][pair][matcher_name] for seed in all_seeds])\n",
    "        train_acc_curves_pair = np.array([train_acc_curves[seed][pair][matcher_name] for seed in all_seeds])\n",
    "\n",
    "        test_loss_curves_mean = np.mean(test_loss_curves_pair, axis=0)\n",
    "        test_acc_curves_mean = np.mean(test_acc_curves_pair, axis=0)\n",
    "\n",
    "        train_loss_curves_mean = np.mean(train_loss_curves_pair, axis=0)\n",
    "        train_acc_curves_mean = np.mean(train_acc_curves_pair, axis=0)\n",
    "\n",
    "        results_averaged_over_seeds[pair][matcher_name] = {\n",
    "            \"test_loss_curve\": test_loss_curves_mean,\n",
    "            \"test_acc_curve\": test_acc_curves_mean,\n",
    "            \"train_loss_curve\": train_loss_curves_mean,\n",
    "            \"train_acc_curve\": train_acc_curves_mean,\n",
    "        }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot loss and acc interp curves with insets "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# seed = 1\n",
    "lambdas = np.linspace(0, 1, 10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset\n",
    "\n",
    "# Define the region for zoom-in\n",
    "x_zoom_min, x_zoom_max = 0.5, 0.6\n",
    "y_zoom_min, y_zoom_max = 0.08, 0.12\n",
    "\n",
    "colors = {\"git_rebasin\": palette[\"light red\"], \"frank_wolfe\": palette[\"green\"], \"naive\": palette[\"dark blue\"]}\n",
    "\n",
    "for pair in model_pairs:\n",
    "    fig, ax = plt.subplots()\n",
    "    ax.set_title(f\"Matching algorithms over seeds {pair[0]} and {pair[1]}\")\n",
    "    ax.set_xlabel(r\"$ \\alpha $\")\n",
    "    ax.set_ylabel(\"Loss\")\n",
    "\n",
    "    # bounds (x0, y0, width, height)\n",
    "    # ax_inset = ax.inset_axes([0.1, 0.7, 0.2, 0.2], xticklabels=[], yticklabels=[])\n",
    "    # ax_inset.set_xlim(x_zoom_min, x_zoom_max)\n",
    "    # ax_inset.set_ylim(y_zoom_min, y_zoom_max)\n",
    "    # ax_inset.set_title(\"\")\n",
    "\n",
    "    for matcher_name in matchers:\n",
    "        test_loss_curve = results_averaged_over_seeds[pair][matcher_name][\"test_loss_curve\"]\n",
    "        if len(test_loss_curve) == 0:\n",
    "            continue\n",
    "\n",
    "        ax.plot(lambdas, test_loss_curve, label=matcher_name, color=colors[matcher_name])\n",
    "        # ax_inset.plot(lambdas, test_loss_curve, label=matcher_name)\n",
    "\n",
    "    # Use mark_inset to indicate the zoomed area instead of manually adding a rectangle\n",
    "    # mark_inset(ax, ax_inset, loc1=2, loc2=4, fc=\"none\", ec=\"0.5\")\n",
    "\n",
    "    # Draw the figure explicitly (this can help in some environments)\n",
    "\n",
    "    ax.legend()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset\n",
    "\n",
    "# Define the region for zoom-in\n",
    "x_zoom_min, x_zoom_max = 0.45, 0.55\n",
    "y_zoom_min, y_zoom_max = 0.95, 1\n",
    "\n",
    "for pair in model_pairs:\n",
    "    fig, ax = plt.subplots()\n",
    "    ax.set_title(f\"Comparison of Different Matching Algorithms over seeds {pair[0]} and {pair[1]}\")\n",
    "    ax.set_xlabel(r\"$ \\alpha $\")\n",
    "    ax.set_ylabel(\"Loss\")\n",
    "\n",
    "    # bounds (x0, y0, width, height)\n",
    "    ax_inset = ax.inset_axes([0.7, 0.1, 0.2, 0.2], xticklabels=[], yticklabels=[])\n",
    "    ax_inset.set_xlim(x_zoom_min, x_zoom_max)\n",
    "    ax_inset.set_ylim(y_zoom_min, y_zoom_max)\n",
    "    ax_inset.set_title(\"\")\n",
    "\n",
    "    for matcher_name in matchers:\n",
    "        test_acc_curve = test_acc_curves[seed][pair][matcher_name]\n",
    "        if len(test_acc_curve) == 0:\n",
    "            continue\n",
    "\n",
    "        ax.plot(lambdas, test_acc_curve, label=matcher_name)\n",
    "        ax_inset.plot(lambdas, test_acc_curve, label=matcher_name)\n",
    "\n",
    "    # Use mark_inset to indicate the zoomed area instead of manually adding a rectangle\n",
    "    mark_inset(ax, ax_inset, loc1=2, loc2=4, fc=\"none\", ec=\"0.5\")\n",
    "\n",
    "    # Draw the figure explicitly (this can help in some environments)\n",
    "\n",
    "    ax.legend()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hic sunt leones"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert dictionary to DataFrame\n",
    "records = []\n",
    "for algo, algo_data in exps.items():\n",
    "    for pair, pair_data in algo_data.items():\n",
    "        for seed, metrics in pair_data.items():\n",
    "            if metrics:  # Check if metrics are not empty\n",
    "                record = {\n",
    "                    \"algorithm\": algo,\n",
    "                    \"pair_seed\": f\"Pair {pair} - seed {seed}\",  # Combining pair and seed\n",
    "                    \"accuracy\": metrics[\"acc\"],\n",
    "                    \"loss\": metrics[\"loss\"],\n",
    "                }\n",
    "                records.append(record)\n",
    "\n",
    "df = pd.DataFrame(records)\n",
    "\n",
    "# Plotting\n",
    "# Accuracy Bar Plot\n",
    "fig_acc = px.bar(\n",
    "    df, x=\"pair_seed\", y=\"accuracy\", color=\"algorithm\", barmode=\"group\", title=\"Accuracy by Algorithm, Pair, and seed\"\n",
    ")\n",
    "fig_acc.show()\n",
    "\n",
    "# Loss Bar Plot\n",
    "fig_loss = px.bar(\n",
    "    df, x=\"pair_seed\", y=\"loss\", color=\"algorithm\", barmode=\"group\", title=\"Loss by Algorithm, Pair, and seed\"\n",
    ")\n",
    "fig_loss.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for matcher in matcher_mapping.values():\n",
    "    mean_approach = df[df[\"algorithm\"] == matcher][\"accuracy\"].mean()\n",
    "    var_approach = df[df[\"algorithm\"] == matcher][\"accuracy\"].var()\n",
    "\n",
    "    print(f\"{matcher} diff: mean {mean_approach}, var {var_approach}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert dictionary to DataFrame\n",
    "records = []\n",
    "for algo, algo_data in exps.items():\n",
    "    for pair, pair_data in algo_data.items():\n",
    "        for seed, metrics in pair_data.items():\n",
    "            if metrics:  # Check if metrics are not empty\n",
    "                record = {\n",
    "                    \"algorithm\": algo,\n",
    "                    \"pair\": f\"{pair[0]}-{pair[1]}\",\n",
    "                    \"seed\": seed,\n",
    "                    \"accuracy\": metrics[\"acc\"],\n",
    "                    \"loss\": metrics[\"loss\"],\n",
    "                }\n",
    "                records.append(record)\n",
    "\n",
    "df = pd.DataFrame(records)\n",
    "\n",
    "# Sort by 'pair' and 'seed'\n",
    "df[\"sort_key\"] = df[\"pair\"] + \"-Seed\" + df[\"seed\"].astype(str)\n",
    "df.sort_values(by=\"sort_key\", inplace=True)\n",
    "\n",
    "# Pivot the DataFrame to calculate differences\n",
    "pivot_df = df.pivot_table(index=\"sort_key\", columns=\"algorithm\", values=\"accuracy\")\n",
    "\n",
    "pivot_df[\"accuracy_diff\"] = (\n",
    "    pivot_df.iloc[:, 0] - pivot_df.iloc[:, 1]\n",
    ")  # Assuming the first column is 'alternating_diffusion'\n",
    "\n",
    "pivot_df.reset_index(inplace=True)\n",
    "\n",
    "total_diff = pivot_df[\"accuracy_diff\"].mean()\n",
    "total_diff_row = pd.DataFrame(\n",
    "    [{\"sort_key\": \"Total\", \"accuracy_diff\": total_diff, \"color\": \"green\" if total_diff > 0 else \"red\"}]\n",
    ")\n",
    "\n",
    "# Concatenate the total difference row to the existing DataFrame\n",
    "pivot_df = pd.concat([pivot_df, total_diff_row], ignore_index=True)\n",
    "\n",
    "\n",
    "# Determine the color based on which algorithm performs better\n",
    "pivot_df[\"color\"] = pivot_df[\"accuracy_diff\"].apply(lambda x: \"green\" if x > 0 else \"red\")\n",
    "\n",
    "# Prepare data for plotting\n",
    "plot_data = pivot_df[[\"sort_key\", \"accuracy_diff\", \"color\"]]\n",
    "\n",
    "# # Manually insert space every 4 positions in x-axis labels\n",
    "# plot_data['x_label'] = plot_data['sort_key']\n",
    "# for i in range(3, len(plot_data['x_label']), 4):\n",
    "#     plot_data['x_label'].iloc[i] = ''  # Inserting empty string to create a gap\n",
    "\n",
    "# Plotting the differences\n",
    "fig = px.bar(plot_data, x=\"sort_key\", y=\"accuracy_diff\", color=\"color\", title=\"Performance Difference in Accuracy\")\n",
    "fig.update_xaxes(type=\"category\")  # Setting x-axis as category type\n",
    "fig.update_layout(\n",
    "    xaxis={\"categoryorder\": \"array\", \"categoryarray\": plot_data[\"sort_key\"]}\n",
    ")  # Explicitly setting the order\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.graph_objects as go\n",
    "\n",
    "# Create the figure\n",
    "fig = go.Figure()\n",
    "\n",
    "# Track the current index for spacing\n",
    "current_index = 0\n",
    "\n",
    "# Iterate over the DataFrame to add bars individually\n",
    "for index, row in plot_data.iterrows():\n",
    "    # Add a bar for each data point\n",
    "    fig.add_trace(\n",
    "        go.Bar(x=[row[\"sort_key\"]], y=[row[\"accuracy_diff\"]], marker_color=row[\"color\"], name=row[\"sort_key\"])\n",
    "    )\n",
    "\n",
    "    # Increment index\n",
    "    current_index += 1\n",
    "\n",
    "    # Add an invisible bar (for spacing) every 4 bars\n",
    "    if current_index % 4 == 0:\n",
    "        fig.add_trace(\n",
    "            go.Bar(\n",
    "                x=[f\"Space-{current_index // 4}\"],\n",
    "                y=[None],\n",
    "                marker=dict(color=\"rgba(255, 255, 255, 0)\"),  # Invisible bar\n",
    "                showlegend=False,\n",
    "            )\n",
    "        )\n",
    "\n",
    "# Update the layout to adjust the bar width and space between bars\n",
    "fig.update_traces(marker_line_width=1.5, width=0.4)  # Adjust the bar width as needed\n",
    "fig.update_layout(\n",
    "    title=\"Performance Difference in Accuracy\",\n",
    "    xaxis_title=\"Pair-Seed\",\n",
    "    yaxis_title=\"Accuracy Difference\",\n",
    "    barmode=\"group\",\n",
    ")\n",
    "\n",
    "# Show the figure\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
