{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d700e119",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from rae.ui.evaluation import parse_checkpoints_tree\n",
    "import logging\n",
    "from collections import defaultdict\n",
    "from enum import auto\n",
    "from pathlib import Path\n",
    "from typing import Callable, Dict, Optional, Tuple, Type, Union\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import rich\n",
    "import torch\n",
    "import typer\n",
    "from torchmetrics import (\n",
    "    ErrorRelativeGlobalDimensionlessSynthesis,\n",
    "    MeanSquaredError,\n",
    "    MetricCollection,\n",
    "    MultiScaleStructuralSimilarityIndexMeasure,\n",
    "    PeakSignalNoiseRatio,\n",
    "    StructuralSimilarityIndexMeasure,\n",
    ")\n",
    "\n",
    "from rae.modules.enumerations import Output\n",
    "from rae.pl_modules.pl_gautoencoder import LightningAutoencoder\n",
    "\n",
    "try:\n",
    "    # be ready for 3.10 when it drops\n",
    "    from enum import StrEnum\n",
    "except ImportError:\n",
    "    from backports.strenum import StrEnum\n",
    "\n",
    "logging.getLogger().setLevel(logging.ERROR)\n",
    "\n",
    "\n",
    "BATCH_SIZE = 32\n",
    "\n",
    "\n",
    "EXPERIMENT_ROOT = Path(\".\").parent\n",
    "EXPERIMENT_CHECKPOINTS = EXPERIMENT_ROOT / \"checkpoints\"\n",
    "PREDICTIONS_TSV = EXPERIMENT_ROOT / \"predictions.tsv\"\n",
    "PERFORMANCE_TSV = EXPERIMENT_ROOT / \"performance.tsv\"\n",
    "\n",
    "DATASET_SANITY = {\n",
    "    \"mnist\": (\"rae.data.vision.fmnist.FashionMNISTDataset\", \"test\"),\n",
    "    \"fmnist\": (\"rae.data.vision.fmnist.FashionMNISTDataset\", \"test\"),\n",
    "    \"cifar10\": (\"rae.data.vision.fmnist.FashionMNISTDataset\", \"test\"),\n",
    "    \"cifar100\": (\"rae.data.vision.fmnist.FashionMNISTDataset\", \"test\"),\n",
    "}\n",
    "MODEL_SANITY = {\n",
    "    \"vae\": \"rae.modules.ae.VanillaAE\",\n",
    "    \"ae\": \"rae.modules.ae.VanillaAE\",\n",
    "    \"relvae\": \"rae.modules.ae.VanillaAE\",\n",
    "    \"relae\": \"rae.modules.ae.VanillaAE\",\n",
    "}\n",
    "\n",
    "\n",
    "checkpoints, RUNS = parse_checkpoints_tree(EXPERIMENT_CHECKPOINTS)\n",
    "\n",
    "\n",
    "DATASETS = sorted(checkpoints.keys())\n",
    "MODELS = sorted(checkpoints[DATASETS[0]].keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b05a0547",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "preds = pd.read_csv(PREDICTIONS_TSV, sep=\"\\t\", index_col=0)\n",
    "preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "281095ff",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "perf = pd.read_csv(PERFORMANCE_TSV, sep=\"\\t\", index_col=0)\n",
    "perf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "485240fa",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "aggregated_performnace = perf.drop(columns=[\"run_id\"])\n",
    "aggregated_performnace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "262bb6e4",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "aggregated_perfomance = aggregated_performnace.groupby(\n",
    "    [\n",
    "        \"dataset_name\",\n",
    "        \"model_type\",\n",
    "    ]\n",
    ").agg([np.mean, np.std])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02421690",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "aggregated_perfomance = aggregated_perfomance.round(4)\n",
    "aggregated_perfomance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b7a0616",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7627c5f6",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "aggregated_perfomance = (\n",
    "    aggregated_perfomance[[\"mse\", \"ergas\", \"psnr\", \"ssim\"]]\n",
    "    .reindex([\"ae\", \"vae\", \"rel_ae\", \"rel_vae\"], level=\"model_type\")\n",
    "    .reindex([\"mnist\", \"fmnist\", \"cifar10\", \"cifar100\"], level=\"dataset_name\")\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
