{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import DatasetDict\n",
    "from typing import Optional, Sequence, Iterable\n",
    "\n",
    "import torch\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from matplotlib import pyplot as plt\n",
    "from torchmetrics import Accuracy, CosineSimilarity, MeanAbsoluteError\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CLASSIFIER_TYPE = \"linear\"\n",
    "\n",
    "RESULTS_DIR = PROJECT_ROOT / \"results\" / \"stitching\" / CLASSIFIER_TYPE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "from pathlib import Path\n",
    "from typing import Dict, Any\n",
    "\n",
    "\n",
    "def load_df(\n",
    "    folder: str,\n",
    "    modality: str,\n",
    "    root_dir: Path,\n",
    ") -> pd.DataFrame:\n",
    "    path = root_dir / folder / \"stitching.tsv\"\n",
    "    if not path.exists():\n",
    "        print(path)\n",
    "        return None\n",
    "    folder_parts = folder.split(\"_\")\n",
    "    dataset = folder_parts[0]\n",
    "    if dataset == \"cifar100\":\n",
    "        dataset = f\"{dataset}_{'coarse' if folder_parts[1] == 'False' else 'fine'}\"\n",
    "    if dataset == \"n24news\":\n",
    "        dataset = f\"{dataset}_{folder_parts[1]}\"\n",
    "    df = pd.read_csv(root_dir / folder / \"stitching.tsv\", sep=\"\\t\")\n",
    "    df[\"dataset\"] = [dataset] * len(df)\n",
    "    df[\"modality\"] = modality\n",
    "\n",
    "    return df\n",
    "\n",
    "\n",
    "FILES_OPTIONS = [\n",
    "    (\"cifar100_False_train_test_1\", \"vision\"),\n",
    "    (\"cifar100_True_train_test_1\", \"vision\"),\n",
    "    (\"cifar10_train_test_1\", \"vision\"),\n",
    "    (\"fashion_mnist_train_test_1\", \"vision\"),\n",
    "    (\"mnist_train_test_1\", \"vision\"),\n",
    "    (\"n24news_image_False_train_test_1\", \"vision\"),\n",
    "    (\"trec_False_train_test_1\", \"text\"),\n",
    "    (\"dbpedia_14_train_test_1\", \"text\"),\n",
    "    (\"n24news_text_False_train_test_1\", \"text\"),\n",
    "]\n",
    "\n",
    "\n",
    "data = pd.concat(\n",
    "    (load_df(folder=folder, modality=modality, root_dir=RESULTS_DIR) for (folder, modality) in FILES_OPTIONS)\n",
    ")\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_data = (\n",
    "    data[data.encoding_space != data.decoding_space]\n",
    "    .drop(\n",
    "        columns=[\n",
    "            \"dataset\",\n",
    "            \"space_type\",\n",
    "            \"encoding_space\",\n",
    "            \"decoding_space\",\n",
    "            \"decoder_type\",\n",
    "            \"modality\",\n",
    "            \"num_anchors\",\n",
    "            \"k\",\n",
    "        ]\n",
    "    )\n",
    "    .groupby([\"projection_func\"])\n",
    "    .aggregate([\"mean\", \"std\"])\n",
    "    .drop(columns=[\"seed\"])\n",
    ")\n",
    "clean_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from latent_invariances.utils.latex import convert_to_latex_with_meanstd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result, table = convert_to_latex_with_meanstd(\n",
    "    data[data.encoding_space != data.decoding_space],\n",
    "    metrics=[\"linear_cka\", \"l1\", \"spearman\", \"mse\", \"cosine_sim\", \"score\"],\n",
    "    rows=[\"projection_func\"],\n",
    ")\n",
    "print(table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result.to_csv(\"stitching.csv\", sep=\"\\t\", index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avoid_funcs = [\"Wasserstein\", \"L3\", \"Normalized Absolute\"]\n",
    "data = data[~data.projection_func.isin(avoid_funcs)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.express as px"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_plot = (\n",
    "    data[\n",
    "        [\n",
    "            \"dataset\",\n",
    "            \"projection_func\",\n",
    "            \"score\",\n",
    "        ]\n",
    "    ]\n",
    "    .groupby([\"dataset\", \"projection_func\"])\n",
    "    .agg([\"mean\", \"std\", \"count\"])\n",
    ")\n",
    "df_plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_plot = pd.read_csv(\"stitching_group_dataset.csv\", sep=\"\\t\")\n",
    "df_plot.dataset.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pandas as pd\n",
    "# import matplotlib.pyplot as plt\n",
    "# import numpy as np\n",
    "\n",
    "# VISION = False\n",
    "\n",
    "# # df_plot = df_plot.sort_values(by=['dataset', 'projection_func'])\n",
    "\n",
    "# if VISION:\n",
    "#     type = \"vision\"\n",
    "#     dataset = [\"cifar10\", \"cifar100_coarse\", \"cifar100_fine\", \"fashion\", \"mnist\", \"n24news_image\"]\n",
    "# else:\n",
    "#     type = \"text\"\n",
    "#     dataset = [\"dbpedia\", \"n24news_text\", \"trec\"]\n",
    "\n",
    "# fig, axs = plt.subplots(1, len(dataset), figsize=(18, 6), sharey=True)\n",
    "\n",
    "# colors = plt.cm.tab10(np.linspace(0, 1, len(df_plot[\"projection_func\"].unique())))\n",
    "\n",
    "# for i, dataset in enumerate(dataset):\n",
    "#     dataset_data = df_plot[df_plot[\"dataset\"] == dataset]\n",
    "#     projection_funcs = dataset_data[\"projection_func\"]\n",
    "#     scores = dataset_data[\"score\"]\n",
    "#     x_pos = np.arange(len(projection_funcs))\n",
    "#     axs[i].bar(x_pos, scores, color=colors)\n",
    "\n",
    "#     axs[i].set_xlabel(\"\")\n",
    "#     axs[i].set_ylabel(\"\")\n",
    "\n",
    "#     axs[i].set_title(f\"{dataset}\")\n",
    "#     axs[i].set_xticks(x_pos)\n",
    "#     axs[i].set_xticklabels(\"\")\n",
    "#     axs[i].set_xticklabels(projection_funcs, rotation=45, ha=\"right\")\n",
    "\n",
    "# # unique_projection_funcs = df_plot['projection_func'].unique()\n",
    "# # legend_handles = [plt.Rectangle((0, 0), 1, 1, color=colors[i]) for i in range(len(unique_projection_funcs))]\n",
    "# # legend = fig.legend(legend_handles, unique_projection_funcs, loc='upper right', bbox_to_anchor=(1.1, 0.9))\n",
    "# # plt.tight_layout(rect=[0.2, 0.3, 0.97, 1])\n",
    "\n",
    "# plt.savefig(f\"plots/hist_{type}_stitching.png\", bbox_inches=\"tight\", transparent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "df_plot = pd.read_csv(\"stitching_group_dataset.csv\", sep=\"\\t\")\n",
    "\n",
    "removed_proj = [\"Normalized Absolute\", \"Standardized Euclidean\", \"L1\", \"L3\", \"Wasserstein\"]\n",
    "df_plot = df_plot[~df_plot.projection_func.isin(removed_proj)]\n",
    "\n",
    "mode = \"all\"  # text ,vision, all\n",
    "text_datasets = [\"trec\", \"n24news_text\", \"dbpedia\"]\n",
    "\n",
    "if mode == \"text\":\n",
    "    df_plot = df_plot[df_plot.dataset.isin(text_datasets)]\n",
    "elif mode == \"vision\":\n",
    "    df_plot = df_plot[~df_plot.dataset.isin(text_datasets)]\n",
    "\n",
    "datasets = df_plot[\"dataset\"].unique()\n",
    "# datasets = ['cifar10', 'cifar100_coarse', 'cifar100_fine', 'fashion', 'mnist', 'n24news_image', 'n24news_text', 'dbpedia', 'trec']\n",
    "projection_funcs = df_plot[\"projection_func\"].unique()\n",
    "color_map = plt.cm.get_cmap(\"tab10\", 7)\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(12, 6))\n",
    "\n",
    "dataset_spacing = 0.55\n",
    "bar_width = 0.4\n",
    "\n",
    "group_width = len(projection_funcs) * bar_width + dataset_spacing\n",
    "\n",
    "x_pos = np.arange(len(datasets)) * group_width\n",
    "\n",
    "for i, dataset in enumerate(datasets):\n",
    "    data_subset = df_plot[df_plot[\"dataset\"] == dataset]\n",
    "\n",
    "    for j, projection_func in enumerate(projection_funcs):\n",
    "        subset = data_subset[data_subset[\"projection_func\"] == projection_func]\n",
    "        if len(subset) > 0:  # Check if there is at least one matching row\n",
    "            score = subset[\"score_mean\"].values[0]\n",
    "            std = subset[\"score_std\"].values[0]\n",
    "\n",
    "            bar_x_pos = x_pos[i] + j * bar_width\n",
    "\n",
    "            ax.bar(bar_x_pos, score, width=bar_width, color=color_map(j), zorder=2)\n",
    "            ax.errorbar(bar_x_pos, score, yerr=std, color=\"black\", capsize=3, zorder=3)\n",
    "\n",
    "# Set up labels and title\n",
    "ax.set_xticks(x_pos + bar_width + (bar_width + 0.2))\n",
    "ax.set_xticklabels(datasets, rotation=45, ha=\"right\")\n",
    "ax.set_ylabel(\"Classification Accuracy\")\n",
    "ax.set_xlabel(\"\")\n",
    "# ax.set_ylim(0, 0.9)\n",
    "ax.grid(axis=\"y\", color=\"gray\", linestyle=\"-\", alpha=0.3, zorder=1)\n",
    "\n",
    "legend_elements = [\n",
    "    plt.Rectangle((0, 0), 1, 1, color=color_map(j), label=projection_func)\n",
    "    for j, projection_func in enumerate(projection_funcs)\n",
    "]\n",
    "# legend_ncol = min(len(projection_funcs), 6)\n",
    "ax.legend(\n",
    "    handles=legend_elements,\n",
    "    bbox_to_anchor=(0.5, 1.1),\n",
    "    loc=\"upper center\",\n",
    "    ncol=len(projection_funcs),\n",
    ")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"plots/hist_{mode}_stitching.svg\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rsvg-convert -f pdf -o 'plots/hist_all_stitching.pdf' 'plots/hist_all_stitching.svg'\n",
    "!rm 'plots/hist_all_stitching'.svg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "px.histogram(\n",
    "    data[(data.encoding_space != data.decoding_space)],\n",
    "    x=\"dataset\",\n",
    "    y=\"score\",\n",
    "    histfunc=\"avg\",\n",
    "    # facet_row=\"dataset\",\n",
    "    # height=800,\n",
    "    width=1200,\n",
    "    barmode=\"group\",\n",
    "    color=\"projection_func\",\n",
    "    color_discrete_sequence=px.colors.qualitative.Plotly,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set(data.encoding_space)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder2type = {\n",
    "    \"roberta-base\": \"T\",\n",
    "    \"vit_base_patch16_224\": \"T\",\n",
    "    \"vit_base_patch16_384\": \"T\",\n",
    "    \"vit_base_resnet50_384\": \"T\",\n",
    "    \"vit_small_patch16_224\": \"T\",\n",
    "    \"xlm-roberta-base\": \"T\",\n",
    "    \"albert-base-v2\": \"T\",\n",
    "    \"bert-base-cased\": \"T\",\n",
    "    \"bert-base-uncased\": \"T\",\n",
    "    \"cspdarknet53\": \"N\",\n",
    "    \"google/electra-base-discriminator\": \"T\",\n",
    "    \"openai/clip-vit-base-patch32\": \"T\",\n",
    "    \"rexnet_100\": \"N\",\n",
    "    \"efficientnet_b1_pruned\": \"N\",\n",
    "    \"regnety_002\": \"N\",\n",
    "    \"cspresnext50\": \"N\",\n",
    "    \"cspdarknet53\": \"N\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data[\"stitching_type\"] = list(zip(data.encoding_space.map(encoder2type) + data.decoding_space.map(encoder2type)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data[\"stitching_type\"] = data[\"stitching_type\"].apply(lambda x: \"\".join(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# stitching_arch =(\n",
    "#     data[\n",
    "#         [\n",
    "#             \"dataset\",\n",
    "#             \"projection_func\",\n",
    "#             \"score\",\n",
    "#             \"stitching_type\"\n",
    "#         ]\n",
    "#     ]\n",
    "#     .groupby([\"dataset\", \"projection_func\", \"stitching_type\"])\n",
    "#     .agg([\"mean\", \"std\"])\n",
    "# )\n",
    "# stitching_arch.to_csv(\"stitching_arch.csv\", index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "filter_dataset = \"cifar100_fine\"\n",
    "\n",
    "stitching_arch = pd.read_csv(\"stitching_arch.csv\")\n",
    "stitching_arch = stitching_arch[stitching_arch[\"dataset\"] == filter_dataset][\n",
    "    [\"projection_func\", \"stitching_type\", \"score_mean\", \"score_std\"]\n",
    "]\n",
    "\n",
    "removed_proj = [\"Normalized Absolute\", \"Standardized Euclidean\", \"L1\", \"L3\", \"Wasserstein\"]\n",
    "stitching_arch = stitching_arch[~stitching_arch.projection_func.isin(removed_proj)]\n",
    "\n",
    "proj_funcs = stitching_arch[\"projection_func\"].unique()\n",
    "stitching_types = stitching_arch[\"stitching_type\"].unique()\n",
    "color_map = plt.cm.get_cmap(\"tab10\", 7)\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(12, 6))\n",
    "\n",
    "proj_spacing = 0.55\n",
    "bar_width = 0.4\n",
    "\n",
    "group_width = len(stitching_types) * bar_width + proj_spacing\n",
    "\n",
    "x_pos = np.arange(len(proj_funcs)) * group_width\n",
    "\n",
    "for i, proj in enumerate(proj_funcs):\n",
    "    data_subset = stitching_arch[stitching_arch[\"projection_func\"] == proj]\n",
    "\n",
    "    for j, stitching_type in enumerate(stitching_types):\n",
    "        matching_rows = data_subset[data_subset[\"stitching_type\"] == stitching_type]\n",
    "        if matching_rows.empty:\n",
    "            continue\n",
    "\n",
    "        score = matching_rows[\"score_mean\"].values[0]\n",
    "        std = matching_rows[\"score_std\"].values[0]\n",
    "\n",
    "        bar_x_pos = x_pos[i] + j * bar_width\n",
    "\n",
    "        ax.bar(bar_x_pos, score, width=bar_width, color=color_map(j), zorder=2)\n",
    "        ax.errorbar(bar_x_pos, score, yerr=std, color=\"black\", capsize=3, zorder=3)\n",
    "\n",
    "ax.set_xticks(x_pos + bar_width + (bar_width / 2))\n",
    "ax.set_xticklabels(proj_funcs, rotation=45, ha=\"right\")\n",
    "ax.set_ylabel(\"Classification Accuracy\")\n",
    "ax.set_xlabel(\"\")\n",
    "# ax.set_ylim(0, 0.6)\n",
    "ax.grid(axis=\"y\", color=\"gray\", linestyle=\"-\", alpha=0.3, zorder=1)\n",
    "\n",
    "legend_elements = [\n",
    "    plt.Rectangle((0, 0), 1, 1, color=color_map(j), label=stitching_type)\n",
    "    for j, stitching_type in enumerate(stitching_types)\n",
    "]\n",
    "legend_ncol = min(len(stitching_types), 6)\n",
    "ax.legend(\n",
    "    handles=legend_elements,\n",
    "    bbox_to_anchor=(0.5, 1.15),\n",
    "    loc=\"upper center\",\n",
    "    ncol=legend_ncol,\n",
    ")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"plots/hist_arch_{filter_dataset}_stitching.svg\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rsvg-convert -f pdf -o 'plots/hist_arch_cifar100_fine_stitching.pdf' 'plots/hist_arch_cifar100_fine_stitching.svg'\n",
    "!rm 'plots/hist_arch_cifar100_fine_stitching'.svg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "px.histogram(\n",
    "    data[(data.encoding_space != data.decoding_space)],\n",
    "    x=\"projection_func\",\n",
    "    y=\"score\",\n",
    "    histfunc=\"avg\",\n",
    "    facet_row=\"dataset\",\n",
    "    height=1200,\n",
    "    # width=1200,\n",
    "    barmode=\"group\",\n",
    "    color=\"stitching_type\",\n",
    "    color_discrete_sequence=px.colors.qualitative.Plotly,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data[\n",
    "    (data.encoding_space != data.decoding_space)\n",
    "    & (data.dataset == \"cifar10\")\n",
    "    & (data.projection_func == \"Absolute\")\n",
    "    & (data.seed == 42)\n",
    "]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "latent-invariances",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
