{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed66b75d",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "__file__ = \".\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eef8fce8",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "from collections import defaultdict\n",
    "from pathlib import Path\n",
    "from typing import Callable, Dict, Optional, Tuple, Type, Union\n",
    "\n",
    "import hydra\n",
    "import pandas as pd\n",
    "import torch\n",
    "import typer\n",
    "from omegaconf import DictConfig, OmegaConf\n",
    "from pytorch_lightning import LightningModule\n",
    "from torch.utils.data import DataLoader\n",
    "from torchmetrics import Accuracy, F1Score, MetricCollection\n",
    "from tqdm import tqdm\n",
    "\n",
    "from nn_core.serialization import NNCheckpointIO\n",
    "\n",
    "from rae.data.vision.datamodule import MyDataModule\n",
    "from rae.modules.enumerations import Output\n",
    "from rae.pl_modules.vision.pl_gclassifier import LightningClassifier\n",
    "\n",
    "logging.getLogger().setLevel(logging.ERROR)\n",
    "\n",
    "\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "BATCH_SIZE = 32\n",
    "CONSIDERED_METRICS = {\n",
    "    \"acc/weighted\": lambda num_classes: Accuracy(average=\"weighted\", num_classes=num_classes),\n",
    "    \"acc/micro\": lambda num_classes: Accuracy(average=\"macro\", num_classes=num_classes),\n",
    "    \"acc/macro\": lambda num_classes: Accuracy(average=\"micro\", num_classes=num_classes),\n",
    "    \"f1/macro\": lambda num_classes: F1Score(average=\"macro\", num_classes=num_classes),\n",
    "    \"f1/micro\": lambda num_classes: F1Score(average=\"micro\", num_classes=num_classes),\n",
    "}\n",
    "\n",
    "EXPERIMENT_ROOT = Path(__file__).parent\n",
    "EXPERIMENT_CHECKPOINTS = EXPERIMENT_ROOT / \"checkpoints\"\n",
    "PREDICTIONS_TSV = EXPERIMENT_ROOT / \"predictions.tsv\"\n",
    "PERFORMANCE_TSV = EXPERIMENT_ROOT / \"performance.tsv\"\n",
    "\n",
    "DATASET_SANITY = {\n",
    "    \"mnist\": (\"rae.data.vision.mnist.MNISTDataset\", \"test\"),\n",
    "    \"fmnist\": (\"rae.data.vision.fmnist.FashionMNISTDataset\", \"test\"),\n",
    "    \"cifar10\": (\"rae.data.vision.cifar10.CIFAR10Dataset\", \"test\"),\n",
    "    \"cifar100\": (\"rae.data.vision.cifar100.CIFAR100Dataset\", \"test\"),\n",
    "}\n",
    "MODEL_SANITY = {\n",
    "    \"abs\": \"rae.modules.vision.resnet.ResNet\",\n",
    "    \"rel\": \"rae.modules.vision.relresnet.RelResNet\",\n",
    "}\n",
    "\n",
    "\n",
    "def parse_checkpoint_id(ckpt: Path) -> str:\n",
    "    return ckpt.with_suffix(\"\").with_suffix(\"\").name\n",
    "\n",
    "\n",
    "# Parse checkpoints tree\n",
    "checkpoints = defaultdict(dict)\n",
    "RUNS = defaultdict(dict)\n",
    "for dataset_abbrv in EXPERIMENT_CHECKPOINTS.iterdir():\n",
    "    checkpoints[dataset_abbrv.name] = defaultdict(list)\n",
    "    RUNS[dataset_abbrv.name] = defaultdict(list)\n",
    "    for model_abbrv in dataset_abbrv.iterdir():\n",
    "        for ckpt in model_abbrv.iterdir():\n",
    "            checkpoints[dataset_abbrv.name][model_abbrv.name].append(ckpt)\n",
    "            RUNS[dataset_abbrv.name][model_abbrv.name].append(parse_checkpoint_id(ckpt))\n",
    "\n",
    "\n",
    "DATASETS = sorted(checkpoints.keys())\n",
    "DATASET_NUM_CLASSES = {\n",
    "    \"mnist\": 10,\n",
    "    \"fmnist\": 10,\n",
    "    \"cifar10\": 10,\n",
    "    \"cifar100\": 100,\n",
    "}\n",
    "MODELS = sorted(checkpoints[DATASETS[0]].keys())\n",
    "\n",
    "\n",
    "def parse_checkpoint(\n",
    "    module_class: Type[LightningModule],\n",
    "    checkpoint_path: Path,\n",
    "    map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,\n",
    ") -> Tuple[LightningModule, DictConfig]:\n",
    "    if checkpoint_path.name.endswith(\".ckpt.zip\"):\n",
    "        checkpoint = NNCheckpointIO.load(path=checkpoint_path, map_location=map_location)\n",
    "        model = module_class._load_model_state(checkpoint=checkpoint, metadata=checkpoint.get(\"metadata\", None))\n",
    "        model.eval()\n",
    "        return (\n",
    "            model,\n",
    "            OmegaConf.create(checkpoint[\"cfg\"]),\n",
    "        )\n",
    "    raise ValueError(f\"Wrong checkpoint: {checkpoint_path}\")\n",
    "\n",
    "\n",
    "def compute_predictions(force_predict: bool) -> pd.DataFrame:\n",
    "    if PREDICTIONS_TSV.exists() and not force_predict:\n",
    "        return pd.read_csv(PREDICTIONS_TSV, sep=\"\\t\", index_col=0)\n",
    "    PREDICTIONS_TSV.unlink(missing_ok=True)\n",
    "\n",
    "    predictions = {x: [] for x in (\"run_id\", \"model_type\", \"dataset_name\", \"sample_idx\", \"pred\", \"target\")}\n",
    "    for dataset_name in (dataset_tqdm := tqdm(DATASETS, leave=True)):\n",
    "        dataset_tqdm.set_description(f\"Dataset ({dataset_name})\")\n",
    "        _, data_cfg = parse_checkpoint(\n",
    "            module_class=LightningClassifier,\n",
    "            checkpoint_path=checkpoints[dataset_name][MODELS[0]][0],\n",
    "            map_location=\"cpu\",\n",
    "        )\n",
    "\n",
    "        datamodule: MyDataModule = hydra.utils.instantiate(data_cfg.nn.data, _recursive_=False)\n",
    "        datamodule.setup()\n",
    "        val_dataset = datamodule.val_datasets[0]\n",
    "        val_dataloder = DataLoader(\n",
    "            val_dataset,\n",
    "            batch_size=BATCH_SIZE,\n",
    "            shuffle=False,\n",
    "            num_workers=8,\n",
    "            persistent_workers=True,\n",
    "        )\n",
    "        assert (\n",
    "            f\"{val_dataset.__module__}.{val_dataset.__class__.__name__}\" == DATASET_SANITY[dataset_name][0]\n",
    "        ), f\"{val_dataset.__module__}.{val_dataset.__class__.__name__}!={DATASET_SANITY[dataset_name][0]}\"\n",
    "        assert val_dataset.split == DATASET_SANITY[dataset_name][1]\n",
    "\n",
    "        for model_type in (model_type_tqdm := tqdm(MODELS, leave=True)):\n",
    "            model_type_tqdm.set_description(f\"Model type ({model_type})\")\n",
    "\n",
    "            for ckpt in (ckpt_tqdm := tqdm(checkpoints[dataset_name][model_type], leave=False)):\n",
    "                run_id = parse_checkpoint_id(ckpt)\n",
    "                ckpt_tqdm.set_description(f\"Run id ({run_id})\")\n",
    "\n",
    "                model, cfg = parse_checkpoint(\n",
    "                    module_class=LightningClassifier,\n",
    "                    checkpoint_path=ckpt,\n",
    "                    map_location=\"cpu\",\n",
    "                )\n",
    "                assert (\n",
    "                    f\"{model.model.__module__}.{model.model.__class__.__name__}\" == MODEL_SANITY[model_type]\n",
    "                ), f\"{model.model.__module__}.{model.model.__class__.__name__}!={MODEL_SANITY[model_type]}\"\n",
    "                model = model.to(DEVICE)\n",
    "                for batch in tqdm(val_dataloder, desc=\"Batch\", leave=False):\n",
    "                    model_out = model(batch[\"image\"].to(DEVICE))\n",
    "                    pred = model_out[Output.LOGITS].argmax(-1).cpu()\n",
    "\n",
    "                    batch_size = len(batch[\"index\"].cpu().tolist())\n",
    "                    predictions[\"run_id\"].extend([run_id] * batch_size)\n",
    "                    predictions[\"model_type\"].extend([model_type] * batch_size)\n",
    "                    predictions[\"dataset_name\"].extend([dataset_name] * batch_size)\n",
    "                    predictions[\"sample_idx\"].extend(batch[\"index\"].cpu().tolist())\n",
    "                    predictions[\"pred\"].extend(pred.cpu().tolist())\n",
    "                    predictions[\"target\"].extend(batch[\"target\"].cpu().tolist())\n",
    "                    del model_out\n",
    "                    del batch\n",
    "                model.cpu()\n",
    "                del model\n",
    "\n",
    "    predictions_df = pd.DataFrame(predictions)\n",
    "    predictions_df.to_csv(PREDICTIONS_TSV, sep=\"\\t\")\n",
    "    return predictions_df\n",
    "\n",
    "\n",
    "def measure_predictions(predictions_df: pd.DataFrame, force_measure: bool) -> pd.DataFrame:\n",
    "    if PERFORMANCE_TSV.exists() and not force_measure:\n",
    "        return pd.read_csv(PERFORMANCE_TSV, sep=\"\\t\", index_col=0)\n",
    "    PERFORMANCE_TSV.unlink(missing_ok=True)\n",
    "\n",
    "    performance = {**{x: [] for x in (\"model_type\", \"dataset_name\")}, **{k: [] for k in CONSIDERED_METRICS.keys()}}\n",
    "\n",
    "    for dataset_name, dataset_pred in RUNS.items():\n",
    "        for model_type, run_ids in dataset_pred.items():\n",
    "            for run_id in run_ids:\n",
    "                metrics = MetricCollection(\n",
    "                    {\n",
    "                        key: metric(num_classes=DATASET_NUM_CLASSES[dataset_name])\n",
    "                        for key, metric in CONSIDERED_METRICS.items()\n",
    "                    }\n",
    "                )\n",
    "                run_df = predictions_df[predictions_df[\"run_id\"] == run_id]\n",
    "                run_predictions = torch.as_tensor(run_df[\"pred\"].values)\n",
    "                run_targets = torch.as_tensor(run_df[\"target\"].values)\n",
    "\n",
    "                metrics.update(run_predictions, run_targets)\n",
    "\n",
    "                performance[\"dataset_name\"].append(dataset_name)\n",
    "                performance[\"model_type\"].append(model_type)\n",
    "                for metric_name, metric_value in metrics.compute().items():\n",
    "                    performance[metric_name].append(metric_value.item())\n",
    "\n",
    "    performance_df = pd.DataFrame(performance)\n",
    "    performance_df.to_csv(PERFORMANCE_TSV, sep=\"\\t\")\n",
    "    return performance_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40deb428",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "force_predict = force_measure = False\n",
    "predictions_df = compute_predictions(force_predict=force_predict)\n",
    "performance_df = measure_predictions(predictions_df, force_measure=force_measure)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eeb89827",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "performance_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5d6bd26",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "aggregated_perfomance = performance_df.groupby(\n",
    "    [\n",
    "        \"dataset_name\",\n",
    "        \"model_type\",\n",
    "    ]\n",
    ").agg([np.mean, np.std])\n",
    "aggregated_perfomance = (aggregated_perfomance * 100).round(2)\n",
    "aggregated_perfomance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06531537",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "COLUMN_ORDER = [\"mnist\", \"cmnist\", \"fmnist\", \"cifar10\", \"cifar100\", \"shapenet\", \"faust\", \"coma\", \"amz\"]\n",
    "METRIC_CONSIDERED = \"f1/macro\"\n",
    "\n",
    "classification_rel = \"Relative & {} $\\pm$ & {} $\\pm$ & {} $\\pm$ & {} $\\pm$ & {} $\\pm$ & {} $\\pm$ & {} $\\pm$ & {} $\\pm$ & {} $\\pm$ \\\\[1ex]\"\n",
    "classification_abs = \"Absolute & {} $\\pm$ & {} $\\pm$ & {} $\\pm$ & {} $\\pm$ & {} $\\pm$ & {} $\\pm$ & {} $\\pm$  & {} $\\pm$ & {} $\\pm$ \\\\[1ex]\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef5ad69e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "df = aggregated_perfomance[METRIC_CONSIDERED].reindex()\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43db67e1",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "type(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f21a2c07",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1a0892e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa34c383",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def extract_mean_std(df, dataset_name, model_type):\n",
    "    try:\n",
    "        mean_std = df.loc[dataset_name, model_type]\n",
    "        return rf\"${mean_std['mean']} \\pm {mean_std['std']}$\"\n",
    "    except (AttributeError, KeyError):\n",
    "        return \"?\"\n",
    "\n",
    "\n",
    "classification_rel.format(*[extract_mean_std(df, dataset_name, \"rel\") for dataset_name in COLUMN_ORDER])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b687f4d",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "print(extract_mean_std(df, \"cifar10\", \"2rel\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "380c69fd",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20ef97d2",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "a = "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2993e18",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
