{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fcf65487-d61d-4039-8f19-863d6b1192c9",
   "metadata": {},
   "source": [
    "# Contains Code for Plots in the Report\n",
    "Contains simple plots for quick visualization of results during the semester.\n",
    "Run the first piece of code to initialize important functions etc. All other plots then can be run individually.\n",
    "As for parameters, the following most often holds:\\\n",
    "#### dataset: \n",
    "[mnist, fmnist, kmnist, qmnist, emnist_letters, emnist_balanced, cifar10, cifar100, imagenet32, synthetic, custom]\\\n",
    "'custom' refers to the scaleable mnist dataset with up to 67 classes.\\\n",
    "#### file: \n",
    "The most important filenames are already plotted. For other experiments/ealier experiments, the experiments file has a list of all of them with a very short description. In the experiments folder, all python files of experiments are stored.\\\n",
    "#### extensions:\n",
    "usually refers to different variants of an experiment. If it is just different seeds, they are simply numbered from 0 on. But it can also be tau values, layer sizes, dropout sizes, etc. \\\n",
    "#### label:\n",
    "refers to the label of the curve in the plot. e.g. 'Tau' leads to the label being 'Tau: {ext}'.\n",
    "#### Else\n",
    "ylim is the minimum of the y-axis, ylim_max is the maximum.\\\n",
    "Baseline parameter set to true shows the baseline for the current metric and dataset.\n",
    "\n",
    "The most recent experiments have test and binary test data inside the results file in the `../results`. Older experiments, some of them still used for report plots and figures do not. I have rerun most models to recreate the test and binary test performance. These files are in a separate folder called `../results_temp`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d53dff6d-ce80-431f-9546-841131a93281",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# Import packages and initialize simple plotting functions\n",
    "\n",
    "import json\n",
    "import argparse\n",
    "import numpy as np\n",
    "from utils import *\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "import matplotlib.colors as mcolors\n",
    "from matplotlib.colors import LogNorm\n",
    "import pandas as pd\n",
    "\n",
    "def get_baseline(dataset, metric):\n",
    "    baselines = []\n",
    "    for i in range(3):\n",
    "        baseline_data = load_json(f\"../results/{dataset}_baseline_final_{i}.json\")  # Adjust path as needed\n",
    "        baselines.append(baseline_data[metric + \"_\"])\n",
    "    return 100* sum(baselines)/len(baselines)\n",
    "\n",
    "def plot_metrics(\n",
    "    dataset,\n",
    "    file,\n",
    "    extensions,\n",
    "    label,\n",
    "    metrics = [\"train_acc_train_mode\", \"train_acc_eval_mode\", \"valid_acc_train_mode\", \"valid_acc_eval_mode\", \"bin_acc_train_mode\", \"bin_acc_eval_mode\",\n",
    "              \"test_acc_eval_mode\",\n",
    "              \"test_bin_acc_eval_mode\"],\n",
    "    titles = [\"Train Accuracy (train mode)\", \"Training Accuracy (eval mode)\", \"Validation Accuracy (train mode)\", \"Validation Accuracy (eval model)\",\n",
    "             \"Binary Val Accuracy (train mode)\", \"Binary Val Accuracy (eval mode)\", \"Test Accuracy (train mode)\", \"Test Accuracy (eval mode)\"],\n",
    "    ylabels = [\"Accuracy (%)\", \"Accuracy (%)\", \"Accuracy (%)\", \"Accuracy (%)\", \"Accuracy (%)\", \"Accuracy (%)\", \"Accuracy (%)\", \"Accuracy (%)\"],\n",
    "    ylim=0.0,  # adjust as needed,\n",
    "    ylim_max=1.0,\n",
    "    baseline=True\n",
    "):\n",
    "    ylims=[[ylim, ylim_max], [ylim, ylim_max], [ylim, ylim_max], [ylim, ylim_max], [ylim, ylim_max], [ylim, ylim_max], [ylim, ylim_max], [ylim, ylim_max]]\n",
    "    fig, axs = plt.subplots(4, 2, figsize=(14, 18))\n",
    "    axs = axs.flatten()  # Flatten to loop easily\n",
    "\n",
    "    # Use a colormap to get as many unique colors as needed\n",
    "    colors = plt.cm.get_cmap('viridis', len(extensions))  # tab20 has 20 distinct colors\n",
    "    # cmap = get_cmap(\"plasma\")  # or any other sequential colormap\n",
    "    # colors = [cmap(i / n) for i in range(n)]\n",
    "    \n",
    "    for idx, (metric, title, ylabel, ylim) in enumerate(zip(metrics, titles, ylabels, ylims)):\n",
    "        ax = axs[idx]\n",
    "\n",
    "        if baseline:\n",
    "            baseline = get_baseline(dataset, metric)\n",
    "            if baseline is not None:\n",
    "                ax.axhline(baseline, color='red', linestyle='--', label=f'Baseline: {baseline:.2f}')\n",
    "        \n",
    "        for i_idx, i in enumerate(extensions):\n",
    "            try:\n",
    "                json_data = load_json(f\"../results/{dataset}{file}{i}.json\")\n",
    "                data = json_data.get(metric, [])\n",
    "                ax.plot([4*i for i in range(len(data))], [(d * 100) for d in data], label=f\"{label}={i}\", color=colors(i_idx))\n",
    "            except:\n",
    "                pass\n",
    "    \n",
    "        # ax.set_title(title)\n",
    "        ax.set_xlabel(\"Epoch\", fontsize=24)\n",
    "        ax.set_ylabel(ylabel, fontsize=24)\n",
    "        ax.set_ylim(ylim)\n",
    "        ax.set_title(title)\n",
    "        ax.set_xticks(np.arange(0, 4*len(data) + 1, 50))\n",
    "        ax.set_yticks(np.arange(ylim[0], ylim[1]+1, 5))\n",
    "        ax.tick_params(axis='both', which='major', labelsize=18)\n",
    "        ax.grid(True, linestyle='--', alpha=0.5)\n",
    "        ax.legend(loc='lower right', fontsize=14)  # Change position if needed\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dfd1daae-5ff3-40f8-ad11-5bb832fdbcac",
   "metadata": {},
   "source": [
    "### Different Tau"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf7e1e05-857d-47eb-bd14-5806074cc127",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# Test models with varying tau\n",
    "# For earlier versions, different tau values are used, e.g. _different_tau_02_ uses [1, 5, 10, 20, 50, 80, 100, 200]\n",
    "plot_metrics(\n",
    "    dataset=\"mnist\",\n",
    "    file=\"_different_tau_final_\",\n",
    "    extensions=[1, 3, 10, 30, 100],\n",
    "    label=\"Tau\",\n",
    "    ylim=90,\n",
    "    ylim_max=100,\n",
    "    baseline=None\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58027c1c-5a63-421a-99de-67f2ff86094f",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# Fmnist different tau plot\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "extensions=[1, 5, 10, 20, 50, 200]\n",
    "# extensions=[1, 3, 10, 30, 100, 300]\n",
    "# extensions = [0.1, 1, 5, 10, 20, 50, 200]\n",
    "label=r\"$\\tau$\"\n",
    "\n",
    "metric=\"valid_acc_eval_mode\"\n",
    "\n",
    "dataset=\"cifar10\"\n",
    "file=\"_different_tau_03_\"\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7, 5))\n",
    "\n",
    "# Unique color for each extension\n",
    "colors = plt.cm.get_cmap('viridis', len(extensions))\n",
    "\n",
    "for i_idx, ext in enumerate(extensions):\n",
    "    try:\n",
    "        json_data = load_json(f\"../results/{dataset}{file}{ext}.json\")\n",
    "        data = json_data.get(metric, [])\n",
    "        ax.plot(\n",
    "            range(len(data)),\n",
    "            [(d * 100) for d in data],\n",
    "            label=f\"{label}={ext}\",\n",
    "            color=colors(i_idx)\n",
    "        )\n",
    "    except:\n",
    "        pass\n",
    "\n",
    "# ax.set_title(title, fontsize=16)\n",
    "ax.set_xlabel(\"Epoch\", fontsize=22)\n",
    "ax.set_ylabel(\"Accuracy (%)\", fontsize=22)\n",
    "ax.set_ylim(75, 95)\n",
    "ax.set_xticks(np.arange(0, len(data) + 1, 50))\n",
    "ax.set_yticks(np.arange(75, 95+1, 5))\n",
    "ax.tick_params(axis='both', which='major', labelsize=17)\n",
    "ax.grid(True, linestyle='--', alpha=0.5)\n",
    "ax.legend(loc='lower right', fontsize=14)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"fmnist_different_tau.pdf\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70480170-1e23-4855-8391-f8b5a7575ca4",
   "metadata": {},
   "source": [
    "### Augmentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23f7b1ba-8078-4bcb-a89e-827bb6fdca9f",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# Augmentation vs No Augmentation\n",
    "# Only run with validation dataset\n",
    "import os\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "RESULTS_DIR = \"../results\"\n",
    "METRIC = \"valid_acc_eval_mode\"\n",
    "NUM_RUNS = 5\n",
    "\n",
    "DATASETS = [\n",
    "    \"mnist\",\n",
    "    \"fmnist\",\n",
    "    \"kmnist\",\n",
    "    \"emnist_balanced\",\n",
    "    \"emnist_letters\",\n",
    "    \"qmnist\"\n",
    "]\n",
    "\n",
    "BASELINES = {\n",
    "    \"baseline_01\": \"No Aug\",\n",
    "    \"baseline_04\": \"With Aug\"\n",
    "}\n",
    "\n",
    "def load_average_best_accuracy(dataset, baseline):\n",
    "    best_vals = []\n",
    "    for run in range(NUM_RUNS):\n",
    "        filename = f\"{dataset}_{baseline}_{run}.json\"\n",
    "        filepath = os.path.join(RESULTS_DIR, filename)\n",
    "        if not os.path.isfile(filepath):\n",
    "            print(f\"Warning: {filename} not found. Skipping.\")\n",
    "            continue\n",
    "        with open(filepath, \"r\") as f:\n",
    "            data = json.load(f)\n",
    "        if METRIC not in data:\n",
    "            raise ValueError(f\"{METRIC} not found in {filename}\")\n",
    "        best_vals.append(max(data[METRIC]))\n",
    "    \n",
    "    if not best_vals:\n",
    "        raise ValueError(f\"No valid runs found for {dataset} {baseline}\")\n",
    "    \n",
    "    return sum(best_vals) / len(best_vals)\n",
    "\n",
    "def plot_comparison(results):\n",
    "    datasets = list(results.keys())[:-1]\n",
    "    n = len(datasets)\n",
    "    bar_width = 0.2\n",
    "    x = np.arange(n)\n",
    "\n",
    "    acc_01 = [results[ds][\"baseline_01\"] for ds in datasets]\n",
    "    acc_04 = [results[ds][\"baseline_04\"] for ds in datasets]\n",
    "\n",
    "    plt.bar(x - bar_width/2, [100*a for a in acc_01], width=bar_width, label=\"No Aug\", color=\"skyblue\")\n",
    "    plt.bar(x + bar_width/2, [100*a for a in acc_04], width=bar_width, label=\"With Aug\", color=\"orange\")\n",
    "\n",
    "    print(datasets)\n",
    "    datasets=[\"Digits\", \"Fashion\", \"Kuzushiji\", \"EMNIST-Bal\", \"EMNIST-Let\"]\n",
    "    plt.xticks(x, datasets, rotation=45)\n",
    "    ax = plt.gca()\n",
    "    for label in ax.get_xticklabels():\n",
    "        label.set_ha('right')\n",
    "    plt.ylabel(\"Accuracy (%)\", fontsize=22)\n",
    "    # plt.xlabel(\"Dataset\", fontsize=22)\n",
    "    # plt.title(\"Baseline Comparison (Averaged over 5 Runs)\")\n",
    "    plt.ylim(80, 100)\n",
    "    plt.yticks(np.arange(80, 100+1, 5), fontsize=18)\n",
    "    plt.xticks(fontsize=18)\n",
    "    plt.legend(loc='lower right')\n",
    "    plt.grid(True, axis='y', linestyle='--', alpha=0.5)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"augmentation.pdf\")\n",
    "    plt.show()\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    results = {}\n",
    "    for dataset in DATASETS:\n",
    "        results[dataset] = {}\n",
    "        for baseline in BASELINES:\n",
    "            results[dataset][baseline] = load_average_best_accuracy(dataset, baseline)\n",
    "    \n",
    "    plot_comparison(results)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dfa21b8e-988e-4188-9b22-5c75ded85477",
   "metadata": {},
   "source": [
    "### Dropout vs. No Dropout"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe85fba6-7079-40fc-85be-86ada2eb5e25",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# Dropout vs No Dropout\n",
    "import os\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "RESULTS_DIR = \"../results\"\n",
    "METRIC = \"test_bin_acc_eval_mode\"\n",
    "EVAL_METRIC = \"bin_acc_eval_mode\"\n",
    "NUM_RUNS = 3\n",
    "\n",
    "DATASETS = [\n",
    "    \"mnist\",\n",
    "    \"fmnist\",\n",
    "    \"kmnist\",\n",
    "    \"emnist_balanced\",\n",
    "    \"emnist_letters\",\n",
    "    \"qmnist\"\n",
    "]\n",
    "\n",
    "BASELINES = {\n",
    "    \"baseline_final\": \"No Dropout\",\n",
    "    #\"baseline_04\": \"With Aug\"\n",
    "}\n",
    "\n",
    "DROPOUTS = {\n",
    "    \"groupsum_dropout_final_0.1\": \"Dropout 0.1\",\n",
    "    \"groupsum_dropout_final_0.3\": \"Dropout 0.3\",\n",
    "    \"groupsum_dropout_final_0.5\": \"Dropout 0.5\",\n",
    "    \"groupsum_dropout_final_0.7\": \"Dropout 0.7\",\n",
    "    \"groupsum_dropout_final_0.9\": \"Dropout 0.9\"\n",
    "}\n",
    "\n",
    "ALL_CONFIGS = {**BASELINES, **DROPOUTS}\n",
    "\n",
    "def load_baseline_average(dataset, config):\n",
    "    best_vals = []\n",
    "    for run in range(NUM_RUNS):\n",
    "        filename = f\"{dataset}_{config}_{run}.json\"\n",
    "        filepath = os.path.join(RESULTS_DIR, filename)\n",
    "        if not os.path.isfile(filepath):\n",
    "            print(f\"Warning: {filename} not found. Skipping.\")\n",
    "            continue\n",
    "        with open(filepath, \"r\") as f:\n",
    "            data = json.load(f)\n",
    "        if METRIC not in data:\n",
    "            raise ValueError(f\"{METRIC} not found in {filename}\")\n",
    "        best_vals.append(data[METRIC][data[EVAL_METRIC].index(max(data[EVAL_METRIC]))])\n",
    "    if not best_vals:\n",
    "        return None\n",
    "    return sum(best_vals) / len(best_vals)\n",
    "\n",
    "def load_single_dropout(dataset, config):\n",
    "    filename = f\"{dataset}_{config}.json\"\n",
    "    filepath = os.path.join(RESULTS_DIR, filename)\n",
    "    if not os.path.isfile(filepath):\n",
    "        print(f\"Warning: {filename} not found.\")\n",
    "        return None\n",
    "    with open(filepath, \"r\") as f:\n",
    "        data = json.load(f)\n",
    "    if METRIC not in data:\n",
    "        raise ValueError(f\"{METRIC} not found in {filename}\")\n",
    "    return data[METRIC][data[EVAL_METRIC].index(max(data[EVAL_METRIC]))]\n",
    "\n",
    "def plot_comparison(results):\n",
    "    datasets = list(results.keys())\n",
    "    configs = list(ALL_CONFIGS.keys())\n",
    "    n_datasets = len(datasets)\n",
    "    n_configs = len(configs)\n",
    "\n",
    "    bar_width = 0.1\n",
    "    x = np.arange(n_datasets)\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(12, 6))\n",
    "\n",
    "    for i, config in enumerate(configs):\n",
    "        accs = [results[ds].get(config, 0) for ds in datasets]\n",
    "        bar_pos = x + (i - n_configs / 2) * bar_width + bar_width / 2\n",
    "        ax.bar(bar_pos, accs, width=bar_width, label=ALL_CONFIGS[config])\n",
    "\n",
    "    ax.set_xticks(x)\n",
    "    ax.set_xticklabels(datasets, rotation=20)\n",
    "    ax.set_ylabel(\"Test Accuracy\")\n",
    "    ax.set_title(\"Baseline vs Dropout Comparison per Dataset\")\n",
    "    ax.set_ylim(0.8, 1)\n",
    "    ax.legend(loc='lower right')\n",
    "    ax.grid(True, axis=\"y\")\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    results = {}\n",
    "\n",
    "    for dataset in DATASETS:\n",
    "        results[dataset] = {}\n",
    "        for config in ALL_CONFIGS:\n",
    "            if config in BASELINES:\n",
    "                val = load_baseline_average(dataset, config)\n",
    "            else:\n",
    "                val = load_single_dropout(dataset, config)\n",
    "            if val is not None:\n",
    "                results[dataset][config] = val\n",
    "\n",
    "    plot_comparison(results)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d6f3fa0-e691-424b-9c1d-d4aa72e946dd",
   "metadata": {},
   "source": [
    "### Dataset - Different Tau"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3ed88e4-ea83-4a85-856a-b4d8a2f5aff4",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# Different Tau for MNIST datasets\n",
    "import os\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "RESULTS_DIR = \"../results\"\n",
    "METRIC = \"test_bin_acc_eval_mode\"\n",
    "EVAL_METRIC = \"bin_acc_eval_mode\"\n",
    "NUM_RUNS = 3\n",
    "\n",
    "DATASETS = [\n",
    "    \"mnist\",\n",
    "    \"fmnist\",\n",
    "    \"kmnist\",\n",
    "    \"emnist_balanced\",\n",
    "    \"emnist_letters\",\n",
    "    \"qmnist\"\n",
    "]\n",
    "\n",
    "BASELINES = {\n",
    "    # \"baseline_final\": \"Tau 10\",\n",
    "    #\"baseline_04\": \"With Aug\"\n",
    "}\n",
    "TAUS = {\n",
    "    \"different_tau_final_1\": \"Tau 1\",\n",
    "    \"different_tau_final_3\": \"Tau 3\",\n",
    "    \"different_tau_final_10\": \"Tau 10\",\n",
    "    \"different_tau_final_30\": \"Tau 30\",\n",
    "    \"different_tau_final_100\": \"Tau 100\"\n",
    "}\n",
    "\n",
    "ALL_CONFIGS = {**BASELINES, **TAUS}\n",
    "\n",
    "def load_baseline_average(dataset, config):\n",
    "    best_vals = []\n",
    "    for run in range(NUM_RUNS):\n",
    "        filename = f\"{dataset}_{config}_{run}.json\"\n",
    "        filepath = os.path.join(RESULTS_DIR, filename)\n",
    "        if not os.path.isfile(filepath):\n",
    "            print(f\"Warning: {filename} not found. Skipping.\")\n",
    "            continue\n",
    "        with open(filepath, \"r\") as f:\n",
    "            data = json.load(f)\n",
    "        if METRIC not in data:\n",
    "            raise ValueError(f\"{METRIC} not found in {filename}\")\n",
    "        best_vals.append(data[METRIC][data[EVAL_METRIC].index(max(data[EVAL_METRIC]))])\n",
    "        \n",
    "    if not best_vals:\n",
    "        return None\n",
    "    return sum(best_vals) / len(best_vals)\n",
    "\n",
    "def load_single_dropout(dataset, config):\n",
    "    filename = f\"{dataset}_{config}.json\"\n",
    "    filepath = os.path.join(RESULTS_DIR, filename)\n",
    "    if not os.path.isfile(filepath):\n",
    "        print(f\"Warning: {filename} not found.\")\n",
    "        return None\n",
    "    with open(filepath, \"r\") as f:\n",
    "        data = json.load(f)\n",
    "    if METRIC not in data:\n",
    "        raise ValueError(f\"{METRIC} not found in {filename}\")\n",
    "    return data[METRIC][data[EVAL_METRIC].index(max(data[EVAL_METRIC]))]\n",
    "\n",
    "def plot_comparison(results):\n",
    "    datasets = list(results.keys())\n",
    "    configs = list(ALL_CONFIGS.keys())\n",
    "    n_datasets = len(datasets)\n",
    "    n_configs = len(configs)\n",
    "\n",
    "    bar_width = 0.1\n",
    "    x = np.arange(n_datasets)\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(12, 6))\n",
    "\n",
    "    for i, config in enumerate(configs):\n",
    "        accs = [results[ds].get(config, 0) for ds in datasets]\n",
    "        bar_pos = x + (i - n_configs / 2) * bar_width + bar_width / 2\n",
    "        ax.bar(bar_pos, accs, width=bar_width, label=ALL_CONFIGS[config])\n",
    "\n",
    "    ax.set_xticks(x)\n",
    "    ax.set_xticklabels(datasets, rotation=20)\n",
    "    ax.set_ylabel(\"Best Valid Accuracy\")\n",
    "    ax.set_title(\"Different Tau Comparison per Dataset\")\n",
    "    ax.set_ylim(0.6, 1)\n",
    "    ax.legend(loc='lower right')\n",
    "    ax.grid(True, axis=\"y\")\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    results = {}\n",
    "\n",
    "    for dataset in DATASETS:\n",
    "        results[dataset] = {}\n",
    "        for config in ALL_CONFIGS:\n",
    "            if config in BASELINES:\n",
    "                val = load_baseline_average(dataset, config)\n",
    "            else:\n",
    "                val = load_single_dropout(dataset, config)\n",
    "            if val is not None:\n",
    "                results[dataset][config] = val\n",
    "\n",
    "    plot_comparison(results)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca1a3995-62af-4e30-a828-fc65cd830ae3",
   "metadata": {},
   "source": [
    "### Per-Neuron Impact for different Dropout & Tau"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37403900-5c5b-499d-aecb-b74a2d322ef1",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# Per-Neuron Impact for different dropout\n",
    "\n",
    "# ['f1_score', 'random_score', 'precision', 'recall', 'sensitivity', 'specificity', 'accuracy', 'mcc']\n",
    "score = 'random_score'\n",
    "\n",
    "ds = 'fmnist'\n",
    "\n",
    "dropouts = [0.1, 0.3, 0.5, 0.7, 0.9]\n",
    "metrics = [\"test_accuracies_train\", \"test_accuracies_eval\", \"val_accuracies_train\", \"val_accuracies_eval\"]\n",
    "titles = [\"Test Accuracy (train mode)\", \"Test Accuracy (eval mode)\", \"Validation Accuracy (train mode)\", \"Validation Accuracy (eval model)\"]\n",
    "ylabels = [\"Accuracy\", \"Accuracy\", \"Accuracy\", \"Accuracy\"]\n",
    "\n",
    "ds_metrics = [\"valid_acc_train_mode\", \"valid_acc_eval_mode\", \"valid_acc_train_mode\", \"valid_acc_eval_mode\"]\n",
    "\n",
    "ylim = 0.0\n",
    "ylims=[[ylim, 1], [ylim, 1], [ylim, 1], [ylim, 1]]\n",
    "fig, axs = plt.subplots(2, 2, figsize=(14, 10))\n",
    "axs = axs.flatten()  # Flatten to loop easily\n",
    "\n",
    "# Use a colormap to get as many unique colors as needed\n",
    "colors = plt.cm.get_cmap('tab20', len(dropouts))  # tab20 has 20 distinct colors\n",
    "\n",
    "for idx, (metric, title, ylabel, ylim) in enumerate(zip(metrics, titles, ylabels, ylims)):\n",
    "    ax = axs[idx]\n",
    "    \n",
    "    baseline = get_baseline(ds, ds_metrics[idx])\n",
    "    if baseline is not None:\n",
    "        ax.axhline(baseline, color='red', linestyle='--', label=f'Baseline: {100*baseline:.2f}')\n",
    "    \n",
    "    for i, dp in enumerate(dropouts):\n",
    "        json_data = load_json(f\"../analysis/output_masking/{ds}_output_masking_02_{dp}_{score}.json\")\n",
    "        data = json_data.get(metric, [])\n",
    "        n_values = json_data.get('n_values', [])\n",
    "        ax.plot(n_values, data, label=f\"Dropout: {dp}\", color=colors(i))\n",
    "\n",
    "    ax.set_title(title)\n",
    "    ax.set_xlabel(\"Epoch\")\n",
    "    ax.set_ylabel(ylabel)\n",
    "    ax.set_ylim(ylim)\n",
    "    ax.set_xlim(0, 64000/10)\n",
    "    ax.grid(True)\n",
    "    ax.plot((6400/(1.0*100))*np.array([0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95]), \n",
    "            0.01 * np.array([0.0, 0.8763020833333334, 1.8364583333333333, 2.8979166666666667, 4.039322916666666, 5.28046875, 6.61953125, 8.127083333333333, 9.761979166666666, 11.597916666666666, 13.59453125, 15.870052083333333, 18.434114583333333, 21.391666666666666, 24.85703125, 29.007552083333334, 34.17213541666667, 40.723697916666666, 50.251041666666666, 66.04192708333333]),\n",
    "           color=\"red\", label=\"Approx. Pruning\")\n",
    "    ax.legend(fontsize='small', loc='lower right')  # Change position if needed\n",
    "    \n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4254ea7e-1ecd-416f-ae0c-1c9cacd980b4",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# Per-Neuron Impact for different tau\n",
    "\n",
    "# ['f1_score', 'random_score', 'precision', 'recall', 'sensitivity', 'specificity', 'accuracy', 'mcc']\n",
    "score = 'random_score'\n",
    "\n",
    "ds = 'mnist'\n",
    "\n",
    "taus = [0.05, 0.5, 3, 8, 20, 50, 100, 200]\n",
    "# taus = [10, 20, 50, 80, 100, 200]\n",
    "metrics = [\"val_accuracies_eval\"]\n",
    "titles = [\"Validation Accuracy (eval model)\"]\n",
    "ylabels = [\"Accuracy/Pruning (%)\"]\n",
    "\n",
    "ds_metrics = [\"valid_acc_eval_mode\"]\n",
    "\n",
    "ylim = 0.0\n",
    "ylims=[[ylim, 100], [ylim, 1], [ylim, 1], [ylim, 1]]\n",
    "fig, axs = plt.subplots(1, 1, figsize=(7, 5))\n",
    "# axs = axs.flatten()  # Flatten to loop easily\n",
    "\n",
    "# Use a colormap to get as many unique colors as needed\n",
    "colors = plt.cm.get_cmap('tab20', len(taus))  # tab20 has 20 distinct colors\n",
    "\n",
    "colors = plt.cm.get_cmap('viridis', len(taus))  # tab20 has 20 distinct colors\n",
    "\n",
    "for idx, (metric, title, ylabel, ylim) in enumerate(zip(metrics, titles, ylabels, ylims)):\n",
    "    ax = axs\n",
    "    \n",
    "    # baseline = get_baseline(ds, ds_metrics[idx])\n",
    "    # if baseline is not None:\n",
    "    #     ax.axhline(baseline, color='red', linestyle='--', label=f'Baseline: {baseline:.2f}')\n",
    "    \n",
    "    for i, tau in enumerate(taus):\n",
    "        json_data = load_json(f\"../analysis/output_masking/{ds}_output_masking_01_{tau}_{score}.json\")\n",
    "        data = json_data.get(metric, [])\n",
    "        n_values = json_data.get('n_values', [])\n",
    "        ax.plot(n_values, [100*d for d in data], label=r\"$\\tau=$\" + f\"{tau}\", color=colors(i))\n",
    "\n",
    "    # ax.set_title(title)\n",
    "    ax.set_xlabel(\"# Neurons\", fontsize=20)\n",
    "    ax.set_ylabel(ylabel, fontsize=20)\n",
    "    ax.set_ylim(ylim)\n",
    "    ax.set_xlim(0, 64000/10)\n",
    "    ax.set_xticks(np.arange(0, 6400, 1000))\n",
    "    ax.set_yticks(np.arange(0, 100+1, 20))\n",
    "    ax.tick_params(axis='both', which='major', labelsize=17)\n",
    "    ax.grid(True)\n",
    "    ax.plot((64000/(10*100))*np.array([0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95]), \n",
    "            np.array([0.0, 0.8763020833333334, 1.8364583333333333, 2.8979166666666667, 4.039322916666666, 5.28046875, 6.61953125, 8.127083333333333, 9.761979166666666, 11.597916666666666, 13.59453125, 15.870052083333333, 18.434114583333333, 21.391666666666666, 24.85703125, 29.007552083333334, 34.17213541666667, 40.723697916666666, 50.251041666666666, 66.04192708333333]),\n",
    "           color=\"red\", label=\"Pruning Possibility\")\n",
    "    ax.legend(fontsize=12, loc='lower left')  # Change position if needed\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"mnist_pruning_tau.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4fe48abf-8f12-4ded-a7bc-74a38065ef80",
   "metadata": {},
   "source": [
    "### Per-Class Accuracies MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54fc5229-96b9-479b-8086-3570bb5074de",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import confusion_matrix\n",
    "\n",
    "taus = [0.05, 0.1, 0.5, 1, 3, 5, 8, 10, 20, 50]\n",
    "class_accuracies = []\n",
    "\n",
    "for tau in taus:\n",
    "    json_data = load_json(f\"../results/mnist_different_tau_01_{tau}.json\")\n",
    "\n",
    "    # Extract true and predicted labels\n",
    "    true_labels = [label for label, _ in json_data[\"raw_values\"][-10000:]]\n",
    "    predicted_labels = [np.argmax(logits) for _, logits in json_data[\"raw_values\"][-10000:]]\n",
    "\n",
    "    # Compute confusion matrix\n",
    "    conf_matrix = confusion_matrix(true_labels, predicted_labels, labels=range(10))\n",
    "\n",
    "    # Per-class accuracy\n",
    "    with np.errstate(divide='ignore', invalid='ignore'):\n",
    "        per_class_acc = np.diag(conf_matrix) / np.maximum(1, conf_matrix.sum(axis=1))\n",
    "    class_accuracies.append(per_class_acc)\n",
    "\n",
    "# Convert to NumPy array: shape (num_taus, num_classes)\n",
    "class_accuracies = np.array(class_accuracies)\n",
    "\n",
    "# Plot\n",
    "plt.figure(figsize=(12, 6))\n",
    "x = np.arange(10)  # class labels\n",
    "\n",
    "colors = plt.cm.get_cmap('inferno', len(taus))  # tab20 has 20 distinct colors\n",
    "\n",
    "for idx, tau in enumerate(taus):\n",
    "    plt.plot(x, class_accuracies[idx], marker='o', label=f\"Tau: {tau}\", color=colors(idx))\n",
    "\n",
    "plt.xticks(x)\n",
    "plt.xlabel(\"Class Label\")\n",
    "plt.ylabel(\"Accuracy\")\n",
    "plt.title(\"Per-Class Accuracy for Different Tau Values\")\n",
    "plt.ylim(0.6, 1)\n",
    "plt.grid(True)\n",
    "plt.legend(title=\"Tau\", bbox_to_anchor=(1.05, 1), loc='upper left')  # outside plot\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "084ed632-7acf-45fd-9d67-e0189fa17f96",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# Per-Class accuracies: Barplot\n",
    "import numpy as np\n",
    "import csv\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import confusion_matrix\n",
    "\n",
    "# taus = [0.05, 0.1, 0.5, 1, 3, 5, 8, 10, 20, 50]\n",
    "taus = [0.1, 1, 5, 10, 20, 50]\n",
    "class_accuracies = []\n",
    "\n",
    "for tau in taus:\n",
    "    json_data = load_json(f\"../results/mnist_different_tau_01_{tau}.json\")\n",
    "    # json_data = load_json(f\"../results/mnist_different_tau_final_{tau}.json\")\n",
    "\n",
    "    true_labels = [label for label, _ in json_data[\"raw_values\"][-10000:]]\n",
    "    predicted_labels = [np.argmax(logits) for _, logits in json_data[\"raw_values\"][-10000:]]\n",
    "\n",
    "    conf_matrix = confusion_matrix(true_labels, predicted_labels, labels=range(10))\n",
    "\n",
    "    with np.errstate(divide='ignore', invalid='ignore'):\n",
    "        per_class_acc = np.diag(conf_matrix) / np.maximum(1, conf_matrix.sum(axis=1))\n",
    "    class_accuracies.append(per_class_acc)\n",
    "\n",
    "# Convert to NumPy array\n",
    "class_accuracies = np.array(class_accuracies)  # shape (num_taus, num_classes)\n",
    "\n",
    "##############################\n",
    "csv_filename = \"mnist_class_specific_acc.csv\"\n",
    "with open(csv_filename, \"w\", newline=\"\") as f:\n",
    "    writer = csv.writer(f)\n",
    "    # header\n",
    "    writer.writerow([\"class\"] + [f\"tau_{tau}\" for tau in taus])\n",
    "    # rows\n",
    "    for class_idx in range(10):\n",
    "        row = [class_idx] + [100*class_accuracies[i,class_idx] for i in range(len(taus))]\n",
    "        writer.writerow(row)\n",
    "\n",
    "print(f\"CSV saved to {csv_filename}\")\n",
    "##############################\n",
    "\"\"\"\n",
    "# Plotting as grouped bar chart\n",
    "num_classes = 10\n",
    "bar_width = 0.12\n",
    "x = np.arange(num_classes)  # class labels\n",
    "\n",
    "plt.figure(figsize=(14, 6))\n",
    "colors=plt.cm.get_cmap('viridis', len(taus))\n",
    "\n",
    "for i, tau in enumerate(taus):\n",
    "    offset = (i - len(taus)/2) * bar_width + bar_width/2\n",
    "    label = f\"$\\\\tau$={tau:.1f}\"\n",
    "    plt.bar(x + offset, [100*c for c in class_accuracies[i]], width=bar_width, label=label, color=colors(i))\n",
    "\n",
    "plt.xticks(x, labels=[str(i) for i in range(num_classes)])\n",
    "plt.xlabel(\"Class Label\", fontsize=24)\n",
    "plt.ylabel(\"Accuracy (%)\", fontsize=24)\n",
    "plt.yticks(np.arange(70, 101, 10))\n",
    "plt.tick_params(axis='both', which='major', labelsize=20)\n",
    "# plt.title(\"Per-Class Accuracy for Different Tau Values\")\n",
    "plt.ylim(70, 100)\n",
    "plt.legend(loc='lower left', fontsize=20)\n",
    "plt.grid(True, axis='y', linestyle='--', alpha=0.5)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"mnist_class_specific_acc.pdf\")\n",
    "plt.show()\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c59412b-4071-4e24-9006-e4095ebd827a",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# Per-Class accuracies with Deviation\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import confusion_matrix\n",
    "\n",
    "# taus = [0.05, 0.1, 0.5, 1, 3, 5, 8, 10, 20, 50, 200]\n",
    "taus = [0.1, 1, 5, 10, 20, 50]\n",
    "mean_accuracies = []\n",
    "std_accuracies = []\n",
    "\n",
    "for tau in taus:\n",
    "    json_data = load_json(f\"../results/mnist_different_tau_01_{tau}.json\")\n",
    "\n",
    "    print(json_data)\n",
    "\n",
    "    true_labels = [label for label, _ in json_data[\"raw_values\"][-10000:]]\n",
    "    predicted_labels = [np.argmax(logits) for _, logits in json_data[\"raw_values\"][-10000:]]\n",
    "\n",
    "    conf_matrix = confusion_matrix(true_labels, predicted_labels, labels=range(10))\n",
    "\n",
    "    with np.errstate(divide='ignore', invalid='ignore'):\n",
    "        per_class_acc = np.diag(conf_matrix) / np.maximum(1, conf_matrix.sum(axis=1))\n",
    "\n",
    "    mean_accuracies.append(np.mean(per_class_acc))\n",
    "    std_accuracies.append(np.std(per_class_acc))\n",
    "\n",
    "# Plot\n",
    "plt.figure(figsize=(10, 6))\n",
    "x = np.arange(len(taus))\n",
    "colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(taus)))\n",
    "\n",
    "plt.bar(x, [m*100 for m in mean_accuracies], yerr=[s*100 for s in std_accuracies], capsize=5, color=colors[0], alpha=0.9)\n",
    "\n",
    "plt.xticks(x, labels=[f\"{tau:g}\" for tau in taus], rotation=45)\n",
    "plt.xlabel(\"Tau\", fontsize=24)\n",
    "plt.ylabel(\"Accuracy (%)\", fontsize=24)\n",
    "# ax.set_xticks(np.arange(0, len(data) + 1, 50))\n",
    "plt.yticks(np.arange(75, 100+1, 10))\n",
    "plt.tick_params(axis='both', which='major', labelsize=22)\n",
    "# plt.title(\"Mean Per-Class Accuracy ± StdDev vs Tau\")\n",
    "plt.ylim(75, 100)\n",
    "plt.grid(True, axis='y', linestyle='--', alpha=0.5)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"mnist_per_class_accuracy.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3caaccc-57c7-47e1-a06c-ac6e9b8333b4",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# Per-Class accuracies heatmap/table\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from sklearn.metrics import confusion_matrix\n",
    "\n",
    "# Define tau values\n",
    "taus = [0.1, 1, 5, 10, 20, 50]\n",
    "\n",
    "# Initialize matrix to store per-class accuracies\n",
    "per_class_acc_matrix = []\n",
    "\n",
    "for tau in taus:\n",
    "    # Load your JSON data\n",
    "    json_data = load_json(f\"../results/mnist_different_tau_01_{tau}.json\")\n",
    "\n",
    "    # Extract true and predicted labels\n",
    "    true_labels = [label for label, _ in json_data[\"raw_values\"][-10000:]]\n",
    "    predicted_labels = [np.argmax(logits) for _, logits in json_data[\"raw_values\"][-10000:]]\n",
    "\n",
    "    # Compute confusion matrix\n",
    "    conf_matrix = confusion_matrix(true_labels, predicted_labels, labels=range(10))\n",
    "\n",
    "    # Compute per-class accuracy\n",
    "    with np.errstate(divide='ignore', invalid='ignore'):\n",
    "        per_class_acc = np.diag(conf_matrix) / np.maximum(1, conf_matrix.sum(axis=1))\n",
    "\n",
    "    # Append result\n",
    "    per_class_acc_matrix.append(per_class_acc)\n",
    "\n",
    "# Convert to numpy array and transpose to shape (10 classes × len(taus))\n",
    "acc_array = np.array(per_class_acc_matrix).T * 100  # Convert to percentage\n",
    "\n",
    "# Compute mean accuracy per tau\n",
    "mean_accuracies = np.mean(acc_array, axis=0)\n",
    "\n",
    "# Optional: Plot heatmap\n",
    "plt.figure(figsize=(8, 6))\n",
    "heatmap = sns.heatmap(\n",
    "    acc_array,\n",
    "    annot=True,\n",
    "    fmt=\".1f\",\n",
    "    cmap=\"viridis\",\n",
    "    xticklabels=[f\"{tau:g}\" for tau in taus],\n",
    "    yticklabels=[f\"{i}\" for i in range(10)],\n",
    "    cbar_kws={'label': 'Accuracy (%)'},\n",
    "    annot_kws={\"size\": 14},\n",
    "    cbar=False  # Set to True if you want to include the colorbar\n",
    ")\n",
    "\n",
    "# Style the plot\n",
    "plt.xlabel(r\"$\\tau$\", fontsize=24)\n",
    "plt.ylabel(\"Class\", fontsize=24)\n",
    "plt.tick_params(axis='both', which='major', labelsize=18)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"mnist_per_class_accuracy_heatmap.pdf\")\n",
    "plt.show()\n",
    "\n",
    "# Generate LaTeX table with mean row at the top\n",
    "latex_table = \"\\\\begin{tabular}{c|\" + \"c\" * len(taus) + \"}\\n\"\n",
    "\n",
    "# Header\n",
    "latex_table += \"Class & \" + \" & \".join([f\"$\\\\tau={tau:g}$\" for tau in taus]) + \" \\\\\\\\\\n\"\n",
    "latex_table += \"\\\\hline\\n\"\n",
    "\n",
    "# Mean row\n",
    "latex_table += \"\\\\textbf{Mean} & \" + \" & \".join([f\"{mean_accuracies[j]:.1f}\" for j in range(len(taus))]) + \" \\\\\\\\\\n\"\n",
    "latex_table += \"\\\\hline\\n\"\n",
    "\n",
    "# Per-class rows\n",
    "for class_idx in range(acc_array.shape[0]):\n",
    "    row = f\"{class_idx} & \" + \" & \".join([f\"{acc_array[class_idx, j]:.1f}\" for j in range(acc_array.shape[1])])\n",
    "    latex_table += row + \" \\\\\\\\\\n\"\n",
    "\n",
    "latex_table += \"\\\\end{tabular}\"\n",
    "\n",
    "# Print LaTeX table to copy into LaTeX document\n",
    "print(\"\\nLaTeX Table:\\n\")\n",
    "print(latex_table)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14c206c6-a75f-4c2c-b630-718857a34659",
   "metadata": {},
   "source": [
    "### Synthetic Dataset - Increasing Number of Classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecb6ac5c-9633-4b70-b08d-7bc787c36ffb",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# Increasing number of classes, synthetic dataset\n",
    "# For more information about specific models, see the experiment configuration.\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "start = 2\n",
    "end = 2000\n",
    "n = 10\n",
    "classes = [int(start * (end / start) ** (i / (n - 1))) for i in range(n)]\n",
    "metrics = [\"test_acc_eval_mode_\"]\n",
    "titles = [\"Binary Test Accuracy (eval mode)\"]\n",
    "\n",
    "# Function to load data into a dictionary\n",
    "def load_metric_dict(prefix):\n",
    "    metric_dict = {metric: [] for metric in metrics}\n",
    "    for num in classes:\n",
    "        try:\n",
    "            json_data = load_json(f\"../results_temp/{prefix}_{num}.json\")\n",
    "            for metric in metrics:\n",
    "                data = json_data.get(metric, [])\n",
    "                metric_dict[metric].append(data)\n",
    "        except:\n",
    "            pass\n",
    "    return metric_dict\n",
    "\n",
    "# Load both datasets\n",
    "# All DLGN have a last layer size of 64'000\n",
    "metric_dict_01 = load_metric_dict(\"synthetic_harder_difflogic_01\") # 64k, tau=10\n",
    "metric_dict_02 = load_metric_dict(\"synthetic_harder_difflogic_02\") # 128k, tau=10\n",
    "metric_dict_03 = load_metric_dict(\"synthetic_harder_difflogic_03\") # 256k, tau=10\n",
    "metric_dict_04 = load_metric_dict(\"synthetic_harder_difflogic_04\") # 64k, tau=1\n",
    "metric_dict_05 = load_metric_dict(\"synthetic_harder_difflogic_05\") # 128k, tau=1\n",
    "metric_dict_06 = load_metric_dict(\"synthetic_harder_difflogic_06\") # 256k, tau=1\n",
    "metric_dict_mlp = load_metric_dict(\"synthetic_harder_05\") # MLP: 05: 512, 07: 1024, 08: 256\n",
    "metric_dict_08 = load_metric_dict(\"synthetic_harder_difflogic_07\") # Codebook: 64k\n",
    "metric_dict_09 = load_metric_dict(\"synthetic_harder_difflogic_09\") # Codebook: 256k\n",
    "\n",
    "######Save to csv###########\n",
    "# Scale all values to percentage (0-100)\n",
    "def scale_percentage(metric_dict):\n",
    "    return [m*100 if m is not None else np.nan for m in metric_dict[\"test_acc_eval_mode_\"]]\n",
    "\n",
    "df = pd.DataFrame({\n",
    "    \"classes\": classes,\n",
    "    \"Small_tau1\": scale_percentage(metric_dict_04),\n",
    "    \"Small_tau10\": scale_percentage(metric_dict_01),\n",
    "    \"Big_tau1\": scale_percentage(metric_dict_06),\n",
    "    \"Big_tau10\": scale_percentage(metric_dict_03),\n",
    "    \"MLP_Medium\": scale_percentage(metric_dict_mlp)\n",
    "})\n",
    "\n",
    "# Save CSV\n",
    "out_path = \"synthetic_num_classes_multi.csv\"\n",
    "df.to_csv(out_path, index=False)\n",
    "#############################\n",
    "\n",
    "# Plot each metric in a separate subplot\n",
    "fig, axs = plt.subplots(1, 1, figsize=(7, 5))\n",
    "# axs = axs.flatten()\n",
    "\n",
    "for idx, metric in enumerate(metrics):\n",
    "    ax = axs\n",
    "\n",
    "    # Plot experimental results\n",
    "    ax.plot(classes[:len(metric_dict_04[metric])], [m*100 for m in metric_dict_04[metric]], label=r\"Small, $\\tau=1$\", linestyle=':', color='tab:purple', marker='o', markersize=5)\n",
    "    ax.plot(classes[:len(metric_dict_01[metric])], [m*100 for m in metric_dict_01[metric]], label=r\"Small, $\\tau=10$\", linestyle='-', color='tab:purple', marker='o', markersize=5)\n",
    "    # ax.plot(classes[:len(metric_dict_02[metric])], [m*100 for m in metric_dict_02[metric]], label=\"128000, Tau: 10\", linestyle='--', color='tab:blue', marker='o', markersize=4)\n",
    "    ax.plot(classes[:len(metric_dict_06[metric])], [m*100 for m in metric_dict_06[metric]], label=r\"Big, $\\tau=1$\", linestyle=':', color='tab:blue', marker='o', markersize=5)\n",
    "    ax.plot(classes[:len(metric_dict_03[metric])], [m*100 for m in metric_dict_03[metric]], label=r\"Big, $\\tau=10$\", linestyle='-', color='tab:blue', marker='o', markersize=5)\n",
    "    # ax.plot(classes[:len(metric_dict_05[metric])],[m*100 for m in metric_dict_05[metric]], label=\"128000, Tau: 1\", linestyle='--', color='tab:purple', marker='o', markersize=4)\n",
    "    ax.plot(classes[:len(metric_dict_mlp[metric])], [m*100 for m in metric_dict_mlp[metric]], label=r\"MLP (Medium)\", linestyle='-', color='tab:red', marker='o', markersize=5)\n",
    "    # ax.plot(classes[:len(metric_dict_08[metric])], [m*100 for m in metric_dict_08[metric]], label=\"DistanceLayer (64k)\", linestyle=':', color='tab:orange', marker='o', markersize=4)\n",
    "    # ax.plot(classes[:len(metric_dict_09[metric])], [m*100 for m in metric_dict_09[metric]], label=\"DistanceLayer (256k)\", linestyle='-', color='tab:orange', marker='o', markersize=4)\n",
    "    \n",
    "    ax.set_xlabel(\"# Classes\", fontsize=22)\n",
    "    ax.set_ylabel(\"Accuracy (%)\", fontsize=22)\n",
    "    ax.set_ylim(0.0, 1)\n",
    "    ax.set_xscale('log')\n",
    "    ax.legend()\n",
    "    # ax.set_xticks(np.arange(0, 2001, 500))\n",
    "    ax.set_yticks(np.arange(0, 100+1, 25))\n",
    "    ax.tick_params(axis='both', which='major', labelsize=18)\n",
    "    ax.grid(True, linestyle='--', alpha=0.5)\n",
    "    ax.legend(loc='lower left', fontsize=15)  # Change position if needed\n",
    "\n",
    "plt.tight_layout()\n",
    "# plt.suptitle(\"Increasing Number of Classes\", fontsize=16, y=1.02)\n",
    "# plt.savefig(\"synthetic_num_classes.pdf\")\n",
    "plt.savefig(\"synthetic_num_classes_2.png\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ceccb0e3-c922-4913-8595-c064767fbcbe",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# Increasing number of classes, different output layer sizes\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "start = 2\n",
    "end = 2000\n",
    "n = 10\n",
    "classes = [int(start * (end / start) ** (i / (n - 1))) for i in range(n)]\n",
    "metrics = [\"test_acc_eval_mode_\"]\n",
    "titles = [\"Binary Test Accuracy (eval mode)\"]\n",
    "\n",
    "# Function to load data into a dictionary\n",
    "def load_metric_dict(prefix):\n",
    "    metric_dict = {metric: [] for metric in metrics}\n",
    "    for num in classes:\n",
    "        try:\n",
    "            json_data = load_json(f\"../results_temp/{prefix}_{num}.json\")\n",
    "            for metric in metrics:\n",
    "                data = json_data.get(metric, [])\n",
    "                metric_dict[metric].append(data)\n",
    "        except:\n",
    "            pass\n",
    "    return metric_dict\n",
    "\n",
    "# Load both datasets\n",
    "# Backbone kept at 64'000 neurons per layer\n",
    "metric_dict_01 = load_metric_dict(\"synthetic_harder_difflogic_01\") # 64k, tau=10\n",
    "metric_dict_04 = load_metric_dict(\"synthetic_harder_difflogic_04\") # 64k, tau=1\n",
    "metric_dict_07 = load_metric_dict(\"synthetic_harder_05\") # MLP: 05: 512, 07: 1024, 08: 256\n",
    "metric_dict_10 = load_metric_dict(\"synthetic_harder_difflogic_10\") # 16k, tau=1\n",
    "metric_dict_11 = load_metric_dict(\"synthetic_harder_difflogic_11\") # 16k, tau=10\n",
    "metric_dict_12 = load_metric_dict(\"synthetic_harder_difflogic_12\") # 16k, tau=100\n",
    "metric_dict_13 = load_metric_dict(\"synthetic_harder_difflogic_13\") # 256k, tau=1\n",
    "metric_dict_14 = load_metric_dict(\"synthetic_harder_difflogic_14\") # 256k, tau=10\n",
    "metric_dict_15 = load_metric_dict(\"synthetic_harder_difflogic_15\") # 256k, tau=100\n",
    "\n",
    "# Plot each metric in a separate subplot\n",
    "fig, axs = plt.subplots(1, 1, figsize=(7, 5))\n",
    "# axs = axs.flatten()\n",
    "\n",
    "for idx, metric in enumerate(metrics):\n",
    "    ax = axs\n",
    "\n",
    "    # Plot experimental results\n",
    "\n",
    "    # 16k Output Dimension\n",
    "    # ax.plot(classes[:len(metric_dict_10[metric])], [m*100 for m in metric_dict_10[metric]], label=r\"16k ($\\tau=1$)\", linestyle=':', color='tab:red', marker='o')\n",
    "    ax.plot(classes[:len(metric_dict_11[metric])], [m*100 for m in metric_dict_11[metric]], label=r\"16k ($\\tau=10$)\", linestyle='--', color='tab:red', marker='o')\n",
    "    # ax.plot(classes[:len(metric_dict_12[metric])], [m*100 for m in metric_dict_12[metric]], label=r\"16k ($\\tau=100$)\", linestyle='-', color='tab:red', marker='o')\n",
    "\n",
    "    # 64k Output Dimension\n",
    "    # ax.plot(classes[:len(metric_dict_04[metric])], [m*100 for m in metric_dict_04[metric]], label=\"64000, Tau: 1\", linestyle='--', color='tab:green', marker='o')\n",
    "    ax.plot(classes[:len(metric_dict_01[metric])], [m*100 for m in metric_dict_01[metric]], label=r\"64k ($\\tau=10$)\", linestyle='-', color='tab:purple', marker='o')\n",
    "    \n",
    "    # 256k Output Dimension\n",
    "    # ax.plot(classes[:len(metric_dict_13[metric])], [m*100 for m in metric_dict_13[metric]], label=r\"256k ($\\tau=1$)\", linestyle=':', color='tab:orange', marker='o')\n",
    "    ax.plot(classes[:len(metric_dict_14[metric])], [m*100 for m in metric_dict_14[metric]], label=r\"256k ($\\tau=10$)\", linestyle='--', color='tab:orange', marker='o')\n",
    "    # ax.plot(classes[:len(metric_dict_15[metric])], [m*100 for m in metric_dict_15[metric]], label=r\"256k ($\\tau=100$)\", linestyle='-', color='tab:orange', marker='o')\n",
    "\n",
    "    # MLP\n",
    "    ax.plot(classes[:len(metric_dict_07[metric])], [m*100 for m in metric_dict_07[metric]], label=\"MLP (Medium)\", linestyle='-', color='tab:blue', marker='o')\n",
    "    \n",
    "    ax.set_xlabel(\"# Classes\", fontsize=22)\n",
    "    ax.set_ylabel(\"Accuracy (%)\", fontsize=22)\n",
    "    ax.set_ylim(0.0, 1)\n",
    "    ax.set_xscale('log')\n",
    "    ax.legend()\n",
    "    # ax.set_xticks(np.arange(0, 2001, 500))\n",
    "    ax.set_yticks(np.arange(0, 100+1, 25))\n",
    "    ax.tick_params(axis='both', which='major', labelsize=18)\n",
    "    ax.grid(True, linestyle='--', alpha=0.5)\n",
    "    ax.legend(loc='lower left', fontsize=15)  # Change position if needed\n",
    "\n",
    "plt.tight_layout()\n",
    "# plt.suptitle(\"Increasing Number of Classes\", fontsize=16, y=1.02)\n",
    "plt.savefig(\"synthetic_num_classes_output_dimension.pdf\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a27187f4-5b89-4ed6-b571-a6fa54fa05b0",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# Intro plot, different number of classes\n",
    "# Group different tau values to one plot.\n",
    "import csv\n",
    "start = 2\n",
    "end = 2000\n",
    "n = 10\n",
    "classes = [int(start * (end / start) ** (i / (n - 1))) for i in range(n)]\n",
    "metrics = [\"test_acc_eval_mode_\"]\n",
    "titles = [\"Binary Test Accuracy (eval mode)\"]\n",
    "\n",
    "# Function to load data into a dictionary\n",
    "def load_metric_dict(prefix):\n",
    "    metric_dict = {metric: [] for metric in metrics}\n",
    "    for num in classes:\n",
    "        try:\n",
    "            json_data = load_json(f\"../results/{prefix}_{num}.json\")\n",
    "            for metric in metrics:\n",
    "                data = json_data.get(metric, [])\n",
    "                metric_dict[metric].append(data)\n",
    "        except:\n",
    "            for metric in metrics:\n",
    "                metric_dict[metric].append(None)\n",
    "    return metric_dict\n",
    "def load_metric_dict_temp(prefix):\n",
    "    metric_dict = {metric: [] for metric in metrics}\n",
    "    for num in classes:\n",
    "        try:\n",
    "            json_data = load_json(f\"../results_temp/{prefix}_{num}.json\")\n",
    "            for metric in metrics:\n",
    "                data = json_data.get(metric, [])\n",
    "                metric_dict[metric].append(data)\n",
    "        except:\n",
    "            pass\n",
    "    return metric_dict\n",
    "def get_best_difflogic_metric_dict(dicts):\n",
    "    best_dict = {metric: [] for metric in metrics}\n",
    "    for i in range(len(classes)):\n",
    "        for metric in metrics:\n",
    "            vals = [d[metric][i] for d in dicts if d[metric][i] is not None]\n",
    "            best_val = max(vals) if vals else None\n",
    "            best_dict[metric].append(best_val)\n",
    "    return best_dict\n",
    "\n",
    "# Load all difflogic variant datasets\n",
    "difflogic_dicts_16k = [\n",
    "    load_metric_dict_temp(\"synthetic_harder_difflogic_10\"),\n",
    "    load_metric_dict_temp(\"synthetic_harder_difflogic_11\"),\n",
    "    load_metric_dict_temp(\"synthetic_harder_difflogic_12\")\n",
    "]\n",
    "difflogic_dicts_256k = [\n",
    "    load_metric_dict_temp(\"synthetic_harder_difflogic_13\"),\n",
    "    load_metric_dict_temp(\"synthetic_harder_difflogic_14\"),\n",
    "    load_metric_dict_temp(\"synthetic_harder_difflogic_15\")\n",
    "]\n",
    "\n",
    "difflogic_dicts_64k = [\n",
    "    load_metric_dict_temp(\"synthetic_harder_difflogic_01\"),\n",
    "    load_metric_dict_temp(\"synthetic_harder_difflogic_04\")\n",
    "]\n",
    "\n",
    "# Load MLP (unmodified)\n",
    "# metric_dict_mlp = load_metric_dict(\"imagenet_num_classes_final_mlp_512\")\n",
    "\n",
    "best_difflogic_dict_16k = get_best_difflogic_metric_dict(difflogic_dicts_16k)\n",
    "best_difflogic_dict_64k = get_best_difflogic_metric_dict(difflogic_dicts_64k)\n",
    "best_difflogic_dict_256k = get_best_difflogic_metric_dict(difflogic_dicts_256k)\n",
    "\n",
    "# Plotting\n",
    "fig, axs = plt.subplots(1, 1, figsize=(7, 5))\n",
    "ax = axs\n",
    "\n",
    "for idx, metric in enumerate(metrics):\n",
    "    # Plot best difflogic results\n",
    "    \n",
    "    ax.plot(classes[:len(best_difflogic_dict_16k[metric])], \n",
    "            [m*100 if m is not None else 0 for m in best_difflogic_dict_16k[metric]], \n",
    "            label=\"DLGN (16k)\", linestyle='-', color='tab:green', marker='o', markersize=6, linewidth=3)\n",
    "    \n",
    "    ax.plot(classes[:len(best_difflogic_dict_64k[metric])], \n",
    "            [m*100 if m is not None else 0 for m in best_difflogic_dict_64k[metric]], \n",
    "            label=\"DLGN (64k)\", linestyle='-', color='tab:orange', marker='o', markersize=6, linewidth=3)\n",
    "    \n",
    "    ax.plot(classes[:len(best_difflogic_dict_256k[metric])], \n",
    "            [m*100 if m is not None else 0 for m in best_difflogic_dict_256k[metric]], \n",
    "            label=\"DLGN (256k)\", linestyle='-', color='tab:blue', marker='o', markersize=6, linewidth=3)\n",
    "    \n",
    "    ax.set_xlabel(\"# Classes\", fontsize=22)\n",
    "    ax.set_ylabel(\"Accuracy (%)\", fontsize=22)\n",
    "    ax.set_ylim(0.0, 1)\n",
    "    ax.set_xscale('log')\n",
    "    ax.legend()\n",
    "    # ax.set_xticks(np.arange(0, 2001, 500))\n",
    "    ax.set_yticks(np.arange(0, 100+1, 25))\n",
    "    ax.tick_params(axis='both', which='major', labelsize=18)\n",
    "    ax.grid(True, linestyle='--', alpha=0.5)\n",
    "    ax.legend(loc='lower left', fontsize=15)  # Change position if needed\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"synthetic_num_classes_different_output_sizes.png\")\n",
    "plt.show()\n",
    "# Combine all data\n",
    "csv_filename = \"synthetic_num_classes_different_output_sizes.csv\"\n",
    "with open(csv_filename, mode='w', newline='') as f:\n",
    "    writer = csv.writer(f)\n",
    "    # Header row\n",
    "    header = [\"Classes\", \"DLGN_16k\", \"DLGN_64k\", \"DLGN_256k\"]\n",
    "    writer.writerow(header)\n",
    "    \n",
    "    # Write each row\n",
    "    for i, num_classes in enumerate(classes):\n",
    "        row = [\n",
    "            num_classes,\n",
    "            best_difflogic_dict_16k[\"test_acc_eval_mode_\"][i]*100 if i < len(best_difflogic_dict_16k[\"test_acc_eval_mode_\"]) and best_difflogic_dict_16k[\"test_acc_eval_mode_\"][i] is not None else 0,\n",
    "            best_difflogic_dict_64k[\"test_acc_eval_mode_\"][i]*100 if i < len(best_difflogic_dict_64k[\"test_acc_eval_mode_\"]) and best_difflogic_dict_64k[\"test_acc_eval_mode_\"][i] is not None else 0,\n",
    "            best_difflogic_dict_256k[\"test_acc_eval_mode_\"][i]*100 if i < len(best_difflogic_dict_256k[\"test_acc_eval_mode_\"]) and best_difflogic_dict_256k[\"test_acc_eval_mode_\"][i] is not None else 0\n",
    "        ]\n",
    "        writer.writerow(row)\n",
    "\n",
    "print(f\"Data saved to {csv_filename}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40a4876f-ee69-4aeb-b47a-a0328d6ecc4e",
   "metadata": {},
   "source": [
    "### Best Tau Heatmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea31e863-d8a7-4d13-827f-5f193ae5cb31",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# Heatmap with best tau value of number of classes and outputs per class\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.colors import LogNorm, rgb_to_hsv\n",
    "import numpy as np\n",
    "import os\n",
    "# Define ranges for num, k, tau\n",
    "start = 2\n",
    "end = 2000\n",
    "n = 10  # number of steps\n",
    "num_values = [int(start * (end / start) ** (i / (n - 1))) for i in range(n)]\n",
    "\n",
    "k_start = 2\n",
    "k_end = 2000\n",
    "k_steps = 10\n",
    "k_values = [int(k_start * (k_end / k_start) ** (i / (k_steps - 1))) for i in range(k_steps)]\n",
    "\n",
    "tau_values = [0.1, 1, 10, 100]\n",
    "selected_metric = \"test_acc_eval_mode_\"\n",
    "\n",
    "def load_json(filepath):\n",
    "    with open(filepath, 'r') as f:\n",
    "        return json.load(f)\n",
    "\n",
    "# Collect best tau for each (num, k)\n",
    "heatmap_data = []\n",
    "for num in num_values:\n",
    "    row = []\n",
    "    for k in k_values:\n",
    "        best_metric = -float('inf')\n",
    "        best_tau = float('nan')\n",
    "        for tau in tau_values:\n",
    "            filename = f\"synthetic_tau_analysis_02_{num}_{k}_{tau}.json\"\n",
    "            if num == 2000 and k == 2000:\n",
    "                init_path = \"../results\"\n",
    "            else:\n",
    "                init_path = \"../results_temp\"\n",
    "            filepath = os.path.join(init_path, filename)\n",
    "            try:\n",
    "                data = load_json(filepath)\n",
    "                metric_value = data.get(selected_metric, None)\n",
    "                if isinstance(metric_value, list):\n",
    "                    metric_value = metric_value[-1] if metric_value else float('nan')\n",
    "                if metric_value is not None and metric_value > best_metric:\n",
    "                    best_metric = metric_value\n",
    "                    best_tau = tau\n",
    "            except Exception:\n",
    "                continue\n",
    "        row.append(best_tau)\n",
    "    heatmap_data.append(row)\n",
    "\n",
    "# Convert to numpy array\n",
    "heatmap_array = np.array(heatmap_data)\n",
    "\n",
    "# Use STIXGeneral for Times-like font\n",
    "mpl.rcParams['font.family'] = 'STIXGeneral'\n",
    "\n",
    "# Example heatmap array (replace with your actual data)\n",
    "# heatmap_array = ...\n",
    "\n",
    "plt.figure(figsize=(14, 10))  # wider and taller\n",
    "\n",
    "valid_vals = heatmap_array[~np.isnan(heatmap_array)]\n",
    "norm = LogNorm(vmin=np.min(valid_vals), vmax=np.max(valid_vals))\n",
    "\n",
    "cmap = plt.cm.viridis\n",
    "im = plt.imshow(\n",
    "    heatmap_array, aspect='auto', cmap=cmap,\n",
    "    interpolation='nearest', norm=norm\n",
    ")\n",
    "\n",
    "# Annotate tau values with adaptive text color\n",
    "for i in range(len(num_values)):\n",
    "    for j in range(len(k_values)):\n",
    "        tau_val = heatmap_array[i, j]\n",
    "        if not (tau_val is None or np.isnan(tau_val)):\n",
    "            val_str = str(int(tau_val)) if float(tau_val).is_integer() else str(tau_val)\n",
    "\n",
    "            # Get background color for this cell\n",
    "            rgba = cmap(norm(tau_val))\n",
    "            r, g, b = rgba[:3]\n",
    "            # Compute luminance (perceived brightness)\n",
    "            luminance = 0.299*r + 0.587*g + 0.114*b\n",
    "            text_color = 'black' if luminance > 0.5 else 'white'\n",
    "\n",
    "            plt.text(j, i, val_str, ha='center', va='center',\n",
    "                     color=text_color, fontsize=48)\n",
    "\n",
    "# Axis ticks\n",
    "plt.xticks(ticks=range(len(k_values)), labels=k_values, fontsize=34, rotation=45)\n",
    "plt.yticks(ticks=range(len(num_values)), labels=num_values, fontsize=34)\n",
    "plt.xlabel(\"# Output Neurons per Class\", fontsize=70)\n",
    "plt.ylabel(\"# of Classes\", fontsize=70)\n",
    "\n",
    "plt.gca().invert_yaxis()\n",
    "plt.tick_params(axis='both', which='major', labelsize=48)\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "# Save high-quality PDF\n",
    "plt.savefig(\"synthetic_best_tau.pdf\", dpi=300)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ac92c76-3e4b-496d-9791-8a828d3fedbb",
   "metadata": {},
   "source": [
    "### IMAGENET-32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2cb7109-433d-43bb-bf3a-c87ec785195b",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# Increasing number of classes on Imagennet-32\n",
    "import numpy as np\n",
    "import json\n",
    "\n",
    "def load_json(filepath):\n",
    "    with open(filepath, \"r\") as f:\n",
    "        return json.load(f)\n",
    "\n",
    "start = 5\n",
    "end = 1000\n",
    "n = 5\n",
    "classes = [int(start * (end / start) ** (i / (n - 1))) for i in range(n)]\n",
    "print(classes)\n",
    "metrics = [\"test_acc_eval_mode_\"]\n",
    "titles = [\"Binary Test Accuracy (eval mode)\"]\n",
    "\n",
    "# Function to load data into a dictionary\n",
    "def load_metric_dict(prefix):\n",
    "    metric_dict = {metric: [] for metric in metrics}\n",
    "    for num in classes:\n",
    "        try:\n",
    "            json_data = load_json(f\"../results/{prefix}_{num}.json\")\n",
    "            for metric in metrics:\n",
    "                data = json_data.get(metric, [])\n",
    "                metric_dict[metric].append(data)\n",
    "        except:\n",
    "            pass\n",
    "    return metric_dict\n",
    "def load_metric_dict_temp(prefix):\n",
    "    metric_dict = {metric: [] for metric in metrics}\n",
    "    for num in classes:\n",
    "        try:\n",
    "            json_data = load_json(f\"../results_temp/{prefix}_{num}.json\")\n",
    "            for metric in metrics:\n",
    "                data = json_data.get(metric, [])\n",
    "                metric_dict[metric].append(data)\n",
    "        except:\n",
    "            pass\n",
    "    return metric_dict\n",
    "\n",
    "# Load both datasets\n",
    "metric_dict_00 = load_metric_dict_temp(\"imagenet_num_classes_00\") # 256k, tau=0.1\n",
    "metric_dict_01 = load_metric_dict_temp(\"imagenet_num_classes_01\") # 256k, tau=1\n",
    "metric_dict_02 = load_metric_dict_temp(\"imagenet_num_classes_02\") # 256k, tau=10\n",
    "metric_dict_03 = load_metric_dict_temp(\"imagenet_num_classes_03\") # 256k, tau=100\n",
    "metric_dict_06 = load_metric_dict_temp(\"imagenet_num_classes_06\") # 512k, tau=1\n",
    "metric_dict_07 = load_metric_dict_temp(\"imagenet_num_classes_07\") # 512k, tau=10\n",
    "metric_dict_08 = load_metric_dict_temp(\"imagenet_num_classes_08\") # 512k, tau=100\n",
    "metric_dict_09 = load_metric_dict_temp(\"imagenet_num_classes_10\") # Codebook: 256k\n",
    "# metric_dict_05 = load_metric_dict_temp(\"imagenet_num_classes_05\")\n",
    "\n",
    "metric_dict_05 = load_metric_dict(\"imagenet_num_classes_final_mlp_512\") # MLP: 512\n",
    "\n",
    "###########################\n",
    "def scale_percentage(metric_dict):\n",
    "    return [m*100 if m is not None else np.nan for m in metric_dict[\"test_acc_eval_mode_\"]]\n",
    "\n",
    "df = pd.DataFrame({\n",
    "    \"classes\": classes,\n",
    "    \"Big_tau1\": scale_percentage(metric_dict_01),\n",
    "    \"Big_tau10\": scale_percentage(metric_dict_02),\n",
    "    \"Big_tau100\": scale_percentage(metric_dict_03),\n",
    "    \"MLP_Medium\": scale_percentage(metric_dict_05)\n",
    "})\n",
    "\n",
    "csv_path = \"imagenet_num_classes.csv\"\n",
    "df.to_csv(csv_path, index=False)\n",
    "print(f\"Saved CSV to {csv_path}\")\n",
    "print(df.head())\n",
    "###########################\n",
    "\n",
    "print(metric_dict_00)\n",
    "\n",
    "# Plot each metric in a separate subplot\n",
    "fig, axs = plt.subplots(1, 1, figsize=(7, 5))\n",
    "\n",
    "for idx, metric in enumerate(metrics):\n",
    "    ax = axs\n",
    "\n",
    "    # Plot experimental results\n",
    "    # ax.plot(classes[:len(metric_dict_00[metric])], [m*100 for m in metric_dict_00[metric]], label=r\"Big, $\\tau=0.1$\", linestyle=':', color='tab:blue', marker='o', markersize=5)\n",
    "    ax.plot(classes[:len(metric_dict_01[metric])], [m*100 for m in metric_dict_01[metric]], label=r\"Big, $\\tau=1$\", linestyle=':', color='tab:blue', marker='o', markersize=5)\n",
    "    ax.plot(classes[:len(metric_dict_02[metric])], [m*100 for m in metric_dict_02[metric]], label=r\"Big, $\\tau=10$\", linestyle='--', color='tab:blue', marker='o', markersize=5)\n",
    "    ax.plot(classes[:len(metric_dict_03[metric])], [m*100 for m in metric_dict_03[metric]], label=r\"Big, $\\tau=100$\", linestyle='-', color='tab:blue', marker='o', markersize=5)\n",
    "    # ax.plot(classes[:len(metric_dict_06[metric])], [m*100 for m in metric_dict_06[metric]], label=\"512000 (Tau: 1)\", linestyle='-', color='tab:red', marker='o', markersize=5)\n",
    "    # ax.plot(classes[:len(metric_dict_07[metric])], [m*100 for m in metric_dict_07[metric]], label=\"512000 (Tau: 10)\", linestyle='--', color='tab:red', marker='o', markersize=4)\n",
    "    # ax.plot(classes[:len(metric_dict_08[metric])], [m*100 for m in metric_dict_08[metric]], label=\"512000 (Tau: 100)\", linestyle=':', color='tab:red', marker='o', markersize=4)\n",
    "    # ax.plot(classes[:len(metric_dict_09[metric])], [m*100 for m in metric_dict_09[metric]], label=\"Distance 256k\", linestyle=':', color='tab:orange', marker='o', markersize=4)\n",
    "    ax.plot(classes[:len(metric_dict_05[metric])], [m*100 for m in metric_dict_05[metric]], label=r\"MLP (Medium)\", linestyle='-', color='tab:red', marker='o', markersize=5)\n",
    "    \n",
    "    ax.set_xlabel(\"# Classes\", fontsize=22)\n",
    "    ax.set_ylabel(\"Accuracy (%)\", fontsize=22)\n",
    "    ax.set_ylim(0.0, 75)\n",
    "    ax.set_xscale('log')\n",
    "    ax.legend()\n",
    "    # ax.set_xticks(np.arange(0, 2001, 500))\n",
    "    ax.set_yticks(np.arange(0, 75+1, 25))\n",
    "    ax.tick_params(axis='both', which='major', labelsize=18)\n",
    "    ax.grid(True, linestyle='--', alpha=0.5)\n",
    "    ax.legend(loc='upper right', fontsize=15)  # Change position if needed\n",
    "\n",
    "plt.tight_layout()\n",
    "# plt.suptitle(\"Increasing Number of Classes\", fontsize=16, y=1.02)\n",
    "plt.savefig(\"imagenet_num_classes_presentation.pdf\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b665aaf-11c2-4068-afc4-5522665e6fa1",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Different Tau for Distance output\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "def get_baseline(dataset, metric):\n",
    "    baselines = []\n",
    "    for i in range(3):\n",
    "        baseline_data = load_json(f\"../results/{dataset}_baseline_05_{i}.json\")  # Adjust path as needed\n",
    "        baselines.append(baseline_data[metric + \"_\"])\n",
    "    return 100* sum(baselines)/len(baselines)\n",
    "\n",
    "extensions=[0.1, 1, 5, 20]\n",
    "label=r\"$\\tau$\"\n",
    "metric=\"valid_acc_eval_mode\"\n",
    "\n",
    "dataset=\"cifar10\"\n",
    "file=\"_distance_output_02_\"\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7, 5))\n",
    "\n",
    "# Unique color for each extension\n",
    "colors = plt.cm.get_cmap('viridis', len(extensions))\n",
    "\n",
    "for i_idx, ext in enumerate(extensions):\n",
    "    try:\n",
    "        json_data = load_json(f\"../results/{dataset}{file}{ext}.json\")\n",
    "        data = json_data.get(metric, [])\n",
    "        ax.plot(\n",
    "            range(len(data)),\n",
    "            [(d * 100) for d in data],\n",
    "            label=f\"{label}={ext}\",\n",
    "            color=colors(i_idx)\n",
    "        )\n",
    "    except:\n",
    "        pass\n",
    "\n",
    "baseline = get_baseline(dataset, metric)\n",
    "ax.axhline(baseline, color='red', linestyle='--', label=f'Baseline: {baseline:.2f}')\n",
    "\n",
    "# ax.set_title(title, fontsize=16)\n",
    "ax.set_xlabel(\"Epoch\", fontsize=22)\n",
    "ax.set_ylabel(\"Accuracy (%)\", fontsize=22)\n",
    "ax.set_ylim(30, 60)\n",
    "ax.set_xticks(np.arange(0, len(data) + 1, 50))\n",
    "ax.set_yticks(np.arange(30, 60+1, 10))\n",
    "ax.tick_params(axis='both', which='major', labelsize=17)\n",
    "ax.grid(True, linestyle='--', alpha=0.5)\n",
    "ax.legend(loc='lower right', fontsize=14)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"cifar10_distance.pdf\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "272a6883-87a6-4c02-952f-2b8782b52ba8",
   "metadata": {},
   "source": [
    "### Tau-Dropout Heatmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5423dbe9-9815-4a73-89a3-c8f2c5b14c64",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "mpl.rcParams['font.family'] = 'STIXGeneral'\n",
    "\n",
    "# Define tau and dropout values\n",
    "tau_start = 0.1\n",
    "tau_end = 100\n",
    "tau_values = [tau_start * (tau_end / tau_start) ** (i / 5) for i in range(6)]\n",
    "dropout_values = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9]\n",
    "\n",
    "# Metric to visualize\n",
    "selected_metric = \"test_acc_eval_mode_\"\n",
    "\n",
    "# Helper: load JSON safely\n",
    "def load_json(filepath):\n",
    "    with open(filepath, 'r') as f:\n",
    "        return json.load(f)\n",
    "\n",
    "# Initialize grid for heatmap data\n",
    "heatmap_data = []\n",
    "for drop in dropout_values:\n",
    "    row = []\n",
    "    for tau in tau_values:\n",
    "        filename = f\"synthetic_tau_dropout_02_{tau}_{drop}.json\"\n",
    "        filepath = os.path.join(\"../results_temp\", filename)\n",
    "        try:\n",
    "            data = load_json(filepath)\n",
    "            metric_value = data.get(selected_metric, None)\n",
    "            if isinstance(metric_value, list):\n",
    "                metric_value = metric_value[-1] if metric_value else float('nan')\n",
    "        except Exception:\n",
    "            metric_value = float('nan')\n",
    "        row.append(metric_value)\n",
    "    heatmap_data.append(row)\n",
    "\n",
    "heatmap_data = np.array(heatmap_data)\n",
    "\n",
    "# Plot heatmap\n",
    "plt.figure(figsize=(14, 10))\n",
    "\n",
    "cmap = plt.cm.viridis\n",
    "im = plt.imshow(heatmap_data, aspect='auto', cmap=cmap, interpolation='nearest')\n",
    "\n",
    "# Annotate each cell with the corresponding value\n",
    "vmin, vmax = np.nanmin(heatmap_data), np.nanmax(heatmap_data)\n",
    "for i in range(len(dropout_values)):\n",
    "    for j in range(len(tau_values)):\n",
    "        value = heatmap_data[i, j]\n",
    "        if not np.isnan(value):\n",
    "            text_str = f\"{100*value:.1f} %\"\n",
    "\n",
    "            # Get background color at this value\n",
    "            rgba = cmap((value - vmin) / (vmax - vmin))\n",
    "            r, g, b = rgba[:3]\n",
    "\n",
    "            # Compute luminance (brightness perception)\n",
    "            luminance = 0.299*r + 0.587*g + 0.114*b\n",
    "            text_color = 'black' if luminance > 0.5 else 'white'\n",
    "\n",
    "            plt.text(j, i, text_str, ha='center', va='center',\n",
    "                     color=text_color, fontsize=40)\n",
    "\n",
    "# Axis labels and ticks\n",
    "plt.xticks(ticks=range(len(tau_values)),\n",
    "           labels=[f\"{v:.2f}\" for v in tau_values],\n",
    "           rotation=45, fontsize=40)\n",
    "plt.yticks(ticks=range(len(dropout_values)),\n",
    "           labels=[f\"{100*v:.0f} %\" for v in dropout_values],\n",
    "           fontsize=40)\n",
    "plt.xlabel(r\"$\\tau$\", fontsize=50)\n",
    "plt.ylabel(\"Dropout\", fontsize=50)\n",
    "\n",
    "plt.tick_params(axis='both', which='major', labelsize=40)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"synthetic_tau_dropout.pdf\", dpi=300)\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
