{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "from collections import defaultdict\n",
    "import os.path as osp\n",
    "from typing import Union, Dict, List\n",
    "from itertools import combinations\n",
    "\n",
    "from wilds.datasets.wilds_dataset import WILDSDataset\n",
    "\n",
    "def get_eval_meta_args(logs, y_index, m_idx, split):\n",
    "    num_batches = len(logs[f\"{split}-metas\"])\n",
    "    logits_model = logs[f\"{split}-logits\"][m_idx]\n",
    "    y_true = []\n",
    "    y_pred= []\n",
    "    metadata = []\n",
    "    for batch_num in range(num_batches):\n",
    "        curr_metadata= np.array(logs[f\"{split}-metas\"][batch_num])\n",
    "        curr_y_true = curr_metadata[:,y_index]\n",
    "        curr_logits = np.array(logits_model[batch_num])\n",
    "        curr_y_pred = np.argmax(curr_logits, axis=1)\n",
    "\n",
    "        y_true = y_true + curr_y_true.tolist()\n",
    "        y_pred = y_pred + curr_y_pred.tolist()\n",
    "        metadata = metadata + curr_metadata.tolist()\n",
    "        \n",
    "\n",
    "    y_true = torch.tensor(y_true)\n",
    "    y_pred= torch.tensor(y_pred)\n",
    "    metadata= torch.tensor(metadata)\n",
    "\n",
    "    return y_pred, y_true, metadata\n",
    "\n",
    "def get_results(filename_format, dataset:WILDSDataset, seeds:List[int], ensemble_size:int, meta_metrics:Union[List[str],None], inverse:bool, show_val:bool):\n",
    "    y_index=dataset.metadata_fields.index(\"y\")\n",
    "\n",
    "    no_diversity=False\n",
    "    res= defaultdict(list)\n",
    "    \n",
    "    for seed in seeds:\n",
    "        filename= filename_format.format(seed=seed)\n",
    "        filename = osp.join(filename,\"summary.json\")\n",
    "\n",
    "        if not(osp.exists(filename)):\n",
    "            continue\n",
    "\n",
    "        with open(filename) as f:\n",
    "            logs=json.load(f)\n",
    "\n",
    "            res[\"test_acc_ensemble\"].append(logs[\"ensemble-test-acc\"])\n",
    "            #res[\"best_single_model_test_acc\"].append(max(logs[\"test-acc\"]))\n",
    "\n",
    "            for i in range(ensemble_size):\n",
    "                res[f\"test_m_{i+1}_acc\"].append(logs[\"test-acc\"][i])\n",
    "                #res[f\"val_m_{i+1}_acc\"].append(logs[f\"m{i+1}\"][\"valid-acc\"][-1][1])\n",
    "\n",
    "            pairwise_indexes = list(combinations(range(ensemble_size),2))\n",
    "            test_sim = []\n",
    "            unlabeled_sim = []\n",
    "            for pairwise_idx in pairwise_indexes:\n",
    "                i,j = pairwise_idx\n",
    "                test_sim.append(logs[\"test_similarity\"][i][j])\n",
    "                unlabeled_sim.append(logs[\"unlabeled_final_similarty\"][i][j])\n",
    "\n",
    "            res[\"test_similarity\"].append(np.array(test_sim).mean())\n",
    "            res[\"unlabeled_final_similarity\"].append(np.array(unlabeled_sim).mean())\n",
    "\n",
    "            ## worst group eval\n",
    "            for m_idx in range(ensemble_size):\n",
    "                y_pred, y_true, test_metadata = get_eval_meta_args(logs=logs, y_index=y_index, m_idx=m_idx, split=\"test\")\n",
    "                eval_res = dataset.eval(y_pred=y_pred, y_true=y_true, metadata=test_metadata)\n",
    "                for meta_metric in meta_metrics:\n",
    "                    res[f\"test_m_{m_idx+1}_{meta_metric}\"].append(eval_res[0][meta_metric])\n",
    "\n",
    "                if inverse:\n",
    "                    eval_res = dataset.eval(y_pred=1-y_pred, y_true=y_true, metadata=test_metadata)\n",
    "                    for meta_metric in meta_metrics:\n",
    "                        res[f\"test_m_{m_idx+1}_inverse_{meta_metric}\"].append(eval_res[0][meta_metric])\n",
    "                \n",
    "                if \"val-logits\" in logs and show_val:\n",
    "                    y_pred, y_true, val_metadata = get_eval_meta_args(logs=logs, y_index=y_index, m_idx=m_idx, split=\"val\")\n",
    "                    eval_res = dataset.eval(y_pred=y_pred, y_true=y_true, metadata=val_metadata)\n",
    "                    for meta_metric in meta_metrics:\n",
    "                        res[f\"val_m_{m_idx+1}_{meta_metric}\"].append(eval_res[0][meta_metric])\n",
    "\n",
    "                    if inverse:\n",
    "                        eval_res = dataset.eval(y_pred=1-y_pred, y_true=y_true, metadata=val_metadata)\n",
    "                        for meta_metric in meta_metrics:\n",
    "                            res[f\"val_m_{m_idx+1}_inverse_{meta_metric}\"].append(eval_res[0][meta_metric])\n",
    "\n",
    "    \n",
    "    return res\n",
    "\n",
    "def display_results(filename_format:str,title:str, seeds:List[int], dataset:WILDSDataset, meta_metrics:Union[List[str],None]=None , ensemble_size=2 , inverse = False, show_val = False):\n",
    "\n",
    "    res= get_results(filename_format=filename_format, dataset=dataset, meta_metrics=meta_metrics, seeds=seeds , ensemble_size=ensemble_size, inverse = inverse, show_val=show_val)\n",
    "    df = pd.DataFrame(res)\n",
    "    if df.empty:\n",
    "        return\n",
    "    results = df.aggregate([\"mean\",\"std\"])\n",
    "\n",
    "    print(title)\n",
    "    pd.options.display.float_format = \"{:,.3f}\".format\n",
    "    display(results)##inverse\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py38",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
