{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from models import *\n",
    "import torch.backends.cudnn as cudnn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "0 9561 10000 17.169633550569415\n",
      "\n",
      "Files already downloaded and verified\n",
      "1 9540 10000 17.46842012926936\n",
      "\n",
      "Files already downloaded and verified\n",
      "2 9540 10000 17.126558613032103\n",
      "\n",
      "Files already downloaded and verified\n",
      "3 9544 10000 17.507671582512558\n",
      "\n",
      "Files already downloaded and verified\n",
      "4 9565 10000 17.41930754855275\n",
      "\n",
      "Files already downloaded and verified\n",
      "5 9561 10000 16.525208700448275\n",
      "\n",
      "Files already downloaded and verified\n",
      "6 9529 10000 17.414313331246376\n",
      "\n",
      "Files already downloaded and verified\n",
      "7 9539 10000 18.271342071704566\n",
      "\n",
      "Files already downloaded and verified\n",
      "8 9504 10000 18.213089518249035\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model_folder = \"baseline_resnet34\"\n",
    "\n",
    "net_dict = {\n",
    "    \"baseline_resnet18\": ResNet18(),\n",
    "    \"baseline_resnet34\": ResNet34(),\n",
    "}\n",
    "\n",
    "for seed in range(0, 9):\n",
    "    net = net_dict[model_folder]\n",
    "    device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "    if device == 'cuda':\n",
    "        net = torch.nn.DataParallel(net)\n",
    "        cudnn.benchmark = True\n",
    "    net.load_state_dict(torch.load(f'./checkpoint/{model_folder}/ckpt_{seed}.pth')[\"net\"])\n",
    "\n",
    "    transform_test = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "    ])\n",
    "    testset = torchvision.datasets.CIFAR10(\n",
    "        root='./data', train=False, download=True, transform=transform_test)\n",
    "    testloader = torch.utils.data.DataLoader(\n",
    "        testset, batch_size=100, shuffle=False, num_workers=2)\n",
    "\n",
    "    classes = ('plane', 'car', 'bird', 'cat', 'deer',\n",
    "            'dog', 'frog', 'horse', 'ship', 'truck')\n",
    "\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "    def test(net, testloader):\n",
    "        net.eval()\n",
    "        test_loss = 0\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        predictions = np.zeros((0, 10))\n",
    "        y = np.array([])\n",
    "        with torch.no_grad():\n",
    "            for batch_idx, (inputs, targets) in enumerate(testloader):\n",
    "                inputs, targets = inputs.to(device), targets.to(device)\n",
    "                outputs = net(inputs)\n",
    "                predictions = np.row_stack((predictions, outputs.cpu().numpy()))\n",
    "                y = np.append(y, targets.cpu().numpy())\n",
    "                loss = criterion(outputs, targets)\n",
    "                test_loss += loss.item()\n",
    "                _, predicted = outputs.max(1)\n",
    "                total += targets.size(0)\n",
    "                correct += predicted.eq(targets).sum().item()\n",
    "        print(seed, correct, total, test_loss)\n",
    "        print()\n",
    "        return predictions, y\n",
    "\n",
    "    predictions, y = test(net, testloader)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Before temperature - NLL: 0.172, ECE: 0.027\n",
      "Optimal temperature: 1.438\n",
      "After temperature - NLL: 0.156, ECE: 0.013\n",
      "0 7651 8000 12.437924511730671\n",
      "\n",
      "Before temperature - NLL: 0.193, ECE: 0.027\n",
      "Optimal temperature: 1.449\n",
      "After temperature - NLL: 0.173, ECE: 0.013\n",
      "1 7636 8000 12.428090248256922\n",
      "\n",
      "Before temperature - NLL: 0.174, ECE: 0.026\n",
      "Optimal temperature: 1.437\n",
      "After temperature - NLL: 0.158, ECE: 0.014\n",
      "2 7633 8000 12.494446732103825\n",
      "\n",
      "Before temperature - NLL: 0.190, ECE: 0.030\n",
      "Optimal temperature: 1.453\n",
      "After temperature - NLL: 0.168, ECE: 0.015\n",
      "3 7641 8000 12.457159992307425\n",
      "\n",
      "Before temperature - NLL: 0.200, ECE: 0.029\n",
      "Optimal temperature: 1.458\n",
      "After temperature - NLL: 0.176, ECE: 0.010\n",
      "4 7664 8000 12.219030018895864\n",
      "\n",
      "Before temperature - NLL: 0.182, ECE: 0.026\n",
      "Optimal temperature: 1.439\n",
      "After temperature - NLL: 0.165, ECE: 0.012\n",
      "5 7650 8000 11.983344290405512\n",
      "\n",
      "Before temperature - NLL: 0.192, ECE: 0.028\n",
      "Optimal temperature: 1.447\n",
      "After temperature - NLL: 0.172, ECE: 0.015\n",
      "6 7631 8000 12.490618422627449\n",
      "\n",
      "Before temperature - NLL: 0.185, ECE: 0.027\n",
      "Optimal temperature: 1.447\n",
      "After temperature - NLL: 0.166, ECE: 0.017\n",
      "7 7631 8000 13.019607406109571\n",
      "\n",
      "Before temperature - NLL: 0.195, ECE: 0.029\n",
      "Optimal temperature: 1.452\n",
      "After temperature - NLL: 0.174, ECE: 0.012\n",
      "8 7604 8000 13.007608383893967\n",
      "\n",
      "Before temperature - NLL: 0.183, ECE: 0.029\n",
      "Optimal temperature: 1.438\n",
      "After temperature - NLL: 0.166, ECE: 0.015\n",
      "9 7617 8000 12.84865446574986\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from temperature_scaling import ModelWithTemperature\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "import random\n",
    "\n",
    "valid_indices = random.sample(range(10000), k=2000)\n",
    "testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)\n",
    "valid_loader = torch.utils.data.DataLoader(\n",
    "    testset, \n",
    "    pin_memory=True, \n",
    "    batch_size=100,\n",
    "    sampler=SubsetRandomSampler(valid_indices),\n",
    ")\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    testset, \n",
    "    pin_memory=True, \n",
    "    batch_size=100,\n",
    "    sampler=SubsetRandomSampler(list(set(range(10000)) - set(valid_indices))),\n",
    ")\n",
    "\n",
    "for seed in range(0, 10):\n",
    "    net = ResNet18()\n",
    "    device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "    if device == 'cuda':\n",
    "        net = torch.nn.DataParallel(net)\n",
    "        cudnn.benchmark = True\n",
    "    net.load_state_dict(torch.load(f'./checkpoint/baseline_resnet18/ckpt_{seed}.pth')[\"net\"])\n",
    "    net.eval()\n",
    "    model = ModelWithTemperature(net)\n",
    "    model.set_temperature(valid_loader)\n",
    "    # model = net\n",
    "    predictions, y = test(model, test_loader)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# validation_indices = random.sample(range(10000), k=2000)\n",
    "# np.savetxt('./validation_indices.txt', validation_indices, fmt='%i')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Before temperature - NLL: 0.164, ECE: 0.026\n",
      "Optimal temperature: 1.527\n",
      "After temperature - NLL: 0.135, ECE: 0.012\n",
      "Before temperature - NLL: 0.193, ECE: 0.031\n",
      "Optimal temperature: 1.554\n",
      "After temperature - NLL: 0.153, ECE: 0.018\n",
      "Before temperature - NLL: 0.187, ECE: 0.031\n",
      "Optimal temperature: 1.548\n",
      "After temperature - NLL: 0.148, ECE: 0.017\n",
      "Before temperature - NLL: 0.183, ECE: 0.030\n",
      "Optimal temperature: 1.543\n",
      "After temperature - NLL: 0.148, ECE: 0.014\n",
      "Before temperature - NLL: 0.156, ECE: 0.023\n",
      "Optimal temperature: 1.511\n",
      "After temperature - NLL: 0.131, ECE: 0.008\n",
      "Before temperature - NLL: 0.180, ECE: 0.030\n",
      "Optimal temperature: 1.540\n",
      "After temperature - NLL: 0.145, ECE: 0.012\n",
      "Before temperature - NLL: 0.201, ECE: 0.029\n",
      "Optimal temperature: 1.559\n",
      "After temperature - NLL: 0.159, ECE: 0.015\n",
      "Before temperature - NLL: 0.162, ECE: 0.025\n",
      "Optimal temperature: 1.530\n",
      "After temperature - NLL: 0.132, ECE: 0.012\n",
      "Before temperature - NLL: 0.160, ECE: 0.023\n",
      "Optimal temperature: 1.513\n",
      "After temperature - NLL: 0.135, ECE: 0.011\n",
      "Before temperature - NLL: 0.185, ECE: 0.029\n",
      "Optimal temperature: 1.544\n",
      "After temperature - NLL: 0.149, ECE: 0.015\n",
      "Before temperature - NLL: 0.203, ECE: 0.033\n",
      "Optimal temperature: 1.554\n",
      "After temperature - NLL: 0.161, ECE: 0.018\n",
      "Before temperature - NLL: 0.153, ECE: 0.022\n",
      "Optimal temperature: 1.513\n",
      "After temperature - NLL: 0.128, ECE: 0.006\n",
      "Before temperature - NLL: 0.202, ECE: 0.030\n",
      "Optimal temperature: 1.558\n",
      "After temperature - NLL: 0.158, ECE: 0.013\n",
      "Before temperature - NLL: 0.185, ECE: 0.030\n",
      "Optimal temperature: 1.548\n",
      "After temperature - NLL: 0.148, ECE: 0.014\n",
      "Before temperature - NLL: 0.170, ECE: 0.029\n",
      "Optimal temperature: 1.530\n",
      "After temperature - NLL: 0.139, ECE: 0.011\n",
      "Before temperature - NLL: 0.163, ECE: 0.025\n",
      "Optimal temperature: 1.531\n",
      "After temperature - NLL: 0.134, ECE: 0.015\n",
      "Before temperature - NLL: 0.180, ECE: 0.031\n",
      "Optimal temperature: 1.553\n",
      "After temperature - NLL: 0.142, ECE: 0.014\n",
      "Before temperature - NLL: 0.197, ECE: 0.033\n",
      "Optimal temperature: 1.561\n",
      "After temperature - NLL: 0.156, ECE: 0.017\n",
      "Before temperature - NLL: 0.183, ECE: 0.029\n",
      "Optimal temperature: 1.537\n",
      "After temperature - NLL: 0.149, ECE: 0.013\n",
      "Before temperature - NLL: 0.180, ECE: 0.028\n",
      "Optimal temperature: 1.536\n",
      "After temperature - NLL: 0.147, ECE: 0.012\n",
      "Before temperature - NLL: 0.178, ECE: 0.029\n",
      "Optimal temperature: 1.529\n",
      "After temperature - NLL: 0.147, ECE: 0.013\n",
      "Before temperature - NLL: 0.198, ECE: 0.033\n",
      "Optimal temperature: 1.545\n",
      "After temperature - NLL: 0.160, ECE: 0.014\n",
      "Before temperature - NLL: 0.179, ECE: 0.027\n",
      "Optimal temperature: 1.539\n",
      "After temperature - NLL: 0.145, ECE: 0.010\n",
      "Before temperature - NLL: 0.169, ECE: 0.025\n",
      "Optimal temperature: 1.525\n",
      "After temperature - NLL: 0.140, ECE: 0.010\n",
      "Before temperature - NLL: 0.178, ECE: 0.030\n",
      "Optimal temperature: 1.530\n",
      "After temperature - NLL: 0.146, ECE: 0.016\n",
      "Before temperature - NLL: 0.185, ECE: 0.026\n",
      "Optimal temperature: 1.545\n",
      "After temperature - NLL: 0.149, ECE: 0.009\n",
      "Before temperature - NLL: 0.179, ECE: 0.025\n",
      "Optimal temperature: 1.541\n",
      "After temperature - NLL: 0.145, ECE: 0.014\n",
      "Before temperature - NLL: 0.155, ECE: 0.023\n",
      "Optimal temperature: 1.501\n",
      "After temperature - NLL: 0.131, ECE: 0.010\n",
      "Before temperature - NLL: 0.177, ECE: 0.029\n",
      "Optimal temperature: 1.540\n",
      "After temperature - NLL: 0.141, ECE: 0.015\n",
      "Before temperature - NLL: 0.156, ECE: 0.023\n",
      "Optimal temperature: 1.524\n",
      "After temperature - NLL: 0.128, ECE: 0.012\n",
      "Before temperature - NLL: 0.175, ECE: 0.029\n",
      "Optimal temperature: 1.539\n",
      "After temperature - NLL: 0.142, ECE: 0.016\n",
      "Before temperature - NLL: 0.170, ECE: 0.026\n",
      "Optimal temperature: 1.519\n",
      "After temperature - NLL: 0.142, ECE: 0.011\n",
      "Before temperature - NLL: 0.197, ECE: 0.029\n",
      "Optimal temperature: 1.559\n",
      "After temperature - NLL: 0.154, ECE: 0.016\n",
      "Before temperature - NLL: 0.174, ECE: 0.029\n",
      "Optimal temperature: 1.533\n",
      "After temperature - NLL: 0.141, ECE: 0.013\n",
      "Before temperature - NLL: 0.151, ECE: 0.023\n",
      "Optimal temperature: 1.515\n",
      "After temperature - NLL: 0.126, ECE: 0.008\n",
      "Before temperature - NLL: 0.185, ECE: 0.029\n",
      "Optimal temperature: 1.556\n",
      "After temperature - NLL: 0.144, ECE: 0.018\n",
      "Before temperature - NLL: 0.174, ECE: 0.025\n",
      "Optimal temperature: 1.533\n",
      "After temperature - NLL: 0.142, ECE: 0.013\n",
      "Before temperature - NLL: 0.170, ECE: 0.029\n",
      "Optimal temperature: 1.534\n",
      "After temperature - NLL: 0.140, ECE: 0.014\n",
      "Before temperature - NLL: 0.179, ECE: 0.027\n",
      "Optimal temperature: 1.536\n",
      "After temperature - NLL: 0.146, ECE: 0.014\n",
      "Before temperature - NLL: 0.155, ECE: 0.022\n",
      "Optimal temperature: 1.513\n",
      "After temperature - NLL: 0.131, ECE: 0.010\n",
      "Before temperature - NLL: 0.179, ECE: 0.028\n",
      "Optimal temperature: 1.542\n",
      "After temperature - NLL: 0.143, ECE: 0.014\n",
      "Before temperature - NLL: 0.174, ECE: 0.028\n",
      "Optimal temperature: 1.542\n",
      "After temperature - NLL: 0.139, ECE: 0.012\n",
      "Before temperature - NLL: 0.176, ECE: 0.025\n",
      "Optimal temperature: 1.534\n",
      "After temperature - NLL: 0.144, ECE: 0.012\n",
      "Before temperature - NLL: 0.168, ECE: 0.032\n",
      "Optimal temperature: 1.527\n",
      "After temperature - NLL: 0.137, ECE: 0.016\n",
      "Before temperature - NLL: 0.205, ECE: 0.032\n",
      "Optimal temperature: 1.560\n",
      "After temperature - NLL: 0.162, ECE: 0.015\n",
      "Before temperature - NLL: 0.154, ECE: 0.023\n",
      "Optimal temperature: 1.513\n",
      "After temperature - NLL: 0.129, ECE: 0.010\n",
      "Before temperature - NLL: 0.162, ECE: 0.023\n",
      "Optimal temperature: 1.520\n",
      "After temperature - NLL: 0.135, ECE: 0.010\n",
      "Before temperature - NLL: 0.179, ECE: 0.027\n",
      "Optimal temperature: 1.541\n",
      "After temperature - NLL: 0.145, ECE: 0.013\n",
      "Before temperature - NLL: 0.167, ECE: 0.026\n",
      "Optimal temperature: 1.517\n",
      "After temperature - NLL: 0.140, ECE: 0.013\n",
      "Before temperature - NLL: 0.194, ECE: 0.033\n",
      "Optimal temperature: 1.548\n",
      "After temperature - NLL: 0.155, ECE: 0.015\n",
      "Before temperature - NLL: 0.144, ECE: 0.020\n",
      "Optimal temperature: 1.487\n",
      "After temperature - NLL: 0.125, ECE: 0.011\n",
      "Before temperature - NLL: 0.171, ECE: 0.027\n",
      "Optimal temperature: 1.519\n",
      "After temperature - NLL: 0.143, ECE: 0.011\n",
      "Before temperature - NLL: 0.186, ECE: 0.029\n",
      "Optimal temperature: 1.549\n",
      "After temperature - NLL: 0.149, ECE: 0.014\n",
      "Before temperature - NLL: 0.190, ECE: 0.030\n",
      "Optimal temperature: 1.527\n",
      "After temperature - NLL: 0.154, ECE: 0.013\n",
      "Before temperature - NLL: 0.169, ECE: 0.028\n",
      "Optimal temperature: 1.531\n",
      "After temperature - NLL: 0.138, ECE: 0.013\n",
      "Before temperature - NLL: 0.183, ECE: 0.031\n",
      "Optimal temperature: 1.544\n",
      "After temperature - NLL: 0.147, ECE: 0.015\n",
      "Before temperature - NLL: 0.185, ECE: 0.027\n",
      "Optimal temperature: 1.544\n",
      "After temperature - NLL: 0.149, ECE: 0.014\n",
      "Before temperature - NLL: 0.176, ECE: 0.028\n",
      "Optimal temperature: 1.530\n",
      "After temperature - NLL: 0.145, ECE: 0.014\n",
      "Before temperature - NLL: 0.170, ECE: 0.025\n",
      "Optimal temperature: 1.520\n",
      "After temperature - NLL: 0.143, ECE: 0.014\n",
      "Before temperature - NLL: 0.197, ECE: 0.029\n",
      "Optimal temperature: 1.566\n",
      "After temperature - NLL: 0.152, ECE: 0.016\n",
      "Files already downloaded and verified\n",
      "Before temperature - NLL: 0.180, ECE: 0.028\n",
      "Optimal temperature: 1.545\n",
      "After temperature - NLL: 0.144, ECE: 0.016\n",
      "Before temperature - NLL: 0.170, ECE: 0.027\n",
      "Optimal temperature: 1.537\n",
      "After temperature - NLL: 0.137, ECE: 0.013\n",
      "Before temperature - NLL: 0.176, ECE: 0.030\n",
      "Optimal temperature: 1.535\n",
      "After temperature - NLL: 0.142, ECE: 0.013\n",
      "Before temperature - NLL: 0.189, ECE: 0.028\n",
      "Optimal temperature: 1.552\n",
      "After temperature - NLL: 0.151, ECE: 0.014\n",
      "Before temperature - NLL: 0.208, ECE: 0.033\n",
      "Optimal temperature: 1.568\n",
      "After temperature - NLL: 0.162, ECE: 0.015\n",
      "Before temperature - NLL: 0.183, ECE: 0.028\n",
      "Optimal temperature: 1.550\n",
      "After temperature - NLL: 0.145, ECE: 0.013\n",
      "Before temperature - NLL: 0.178, ECE: 0.028\n",
      "Optimal temperature: 1.538\n",
      "After temperature - NLL: 0.144, ECE: 0.010\n",
      "Before temperature - NLL: 0.179, ECE: 0.030\n",
      "Optimal temperature: 1.545\n",
      "After temperature - NLL: 0.143, ECE: 0.013\n",
      "Before temperature - NLL: 0.166, ECE: 0.026\n",
      "Optimal temperature: 1.527\n",
      "After temperature - NLL: 0.135, ECE: 0.013\n",
      "Before temperature - NLL: 0.169, ECE: 0.029\n",
      "Optimal temperature: 1.533\n",
      "After temperature - NLL: 0.138, ECE: 0.016\n",
      "Before temperature - NLL: 0.211, ECE: 0.033\n",
      "Optimal temperature: 1.573\n",
      "After temperature - NLL: 0.162, ECE: 0.017\n",
      "Before temperature - NLL: 0.156, ECE: 0.023\n",
      "Optimal temperature: 1.524\n",
      "After temperature - NLL: 0.127, ECE: 0.009\n",
      "Before temperature - NLL: 0.155, ECE: 0.026\n",
      "Optimal temperature: 1.512\n",
      "After temperature - NLL: 0.129, ECE: 0.011\n",
      "Before temperature - NLL: 0.183, ECE: 0.026\n",
      "Optimal temperature: 1.544\n",
      "After temperature - NLL: 0.147, ECE: 0.014\n",
      "Before temperature - NLL: 0.199, ECE: 0.033\n",
      "Optimal temperature: 1.554\n",
      "After temperature - NLL: 0.158, ECE: 0.015\n",
      "Before temperature - NLL: 0.185, ECE: 0.034\n",
      "Optimal temperature: 1.547\n",
      "After temperature - NLL: 0.148, ECE: 0.018\n",
      "Before temperature - NLL: 0.179, ECE: 0.027\n",
      "Optimal temperature: 1.549\n",
      "After temperature - NLL: 0.143, ECE: 0.009\n",
      "Before temperature - NLL: 0.177, ECE: 0.026\n",
      "Optimal temperature: 1.544\n",
      "After temperature - NLL: 0.142, ECE: 0.014\n",
      "Before temperature - NLL: 0.201, ECE: 0.031\n",
      "Optimal temperature: 1.569\n",
      "After temperature - NLL: 0.155, ECE: 0.017\n",
      "Before temperature - NLL: 0.166, ECE: 0.026\n",
      "Optimal temperature: 1.527\n",
      "After temperature - NLL: 0.136, ECE: 0.011\n",
      "Before temperature - NLL: 0.201, ECE: 0.032\n",
      "Optimal temperature: 1.566\n",
      "After temperature - NLL: 0.156, ECE: 0.014\n",
      "Before temperature - NLL: 0.182, ECE: 0.029\n",
      "Optimal temperature: 1.549\n",
      "After temperature - NLL: 0.146, ECE: 0.014\n",
      "Before temperature - NLL: 0.168, ECE: 0.027\n",
      "Optimal temperature: 1.541\n",
      "After temperature - NLL: 0.135, ECE: 0.014\n",
      "Before temperature - NLL: 0.157, ECE: 0.026\n",
      "Optimal temperature: 1.516\n",
      "After temperature - NLL: 0.131, ECE: 0.012\n",
      "Before temperature - NLL: 0.177, ECE: 0.030\n",
      "Optimal temperature: 1.535\n",
      "After temperature - NLL: 0.143, ECE: 0.012\n",
      "Before temperature - NLL: 0.158, ECE: 0.024\n",
      "Optimal temperature: 1.519\n",
      "After temperature - NLL: 0.131, ECE: 0.012\n",
      "Before temperature - NLL: 0.190, ECE: 0.030\n",
      "Optimal temperature: 1.565\n",
      "After temperature - NLL: 0.147, ECE: 0.013\n",
      "Before temperature - NLL: 0.178, ECE: 0.029\n",
      "Optimal temperature: 1.537\n",
      "After temperature - NLL: 0.144, ECE: 0.011\n",
      "Before temperature - NLL: 0.173, ECE: 0.027\n",
      "Optimal temperature: 1.532\n",
      "After temperature - NLL: 0.142, ECE: 0.010\n",
      "Before temperature - NLL: 0.176, ECE: 0.023\n",
      "Optimal temperature: 1.539\n",
      "After temperature - NLL: 0.143, ECE: 0.013\n",
      "Before temperature - NLL: 0.220, ECE: 0.036\n",
      "Optimal temperature: 1.581\n",
      "After temperature - NLL: 0.169, ECE: 0.019\n",
      "Before temperature - NLL: 0.167, ECE: 0.026\n",
      "Optimal temperature: 1.531\n",
      "After temperature - NLL: 0.137, ECE: 0.012\n",
      "Before temperature - NLL: 0.181, ECE: 0.026\n",
      "Optimal temperature: 1.550\n",
      "After temperature - NLL: 0.144, ECE: 0.011\n",
      "Before temperature - NLL: 0.172, ECE: 0.027\n",
      "Optimal temperature: 1.541\n",
      "After temperature - NLL: 0.137, ECE: 0.011\n",
      "Before temperature - NLL: 0.161, ECE: 0.025\n",
      "Optimal temperature: 1.533\n",
      "After temperature - NLL: 0.130, ECE: 0.010\n",
      "Before temperature - NLL: 0.160, ECE: 0.023\n",
      "Optimal temperature: 1.516\n",
      "After temperature - NLL: 0.133, ECE: 0.010\n",
      "Before temperature - NLL: 0.194, ECE: 0.029\n",
      "Optimal temperature: 1.561\n",
      "After temperature - NLL: 0.153, ECE: 0.015\n",
      "Before temperature - NLL: 0.164, ECE: 0.026\n",
      "Optimal temperature: 1.524\n",
      "After temperature - NLL: 0.135, ECE: 0.011\n",
      "Before temperature - NLL: 0.154, ECE: 0.021\n",
      "Optimal temperature: 1.514\n",
      "After temperature - NLL: 0.128, ECE: 0.010\n",
      "Before temperature - NLL: 0.163, ECE: 0.027\n",
      "Optimal temperature: 1.527\n",
      "After temperature - NLL: 0.134, ECE: 0.011\n",
      "Before temperature - NLL: 0.172, ECE: 0.028\n",
      "Optimal temperature: 1.550\n",
      "After temperature - NLL: 0.136, ECE: 0.013\n",
      "Before temperature - NLL: 0.210, ECE: 0.037\n",
      "Optimal temperature: 1.563\n",
      "After temperature - NLL: 0.164, ECE: 0.020\n",
      "Before temperature - NLL: 0.160, ECE: 0.024\n",
      "Optimal temperature: 1.520\n",
      "After temperature - NLL: 0.133, ECE: 0.011\n",
      "Before temperature - NLL: 0.182, ECE: 0.028\n",
      "Optimal temperature: 1.553\n",
      "After temperature - NLL: 0.143, ECE: 0.012\n",
      "Before temperature - NLL: 0.161, ECE: 0.024\n",
      "Optimal temperature: 1.536\n",
      "After temperature - NLL: 0.130, ECE: 0.014\n",
      "Before temperature - NLL: 0.164, ECE: 0.024\n",
      "Optimal temperature: 1.527\n",
      "After temperature - NLL: 0.134, ECE: 0.013\n",
      "Before temperature - NLL: 0.180, ECE: 0.029\n",
      "Optimal temperature: 1.545\n",
      "After temperature - NLL: 0.144, ECE: 0.013\n",
      "Before temperature - NLL: 0.169, ECE: 0.024\n",
      "Optimal temperature: 1.530\n",
      "After temperature - NLL: 0.138, ECE: 0.012\n",
      "Before temperature - NLL: 0.156, ECE: 0.027\n",
      "Optimal temperature: 1.522\n",
      "After temperature - NLL: 0.129, ECE: 0.010\n",
      "Before temperature - NLL: 0.171, ECE: 0.026\n",
      "Optimal temperature: 1.529\n",
      "After temperature - NLL: 0.140, ECE: 0.012\n",
      "Before temperature - NLL: 0.184, ECE: 0.030\n",
      "Optimal temperature: 1.546\n",
      "After temperature - NLL: 0.146, ECE: 0.011\n",
      "Before temperature - NLL: 0.186, ECE: 0.029\n",
      "Optimal temperature: 1.552\n",
      "After temperature - NLL: 0.148, ECE: 0.014\n",
      "Before temperature - NLL: 0.183, ECE: 0.030\n",
      "Optimal temperature: 1.542\n",
      "After temperature - NLL: 0.147, ECE: 0.013\n",
      "Before temperature - NLL: 0.182, ECE: 0.033\n",
      "Optimal temperature: 1.546\n",
      "After temperature - NLL: 0.144, ECE: 0.015\n",
      "Before temperature - NLL: 0.185, ECE: 0.032\n",
      "Optimal temperature: 1.553\n",
      "After temperature - NLL: 0.147, ECE: 0.015\n",
      "Before temperature - NLL: 0.171, ECE: 0.028\n",
      "Optimal temperature: 1.530\n",
      "After temperature - NLL: 0.140, ECE: 0.016\n",
      "Before temperature - NLL: 0.162, ECE: 0.025\n",
      "Optimal temperature: 1.528\n",
      "After temperature - NLL: 0.133, ECE: 0.012\n",
      "Before temperature - NLL: 0.171, ECE: 0.030\n",
      "Optimal temperature: 1.536\n",
      "After temperature - NLL: 0.138, ECE: 0.013\n",
      "Before temperature - NLL: 0.185, ECE: 0.029\n",
      "Optimal temperature: 1.546\n",
      "After temperature - NLL: 0.147, ECE: 0.015\n",
      "Before temperature - NLL: 0.186, ECE: 0.028\n",
      "Optimal temperature: 1.556\n",
      "After temperature - NLL: 0.148, ECE: 0.012\n",
      "Files already downloaded and verified\n",
      "Before temperature - NLL: 0.182, ECE: 0.031\n",
      "Optimal temperature: 1.549\n",
      "After temperature - NLL: 0.145, ECE: 0.015\n",
      "Before temperature - NLL: 0.215, ECE: 0.038\n",
      "Optimal temperature: 1.577\n",
      "After temperature - NLL: 0.164, ECE: 0.018\n",
      "Before temperature - NLL: 0.151, ECE: 0.025\n",
      "Optimal temperature: 1.538\n",
      "After temperature - NLL: 0.123, ECE: 0.011\n",
      "Before temperature - NLL: 0.175, ECE: 0.028\n",
      "Optimal temperature: 1.551\n",
      "After temperature - NLL: 0.138, ECE: 0.014\n",
      "Before temperature - NLL: 0.169, ECE: 0.029\n",
      "Optimal temperature: 1.536\n",
      "After temperature - NLL: 0.136, ECE: 0.013\n",
      "Before temperature - NLL: 0.170, ECE: 0.027\n",
      "Optimal temperature: 1.536\n",
      "After temperature - NLL: 0.137, ECE: 0.015\n",
      "Before temperature - NLL: 0.150, ECE: 0.028\n",
      "Optimal temperature: 1.519\n",
      "After temperature - NLL: 0.124, ECE: 0.012\n",
      "Before temperature - NLL: 0.169, ECE: 0.025\n",
      "Optimal temperature: 1.539\n",
      "After temperature - NLL: 0.136, ECE: 0.011\n",
      "Before temperature - NLL: 0.169, ECE: 0.031\n",
      "Optimal temperature: 1.537\n",
      "After temperature - NLL: 0.136, ECE: 0.014\n",
      "Before temperature - NLL: 0.183, ECE: 0.032\n",
      "Optimal temperature: 1.562\n",
      "After temperature - NLL: 0.142, ECE: 0.016\n",
      "Before temperature - NLL: 0.170, ECE: 0.030\n",
      "Optimal temperature: 1.542\n",
      "After temperature - NLL: 0.136, ECE: 0.015\n",
      "Before temperature - NLL: 0.168, ECE: 0.028\n",
      "Optimal temperature: 1.541\n",
      "After temperature - NLL: 0.134, ECE: 0.012\n",
      "Before temperature - NLL: 0.182, ECE: 0.030\n",
      "Optimal temperature: 1.554\n",
      "After temperature - NLL: 0.143, ECE: 0.013\n",
      "Before temperature - NLL: 0.161, ECE: 0.027\n",
      "Optimal temperature: 1.537\n",
      "After temperature - NLL: 0.129, ECE: 0.013\n",
      "Before temperature - NLL: 0.175, ECE: 0.025\n",
      "Optimal temperature: 1.553\n",
      "After temperature - NLL: 0.139, ECE: 0.009\n",
      "Before temperature - NLL: 0.176, ECE: 0.033\n",
      "Optimal temperature: 1.546\n",
      "After temperature - NLL: 0.141, ECE: 0.016\n",
      "Before temperature - NLL: 0.165, ECE: 0.031\n",
      "Optimal temperature: 1.546\n",
      "After temperature - NLL: 0.131, ECE: 0.017\n",
      "Before temperature - NLL: 0.152, ECE: 0.026\n",
      "Optimal temperature: 1.515\n",
      "After temperature - NLL: 0.127, ECE: 0.008\n",
      "Before temperature - NLL: 0.165, ECE: 0.028\n",
      "Optimal temperature: 1.536\n",
      "After temperature - NLL: 0.134, ECE: 0.012\n",
      "Before temperature - NLL: 0.171, ECE: 0.029\n",
      "Optimal temperature: 1.547\n",
      "After temperature - NLL: 0.136, ECE: 0.013\n",
      "Before temperature - NLL: 0.194, ECE: 0.034\n",
      "Optimal temperature: 1.569\n",
      "After temperature - NLL: 0.150, ECE: 0.017\n",
      "Before temperature - NLL: 0.165, ECE: 0.026\n",
      "Optimal temperature: 1.538\n",
      "After temperature - NLL: 0.132, ECE: 0.012\n",
      "Before temperature - NLL: 0.161, ECE: 0.025\n",
      "Optimal temperature: 1.528\n",
      "After temperature - NLL: 0.130, ECE: 0.010\n",
      "Before temperature - NLL: 0.197, ECE: 0.031\n",
      "Optimal temperature: 1.576\n",
      "After temperature - NLL: 0.151, ECE: 0.017\n",
      "Before temperature - NLL: 0.179, ECE: 0.027\n",
      "Optimal temperature: 1.551\n",
      "After temperature - NLL: 0.143, ECE: 0.017\n",
      "Before temperature - NLL: 0.180, ECE: 0.029\n",
      "Optimal temperature: 1.549\n",
      "After temperature - NLL: 0.144, ECE: 0.015\n",
      "Before temperature - NLL: 0.168, ECE: 0.029\n",
      "Optimal temperature: 1.536\n",
      "After temperature - NLL: 0.136, ECE: 0.014\n",
      "Before temperature - NLL: 0.180, ECE: 0.030\n",
      "Optimal temperature: 1.550\n",
      "After temperature - NLL: 0.143, ECE: 0.015\n",
      "Before temperature - NLL: 0.186, ECE: 0.028\n",
      "Optimal temperature: 1.548\n",
      "After temperature - NLL: 0.148, ECE: 0.013\n",
      "Before temperature - NLL: 0.170, ECE: 0.029\n",
      "Optimal temperature: 1.544\n",
      "After temperature - NLL: 0.136, ECE: 0.015\n",
      "Before temperature - NLL: 0.202, ECE: 0.032\n",
      "Optimal temperature: 1.574\n",
      "After temperature - NLL: 0.155, ECE: 0.019\n",
      "Before temperature - NLL: 0.163, ECE: 0.026\n",
      "Optimal temperature: 1.543\n",
      "After temperature - NLL: 0.131, ECE: 0.012\n",
      "Before temperature - NLL: 0.185, ECE: 0.026\n",
      "Optimal temperature: 1.559\n",
      "After temperature - NLL: 0.145, ECE: 0.014\n",
      "Before temperature - NLL: 0.202, ECE: 0.033\n",
      "Optimal temperature: 1.573\n",
      "After temperature - NLL: 0.156, ECE: 0.018\n",
      "Before temperature - NLL: 0.183, ECE: 0.031\n",
      "Optimal temperature: 1.556\n",
      "After temperature - NLL: 0.144, ECE: 0.015\n",
      "Before temperature - NLL: 0.212, ECE: 0.033\n",
      "Optimal temperature: 1.583\n",
      "After temperature - NLL: 0.161, ECE: 0.019\n",
      "Before temperature - NLL: 0.178, ECE: 0.030\n",
      "Optimal temperature: 1.550\n",
      "After temperature - NLL: 0.142, ECE: 0.016\n",
      "Before temperature - NLL: 0.172, ECE: 0.028\n",
      "Optimal temperature: 1.537\n",
      "After temperature - NLL: 0.139, ECE: 0.015\n",
      "Before temperature - NLL: 0.183, ECE: 0.031\n",
      "Optimal temperature: 1.539\n",
      "After temperature - NLL: 0.148, ECE: 0.014\n",
      "Before temperature - NLL: 0.201, ECE: 0.031\n",
      "Optimal temperature: 1.560\n",
      "After temperature - NLL: 0.155, ECE: 0.018\n",
      "Before temperature - NLL: 0.175, ECE: 0.030\n",
      "Optimal temperature: 1.552\n",
      "After temperature - NLL: 0.139, ECE: 0.014\n",
      "Before temperature - NLL: 0.185, ECE: 0.029\n",
      "Optimal temperature: 1.565\n",
      "After temperature - NLL: 0.144, ECE: 0.015\n",
      "Before temperature - NLL: 0.166, ECE: 0.024\n",
      "Optimal temperature: 1.534\n",
      "After temperature - NLL: 0.133, ECE: 0.010\n",
      "Before temperature - NLL: 0.150, ECE: 0.026\n",
      "Optimal temperature: 1.509\n",
      "After temperature - NLL: 0.126, ECE: 0.009\n",
      "Before temperature - NLL: 0.160, ECE: 0.028\n",
      "Optimal temperature: 1.530\n",
      "After temperature - NLL: 0.130, ECE: 0.014\n",
      "Before temperature - NLL: 0.190, ECE: 0.031\n",
      "Optimal temperature: 1.553\n",
      "After temperature - NLL: 0.151, ECE: 0.014\n",
      "Before temperature - NLL: 0.190, ECE: 0.030\n",
      "Optimal temperature: 1.566\n",
      "After temperature - NLL: 0.147, ECE: 0.016\n",
      "Before temperature - NLL: 0.148, ECE: 0.024\n",
      "Optimal temperature: 1.518\n",
      "After temperature - NLL: 0.122, ECE: 0.013\n",
      "Before temperature - NLL: 0.174, ECE: 0.026\n",
      "Optimal temperature: 1.541\n",
      "After temperature - NLL: 0.140, ECE: 0.012\n",
      "Before temperature - NLL: 0.177, ECE: 0.029\n",
      "Optimal temperature: 1.544\n",
      "After temperature - NLL: 0.141, ECE: 0.013\n",
      "Before temperature - NLL: 0.168, ECE: 0.028\n",
      "Optimal temperature: 1.540\n",
      "After temperature - NLL: 0.134, ECE: 0.012\n",
      "Before temperature - NLL: 0.189, ECE: 0.030\n",
      "Optimal temperature: 1.556\n",
      "After temperature - NLL: 0.150, ECE: 0.014\n",
      "Before temperature - NLL: 0.168, ECE: 0.029\n",
      "Optimal temperature: 1.525\n",
      "After temperature - NLL: 0.137, ECE: 0.013\n",
      "Before temperature - NLL: 0.191, ECE: 0.030\n",
      "Optimal temperature: 1.567\n",
      "After temperature - NLL: 0.148, ECE: 0.015\n",
      "Before temperature - NLL: 0.181, ECE: 0.026\n",
      "Optimal temperature: 1.548\n",
      "After temperature - NLL: 0.144, ECE: 0.015\n",
      "Before temperature - NLL: 0.165, ECE: 0.028\n",
      "Optimal temperature: 1.533\n",
      "After temperature - NLL: 0.135, ECE: 0.013\n",
      "Before temperature - NLL: 0.171, ECE: 0.031\n",
      "Optimal temperature: 1.527\n",
      "After temperature - NLL: 0.140, ECE: 0.017\n",
      "Before temperature - NLL: 0.176, ECE: 0.030\n",
      "Optimal temperature: 1.550\n",
      "After temperature - NLL: 0.139, ECE: 0.015\n",
      "Before temperature - NLL: 0.187, ECE: 0.033\n",
      "Optimal temperature: 1.554\n",
      "After temperature - NLL: 0.148, ECE: 0.015\n",
      "Before temperature - NLL: 0.168, ECE: 0.026\n",
      "Optimal temperature: 1.545\n",
      "After temperature - NLL: 0.135, ECE: 0.012\n"
     ]
    }
   ],
   "source": [
    "from temperature_scaling import ModelWithTemperature\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "import random\n",
    "\n",
    "def calculate_logloss_and_cali_free_loss(key, num_models=60):\n",
    "    \n",
    "    def test(net, testloader):\n",
    "        net.eval()\n",
    "        test_loss = 0\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        predictions = np.zeros((0, 10))\n",
    "        y = np.array([])\n",
    "        with torch.no_grad():\n",
    "            for batch_idx, (inputs, targets) in enumerate(testloader):\n",
    "                inputs, targets = inputs.to(device), targets.to(device)\n",
    "                outputs = net(inputs)\n",
    "                predictions = np.row_stack((predictions, outputs.cpu().numpy()))\n",
    "                y = np.append(y, targets.cpu().numpy())\n",
    "                loss = criterion(outputs, targets)\n",
    "                test_loss += loss.item()\n",
    "                _, predicted = outputs.max(1)\n",
    "                total += targets.size(0)\n",
    "                correct += predicted.eq(targets).sum().item()\n",
    "        # print(seed, correct, total, test_loss)\n",
    "        # print()\n",
    "        return predictions, y, test_loss / total, correct / total\n",
    "    \n",
    "    def get_log_loss(seed):\n",
    "        net = net_dict[key]\n",
    "        if device == 'cuda':\n",
    "            net = torch.nn.DataParallel(net)\n",
    "            cudnn.benchmark = True\n",
    "        net.load_state_dict(torch.load(f'./checkpoint/{key}/ckpt_{seed}.pth')[\"net\"])\n",
    "\n",
    "        classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
    "        criterion = nn.CrossEntropyLoss()\n",
    "        _, _, log_loss, log_loss_accuracy = test(net, testloader)\n",
    "        return log_loss, log_loss_accuracy\n",
    "    \n",
    "    def get_cali_free_log_loss(seed):\n",
    "        net = net_dict[key]\n",
    "        if device == 'cuda':\n",
    "            net = torch.nn.DataParallel(net)\n",
    "            cudnn.benchmark = True\n",
    "        net.load_state_dict(torch.load(f'./checkpoint/{key}/ckpt_{seed}.pth')[\"net\"])\n",
    "        net.eval()\n",
    "        model = ModelWithTemperature(net)\n",
    "        model.set_temperature(validation_loader)\n",
    "        _, _, cali_free_log_loss, cali_free_log_loss_accuracy = test(model, cali_free_test_loader)\n",
    "        return cali_free_log_loss, cali_free_log_loss_accuracy\n",
    "\n",
    "    \n",
    "    net_dict = {\n",
    "        \"baseline_resnet18\": ResNet18(),\n",
    "        \"baseline_resnet34\": ResNet34(),\n",
    "        \"baseline_VGG19\": VGG('VGG19'),\n",
    "        \"baseline_DenseNet121\": DenseNet121(),\n",
    "        \"baseline_GoogLeNet\": GoogLeNet(),\n",
    "        \"baseline_MobileNet\": MobileNet(),\n",
    "        \"baseline_MobileNetV2\": MobileNetV2(),\n",
    "        \"baseline_ResNeXt29_2x64d\": ResNeXt29_2x64d(),\n",
    "        \"baseline_resnet50\": ResNet50(),\n",
    "        \"baseline_resnet101\": ResNet101(),\n",
    "        \"baseline_resnet152\": ResNet152(),\n",
    "    }\n",
    "    \n",
    "    device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    \n",
    "    transform_test = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "    ])\n",
    "    testset = torchvision.datasets.CIFAR10(\n",
    "        root='./data', \n",
    "        train=False, \n",
    "        download=True, \n",
    "        transform=transform_test,\n",
    "    )\n",
    "    testloader = torch.utils.data.DataLoader(\n",
    "        testset, \n",
    "        batch_size=100, \n",
    "        shuffle=False, \n",
    "        num_workers=2,\n",
    "    )\n",
    "    \n",
    "    validation_indices = np.loadtxt('./validation_indices.txt', dtype='int')\n",
    "    validation_loader = torch.utils.data.DataLoader(\n",
    "        testset, \n",
    "        pin_memory=True, \n",
    "        batch_size=100,\n",
    "        sampler=SubsetRandomSampler(validation_indices),\n",
    "    )\n",
    "\n",
    "    cali_free_test_loader = torch.utils.data.DataLoader(\n",
    "        testset, \n",
    "        pin_memory=True, \n",
    "        batch_size=100,\n",
    "        sampler=SubsetRandomSampler(list(set(range(10000)) - set(validation_indices))),\n",
    "    )\n",
    "\n",
    "    log_loss_array = np.zeros(num_models)\n",
    "    log_loss_accuracy_array = np.zeros(num_models)\n",
    "    cali_free_log_loss_array = np.zeros(num_models)\n",
    "    cali_free_log_loss_accuracy_array = np.zeros(num_models)\n",
    "    for seed in range(0, num_models):\n",
    "        log_loss_array[seed], log_loss_accuracy_array[seed]  = get_log_loss(seed)\n",
    "        cali_free_log_loss_array[seed], cali_free_log_loss_accuracy_array[seed] = get_cali_free_log_loss(seed)\n",
    "    return log_loss_array, log_loss_accuracy_array, cali_free_log_loss_array, cali_free_log_loss_accuracy_array\n",
    "\n",
    "keys = [\n",
    "    # \"baseline_resnet18\",\n",
    "    # \"baseline_resnet34\",\n",
    "    # \"baseline_VGG19\",\n",
    "    # \"baseline_DenseNet121\",\n",
    "    # \"baseline_GoogLeNet\",\n",
    "    # \"baseline_MobileNet\",\n",
    "    # \"baseline_MobileNetV2\",\n",
    "    # \"baseline_ResNeXt29_2x64d\",\n",
    "    \"baseline_resnet50\",\n",
    "    \"baseline_resnet101\",\n",
    "    \"baseline_resnet152\",\n",
    "]\n",
    "\n",
    "\n",
    "results = {\n",
    "    key: {} for key in keys\n",
    "}\n",
    "for key in keys:\n",
    "    results[key][\"log_loss\"], results[key][\"log_loss_accuracy\"], results[key][\"cali_free_log_loss\"], results[key][\"cali_free_log_loss_accuracy\"] = calculate_logloss_and_cali_free_loss(\n",
    "        key=key,\n",
    "        num_models=60,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import log_loss\n",
    "import json\n",
    "import numpy as np\n",
    "\n",
    "def save_loss_results(\n",
    "    results, \n",
    "    file_name=\"./results/loss_results_CV.txt\", \n",
    "    append=True, \n",
    "    metrics_to_include=[\"log_loss\", \"cali_free_log_loss\"],\n",
    "):\n",
    "    result_json = {}\n",
    "    for key, item in results.items():\n",
    "        result_json[key] = {}\n",
    "        for metric in metrics_to_include:\n",
    "            if metric in item.keys():\n",
    "                result_json[key][metric] = item[metric].tolist()\n",
    "    if append:\n",
    "        f = open(file_name)\n",
    "        previous_results = json.load(f)\n",
    "        for key, item in previous_results.items():\n",
    "            if key not in result_json.keys():\n",
    "                result_json[key] = {}\n",
    "                for metric in metrics_to_include:\n",
    "                    if metric in item.keys():\n",
    "                        result_json[key][metric] = item[metric]\n",
    "    json_object = json.dumps(result_json)\n",
    "    with open(file_name, \"w\") as outfile:\n",
    "        outfile.write(json_object)\n",
    "\n",
    "metrics = [\"log_loss\", \"log_loss_accuracy\", \"cali_free_log_loss\", \"cali_free_log_loss_accuracy\"]\n",
    "save_loss_results(\n",
    "    results=results,\n",
    "    file_name=\"./results/loss_results_CV.txt\", \n",
    "    append=True, \n",
    "    metrics_to_include=[\"log_loss\", \"log_loss_accuracy\", \"cali_free_log_loss\", \"cali_free_log_loss_accuracy\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "baseline_DenseNet121 baseline_resnet50\n",
      "accuracy ratio:  0.21388888888888888\n",
      "log loss p-value:  9.999000099990002e-05\n",
      "accuracy p-value:  9.999000099990002e-05\n",
      "log loss accuracy:  0.9559566666666669 0.954065\n",
      "cali free log loss accuracy:  0.9555291666666667 0.9534\n",
      "log_loss variance:  3.1987072471201214e-09 1.041979565700442e-08\n",
      "cali_free_log_loss variance:  1.864031822289755e-09 4.645542598384818e-09\n",
      "log loss:  0.9380555555555555\n",
      "cali-free log loss:  0.6388888888888888\n",
      "log loss final accuracy:  0.938\n",
      "cali-free log loss final accuracy:  0.639\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from mlxtend.evaluate import permutation_test\n",
    "import json\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "f = open(\"./results/loss_results_CV.txt\")\n",
    "metrics = [\"log_loss\", \"log_loss_accuracy\", \"cali_free_log_loss\", \"cali_free_log_loss_accuracy\"]\n",
    "results = json.load(f)\n",
    "\n",
    "for key, value in results.items():\n",
    "    for metric in metrics:\n",
    "        if metric in value.keys():\n",
    "            value[metric] = np.array(value[metric])\n",
    "\n",
    "keys = list(results.keys())\n",
    "\n",
    "def interpret_result(results, key1, key2, cali_free_version=\"cali_free_log_loss\"):\n",
    "\n",
    "    def get_better_percetage(loss1, loss2): # get the percetage of loss1 < loss2\n",
    "        num_model = loss1.shape[0]\n",
    "        assert num_model == loss2.shape[0]\n",
    "        num_greater = 0\n",
    "        for i in range(num_model):\n",
    "            for j in range(num_model):\n",
    "                if loss1[i] < loss2[j]:\n",
    "                    num_greater += 1\n",
    "        return num_greater / (num_model ** 2)\n",
    "\n",
    "    log_loss_p_value = permutation_test(\n",
    "        results[key1][\"log_loss\"], \n",
    "        results[key2][\"log_loss\"],\n",
    "        method='approximate',\n",
    "        num_rounds=10000,\n",
    "        seed=0,\n",
    "    )\n",
    "    \n",
    "    accuracy_p_value = permutation_test(\n",
    "        results[key1][\"log_loss_accuracy\"], \n",
    "        results[key2][\"log_loss_accuracy\"],\n",
    "        method='approximate',\n",
    "        num_rounds=10000,\n",
    "        seed=0,\n",
    "    )\n",
    "    \n",
    "    if accuracy_p_value < 0.05:\n",
    "        accuracy_ratio = get_better_percetage(results[key1][\"log_loss_accuracy\"], results[key2][\"log_loss_accuracy\"])\n",
    "        print(key1, key2)\n",
    "        print(\"accuracy ratio: \", accuracy_ratio)\n",
    "        print(\"log loss p-value: \", log_loss_p_value)\n",
    "        print(\"accuracy p-value: \", accuracy_p_value)\n",
    "        print(\"log loss accuracy: \", np.mean(results[key1][\"log_loss_accuracy\"]), np.mean(results[key2][\"log_loss_accuracy\"]))\n",
    "        print(\"cali free log loss accuracy: \", np.mean(results[key1][\"cali_free_log_loss_accuracy\"]), np.mean(results[key2][\"cali_free_log_loss_accuracy\"]))\n",
    "        # print(\"log loss: \", results[key1][\"log_loss\"], results[key2][\"log_loss\"])\n",
    "        # print(\"cali-free log loss: \", results[key1][cali_free_version], results[key2][cali_free_version])\n",
    "        log_loss_result = get_better_percetage(results[key1][\"log_loss\"], results[key2][\"log_loss\"])\n",
    "        print(\"log_loss variance: \", np.var(results[key1][\"log_loss\"]), np.var(results[key2][\"log_loss\"]))\n",
    "        cali_free_log_loss_result = get_better_percetage(results[key1][cali_free_version], results[key2][cali_free_version])\n",
    "        print(\"cali_free_log_loss variance: \", np.var(results[key1][cali_free_version]), np.var(results[key2][cali_free_version]))\n",
    "        print(\"log loss: \", log_loss_result)\n",
    "        print(\"cali-free log loss: \", cali_free_log_loss_result)\n",
    "        \n",
    "        if accuracy_ratio > 0.5: # model1 is less accurate\n",
    "            print(\"accuracy final accuracy: \", round(accuracy_ratio, 3))\n",
    "            print(\"log loss final accuracy: \", round(1 - log_loss_result, 3))\n",
    "            print(\"cali-free log loss final accuracy: \", round(1 - cali_free_log_loss_result, 3))\n",
    "        else:\n",
    "            print(\"accuracy final accuracy: \", round(1 - accuracy_ratio, 3))\n",
    "            print(\"log loss final accuracy: \", round(log_loss_result, 3))\n",
    "            print(\"cali-free log loss final accuracy: \", round(cali_free_log_loss_result, 3))\n",
    "        print()\n",
    "    \n",
    "    \n",
    "    # return log_loss_result, cali_free_log_loss_result\n",
    "\n",
    "# for key, value in results.items():\n",
    "#     print(key)\n",
    "#     print(value[\"log_loss\"].shape)\n",
    "#     print(value[\"cali_free_log_loss\"].shape)\n",
    "#     print(value[\"log_loss_accuracy\"].shape)\n",
    "#     print(value[\"cali_free_log_loss_accuracy\"].shape)\n",
    "\n",
    "\n",
    "for i in range(len(keys)):\n",
    "    for j in range(i+1, len(keys)):\n",
    "        interpret_result(results, keys[j], keys[i])\n",
    "\n",
    "# interpret_result(results, \"baseline_DenseNet121\", \"baseline_resnet50\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
