{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "from collections import defaultdict\n",
    "import os.path as osp\n",
    "from typing import Union, Dict, List\n",
    "from itertools import combinations\n",
    "\n",
    "\n",
    "def get_results(filename_format, seeds:List[int], ensemble_size:int, inverse:bool, show_val:bool):\n",
    "\n",
    "    no_diversity=False\n",
    "    res= defaultdict(list)\n",
    "    \n",
    "    for seed in seeds:\n",
    "        filename= filename_format.format(seed=seed)\n",
    "        filename = osp.join(filename,\"summary.json\")\n",
    "\n",
    "        if not(osp.exists(filename)):\n",
    "            continue\n",
    "\n",
    "        with open(filename) as f:\n",
    "            logs=json.load(f)\n",
    "\n",
    "            res[\"test_acc_ensemble\"].append(logs[\"ensemble-test-acc\"])\n",
    "            #res[\"best_single_model_test_acc\"].append(max(logs[\"test-acc\"]))\n",
    "\n",
    "            for i in range(ensemble_size):\n",
    "                res[f\"test_m_{i+1}_acc\"].append(logs[\"test-acc\"][i])\n",
    "                print(f\"test_acc_{i}\",logs[\"test-acc\"][i])\n",
    "                #res[f\"val_m_{i+1}_acc\"].append(logs[f\"m{i+1}\"][\"valid-acc\"][-1][1])\n",
    "\n",
    "            pairwise_indexes = list(combinations(range(ensemble_size),2))\n",
    "            test_sim = []\n",
    "            unlabeled_sim = []\n",
    "            for pairwise_idx in pairwise_indexes:\n",
    "                i,j = pairwise_idx\n",
    "                test_sim.append(logs[\"test_similarity\"][i][j])\n",
    "                unlabeled_sim.append(logs[\"unlabeled_final_similarty\"][i][j])\n",
    "\n",
    "            res[\"test_similarity\"].append(np.array(test_sim).mean())\n",
    "            res[\"unlabeled_final_similarity\"].append(np.array(unlabeled_sim).mean())\n",
    "    \n",
    "    return res\n",
    "\n",
    "def display_results(filename_format:str,title:str, seeds:List[int], ensemble_size=2 , inverse = False, show_val = False):\n",
    "\n",
    "    res= get_results(filename_format=filename_format, seeds=seeds , ensemble_size=ensemble_size, inverse = inverse, show_val=show_val)\n",
    "    df = pd.DataFrame(res)\n",
    "    if df.empty:\n",
    "        return df, False\n",
    "    results = df.aggregate([\"mean\",\"std\"])\n",
    "\n",
    "\n",
    "    return results, True\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py38",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
