{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a211b542-1420-4fb8-9039-5c3a05d164d7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import json\n",
    "import math\n",
    "import numpy as np\n",
    "import pickle as pkl\n",
    "import os\n",
    "import time\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import RandomSampler\n",
    "\n",
    "from data_processor import DataProcessor\n",
    "from trainer import Trainer\n",
    "import argparse\n",
    "import torchvision.models as models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b915dfe0-dbc4-41d0-8b53-d97ef95a7dd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# === DATA LOADING HELPERS =============================================================================================\n",
    "# find the dataset filepaths\n",
    "def get_dataset_paths(data_dir):\n",
    "    paths = sorted([os.path.join(data_dir, d) for d in os.listdir(data_dir) if 'dataset' in d], reverse=True)\n",
    "    return paths\n",
    "\n",
    "# load the dataset metadata from json\n",
    "def load_dataset_metadata(dataset_path):\n",
    "    with open(os.path.join(dataset_path, 'metadata'), \"r\") as f:\n",
    "        metadata = json.load(f)\n",
    "    return metadata\n",
    "\n",
    "# load dataset from file\n",
    "def load_datasets(data_path, truncate):\n",
    "    #data_path = '../../../datasets_val/'+data_path\n",
    "    data_path = '/home/woody/iwb3/iwb3021h/datasets_val1/'+data_path\n",
    "    \n",
    "    train_x = np.load(os.path.join(data_path,'train_x.npy'))\n",
    "    train_y = np.load(os.path.join(data_path,'train_y.npy'))\n",
    "    valid_x = np.load(os.path.join(data_path,'valid_x.npy'))\n",
    "    valid_y = np.load(os.path.join(data_path,'valid_y.npy'))\n",
    "    test_x = np.load(os.path.join(data_path,'test_x.npy'))\n",
    "    metadata = load_dataset_metadata(data_path)\n",
    "\n",
    "    if truncate:\n",
    "        train_x = train_x[:64]\n",
    "        train_y = train_y[:64]\n",
    "        valid_x = valid_x[:64]\n",
    "        valid_y = valid_y[:64]\n",
    "        test_x = test_x[:64]\n",
    "\n",
    "    return (train_x, train_y), \\\n",
    "           (valid_x, valid_y), \\\n",
    "           (test_x), metadata\n",
    "\n",
    "\n",
    "# === TIME COUNTERs ====================================================================================================\n",
    "def div_remainder(n, interval):\n",
    "    # finds divisor and remainder given some n/interval\n",
    "    factor = math.floor(n / interval)\n",
    "    remainder = int(n - (factor * interval))\n",
    "    return factor, remainder\n",
    "\n",
    "\n",
    "def show_time(seconds):\n",
    "    # show amount of time as human readable\n",
    "    if seconds < 60:\n",
    "        return \"{:.2f}s\".format(seconds)\n",
    "    elif seconds < (60 * 60):\n",
    "        minutes, seconds = div_remainder(seconds, 60)\n",
    "        return \"{}m,{}s\".format(minutes, seconds)\n",
    "    else:\n",
    "        hours, seconds = div_remainder(seconds, 60 * 60)\n",
    "        minutes, seconds = div_remainder(seconds, 60)\n",
    "        return \"{}h,{}m,{}s\".format(hours, minutes, seconds)\n",
    "\n",
    "\n",
    "# keep a counter of available time\n",
    "class Clock:\n",
    "    def __init__(self, time_available):\n",
    "        self.start_time =  time.time()\n",
    "        self.total_time = time_available\n",
    "\n",
    "    def check(self):\n",
    "        return self.total_time + self.start_time - time.time()\n",
    "\n",
    "\n",
    "# === MODEL ANALYSIS ===================================================================================================\n",
    "def general_num_params(model):\n",
    "    # return number of differential parameters of input model\n",
    "    return sum([np.prod(p.size()) for p in filter(lambda p: p.requires_grad, model.parameters())])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a48c1acd-d3ba-4e38-b515-f80b02d9b688",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0c180a8f-6140-4d56-91be-459ed3e2984c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda')"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4f6e16cb-fe76-43a0-9b6d-86b132958a57",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ic| self.select_augment: 'Proxy'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metadata:\n",
      "   - input_shape         : [50000, 20, 20, 20]\n",
      "   - codename            : Volga\n",
      "   - benchmark           : 71.35\n",
      "   - num_classes         : 7\n",
      "   - time_remaining      : 107980.68109369278\n",
      "\n",
      "=== Processing Data ===\n",
      "  Allotted compute time remaining: ~29h,59m,40s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ic| self.x.shape: torch.Size([50000, 20, 20, 20])\n",
      "ic| unique_values: array([0., 1.], dtype=float32)\n",
      "ic| C: 20\n",
      "ic| H: 20\n",
      "ic| PH: 2\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torchvision/transforms/v2/_deprecated.py:41: UserWarning: The transform `ToTensor()` is deprecated and will be removed in a future release. Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`.\n",
      "  warnings.warn(\n",
      "ic| '#############'\n",
      "ic| aug: 0\n",
      "ic| train_loader.dataset.transform.transforms: [Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "ic| '#############'\n",
      "ic| aug: 1\n",
      "ic| train_loader.dataset.transform.transforms: [RandAugment(interpolation=InterpolationMode.NEAREST, num_ops=2, magnitude=9, num_magnitude_bins=31),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "ic| '#############'\n",
      "ic| aug: 2\n",
      "ic| train_loader.dataset.transform.transforms: [RandAugment(interpolation=InterpolationMode.NEAREST, num_ops=2, magnitude=5, num_magnitude_bins=31),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "ic| '#############'\n",
      "ic| aug: 3\n",
      "ic| train_loader.dataset.transform.transforms: [RandAugment(interpolation=InterpolationMode.NEAREST, num_ops=2, magnitude=1, num_magnitude_bins=31),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "ic| '#############'\n",
      "ic| aug: 4\n",
      "ic| train_loader.dataset.transform.transforms: [TrivialAugmentWide(interpolation=InterpolationMode.NEAREST, num_magnitude_bins=31),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "ic| '#############'\n",
      "ic| aug: 5\n",
      "ic| train_loader.dataset.transform.transforms: [TrivialAugmentWide(interpolation=InterpolationMode.NEAREST, num_magnitude_bins=15),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "ic| '#############'\n",
      "ic| aug: 6\n",
      "ic| train_loader.dataset.transform.transforms: [AugMix(interpolation=InterpolationMode.BILINEAR, severity=3, mixture_width=3, chain_depth=-1, alpha=1.0, all_ops=True),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "ic| '#############'\n",
      "ic| aug: 7\n",
      "ic| train_loader.dataset.transform.transforms: [AugMix(interpolation=InterpolationMode.BILINEAR, severity=1, mixture_width=3, chain_depth=-1, alpha=1.0, all_ops=True),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "ic| '#############'\n",
      "ic| aug: 8\n",
      "ic| train_loader.dataset.transform.transforms: [RandomHorizontalFlip(p=0.5),\n",
      "                                                RandomVerticalFlip(p=0.5),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "ic| '#############'\n",
      "ic| aug: 9\n",
      "ic| train_loader.dataset.transform.transforms: [RandomErasing(p=0.2, scale=(0.05, 0.2), ratio=(0.3, 3.3), value=[0.0], inplace=False),\n",
      "                                                RandomHorizontalFlip(p=0.5),\n",
      "                                                RandomVerticalFlip(p=0.5),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "ic| '#############'\n",
      "ic| aug: 10\n",
      "ic| train_loader.dataset.transform.transforms: [RandomErasing(p=0.2, scale=(0.05, 0.2), ratio=(0.3, 3.3), value=[0.0], inplace=False),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "ic| '#############'\n",
      "ic| aug: 11\n",
      "ic| train_loader.dataset.transform.transforms: [RandomErasing(p=0.2, scale=(0.02, 0.2), ratio=(0.3, 3.3), value=[0.0], inplace=False),\n",
      "                                                RandomCrop(size=(20, 20), padding=[2, 2, 2, 2], pad_if_needed=False, fill=0, padding_mode=constant),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "ic| '#############'\n",
      "ic| aug: 12\n",
      "ic| train_loader.dataset.transform.transforms: [RandomCrop(size=(20, 20), padding=[2, 2, 2, 2], pad_if_needed=False, fill=0, padding_mode=constant),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "ic| '#############'\n",
      "ic| aug: 13\n",
      "ic| train_loader.dataset.transform.transforms: [RandomCrop(size=(20, 20), padding=[2, 2, 2, 2], pad_if_needed=False, fill=0, padding_mode=constant),\n",
      "                                                RandomHorizontalFlip(p=0.5),\n",
      "                                                RandomVerticalFlip(p=0.5),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "ic| '#############'\n",
      "ic| aug: 14\n",
      "ic| train_loader.dataset.transform.transforms: [RandomErasing(p=0.2, scale=(0.02, 0.2), ratio=(0.3, 3.3), value=[0.0], inplace=False),\n",
      "                                                RandomCrop(size=(20, 20), padding=[2, 2, 2, 2], pad_if_needed=False, fill=0, padding_mode=constant),\n",
      "                                                RandomHorizontalFlip(p=0.5),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "ic| '#############'\n",
      "ic| aug: 15\n",
      "ic| train_loader.dataset.transform.transforms: [<data_processor.RandomPixelChange object at 0x7efcbe5cf9e0>,\n",
      "                                                ToTensor(),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "ic| '#############'\n",
      "ic| aug: 16\n",
      "ic| train_loader.dataset.transform.transforms: [<data_processor.RandomPixelChange object at 0x7efc7df4b410>,\n",
      "                                                ToTensor(),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "ic| '#############'\n",
      "ic| aug: 17\n",
      "ic| train_loader.dataset.transform.transforms: [<data_processor.RandomPixelChange object at 0x7efcbe1e5ee0>,\n",
      "                                                ToTensor(),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "ic| '#############'\n",
      "ic| aug: 18\n",
      "ic| train_loader.dataset.transform.transforms: [<data_processor.RandomPixelChange object at 0x7efcbe1e63c0>,\n",
      "                                                ToTensor(),\n",
      "                                                RandomHorizontalFlip(p=0.5),\n",
      "                                                RandomVerticalFlip(p=0.5),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "ic| '#############'\n",
      "ic| aug: 19\n",
      "ic| train_loader.dataset.transform.transforms: [<data_processor.RandomPixelChange object at 0x7efc5dae6de0>,\n",
      "                                                ToTensor(),\n",
      "                                                RandomErasing(p=0.2, scale=(0.05, 0.2), ratio=(0.3, 3.3), value=[0.0], inplace=False),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "ic| '#############'\n",
      "ic| aug: 20\n",
      "ic| train_loader.dataset.transform.transforms: [<data_processor.RandomPixelChange object at 0x7efc5dae5b20>,\n",
      "                                                ToTensor(),\n",
      "                                                RandomCrop(size=(20, 20), padding=[2, 2, 2, 2], pad_if_needed=False, fill=0, padding_mode=constant),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "ic| '#############'\n",
      "ic| aug: 21\n",
      "ic| train_loader.dataset.transform.transforms: [<data_processor.RandomPixelChange object at 0x7efc92235220>,\n",
      "                                                ToTensor(),\n",
      "                                                RandomHorizontalFlip(p=0.5),\n",
      "                                                RandomVerticalFlip(p=0.5),\n",
      "                                                RandomErasing(p=0.2, scale=(0.05, 0.2), ratio=(0.3, 3.3), value=[0.0], inplace=False),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "/apps/jupyterhub/jh3.1.1-py3.11/envs/pytorch-2.2.0/lib/python3.12/site-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n",
      "ic| '#############'\n",
      "ic| aug: 22\n",
      "ic| train_loader.dataset.transform.transforms: [AutoAugment(interpolation=InterpolationMode.NEAREST, policy=AutoAugmentPolicy.IMAGENET),\n",
      "                                                Normalize(mean=tensor([0.0724, 0.0832, 0.0932, 0.1020, 0.1097, 0.1162, 0.1213, 0.1250, 0.1275,\n",
      "                                                       0.1288, 0.1289, 0.1274, 0.1250, 0.1211, 0.1157, 0.1091, 0.1012, 0.0924,\n",
      "                                                       0.0824, 0.0716]), std=tensor([0.2592, 0.2762, 0.2907, 0.3026, 0.3125, 0.3204, 0.3264, 0.3307, 0.3336,\n",
      "                                                       0.3350, 0.3351, 0.3334, 0.3308, 0.3263, 0.3198, 0.3117, 0.3016, 0.2896,\n",
      "                                                       0.2750, 0.2578]))]\n",
      "ic| f\"best_augmentation: {best_aug}\": 'best_augmentation: 1'\n",
      "ic| f\"selected transform {train_transform}\": ('selected transform [RandAugment(interpolation=InterpolationMode.NEAREST, '\n",
      "                                              'num_ops=2, magnitude=9, num_magnitude_bins=31)]')\n",
      "ic| self.x.shape: torch.Size([50000, 20, 20, 20])\n",
      "ic| self.x.shape: torch.Size([10000, 20, 20, 20])\n",
      "ic| self.x.shape: torch.Size([10000, 20, 20, 20])\n"
     ]
    }
   ],
   "source": [
    "# === MAIN =============================================================================================================\n",
    "# the available runtime will change at various stages of the competition, but feel free to change for local tests\n",
    "# note, this is approximate, your runtime will be controlled externally by our server\n",
    "total_runtime_hours = 30\n",
    "total_runtime_seconds = total_runtime_hours * 60 * 60\n",
    "\n",
    "\n",
    "dataset=\"Voxel\"\n",
    "#MultNIST, Language, Gutenberg, CIFARTile, Chesseract, AddNIST, \n",
    "#CIFAR10, ImageNet16-120; Sudoku, Voxel, \n",
    "#select_augment=\"Model\"\n",
    "select_augment=\"Proxy\"\n",
    "model_name=\"EfficientNet_b0\"\n",
    "# ResNet18, MobileNetV3_large; ,RegNetY_400MF, EfficientNet_b0\n",
    "seed=1\n",
    "\n",
    "runclock = Clock(total_runtime_seconds)\n",
    "\n",
    "\n",
    "# load and display data info\n",
    "(train_x, train_y), (valid_x, valid_y), (test_x), metadata = load_datasets(dataset, truncate=False)\n",
    "metadata['time_remaining'] = runclock.check()\n",
    "start_time = time.time()\n",
    "\n",
    "print(\"Metadata:\")\n",
    "[print(\"   - {:<20}: {}\".format(k, v)) for k,v in metadata.items()]\n",
    "\n",
    "# perform data processing/augmentation/etc using your DataProcessor\n",
    "print(\"\\n=== Processing Data ===\")\n",
    "print(\"  Allotted compute time remaining: ~{}\".format(show_time(runclock.check())))\n",
    "\n",
    "#### Select model\n",
    "import torch.nn as nn\n",
    "metadata[\"model_name\"]=model_name\n",
    "if model_name==\"ResNet18\":\n",
    "    model = models.resnet18(weights=None)\n",
    "    new_conv1 = torch.nn.Conv2d(in_channels=metadata[\"input_shape\"][1], \n",
    "                              out_channels=model.conv1.out_channels, \n",
    "                              kernel_size=model.conv1.kernel_size, \n",
    "                              stride=model.conv1.stride, \n",
    "                              padding=model.conv1.padding, \n",
    "                              bias=model.conv1.bias)\n",
    "    # Replace the first convolutional layer\n",
    "    model.conv1 = new_conv1\n",
    "    model.fc = torch.nn.Linear(512, metadata['num_classes'])\n",
    "    model.to(device)\n",
    "####\n",
    "if model_name == \"RegNetY_400MF\":\n",
    "        # Load a predefined RegNet architecture\n",
    "        model = models.regnet_y_400mf(weights=None)  # Using no pre-trained weights\n",
    "\n",
    "        # Adjust the first convolution layer for the specific input channels\n",
    "        new_stem_conv = torch.nn.Conv2d(\n",
    "            in_channels=metadata[\"input_shape\"][1],  # Dynamic input channels\n",
    "            out_channels=model.stem[0].out_channels,  # Preserve original out channels\n",
    "            kernel_size=model.stem[0].kernel_size,\n",
    "            stride=model.stem[0].stride,\n",
    "            padding=model.stem[0].padding,\n",
    "            bias=model.stem[0].bias\n",
    "        )\n",
    "        # Replace the stem's initial convolutional layer\n",
    "        model.stem[0] = new_stem_conv\n",
    "\n",
    "        # Replace the fully connected layer with a dynamic output layer\n",
    "        model.fc = torch.nn.Linear(model.fc.in_features, metadata['num_classes'])\n",
    "\n",
    "        # Move the model to the specified device\n",
    "        model.to(device)\n",
    "\n",
    "if model_name == \"VGG19\":\n",
    "    model = models.vgg19(weights=None)\n",
    "    # Modify the first convolutional layer\n",
    "    model.features[0] = torch.nn.Conv2d(\n",
    "        in_channels=metadata[\"input_shape\"][1],\n",
    "        out_channels=model.features[0].out_channels,\n",
    "        kernel_size=model.features[0].kernel_size,\n",
    "        stride=model.features[0].stride,\n",
    "        padding=model.features[0].padding\n",
    "    )\n",
    "\n",
    "    # Add a global average pooling before the classifier\n",
    "    model.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))\n",
    "\n",
    "    # Modify the classification layer\n",
    "    model.classifier[-1] = torch.nn.Linear(model.classifier[-1].in_features, metadata['num_classes'])\n",
    "\n",
    "    # Move model to device\n",
    "    model.to(device)\n",
    "\n",
    "\n",
    "if model_name == \"MobileNetV3_large\":\n",
    "    model = models.mobilenet_v3_large(weights=None)\n",
    "    # Modify the first convolutional layer\n",
    "    model.features[0][0] = torch.nn.Conv2d(\n",
    "        in_channels=metadata[\"input_shape\"][1],\n",
    "        out_channels=model.features[0][0].out_channels,\n",
    "        kernel_size=model.features[0][0].kernel_size,\n",
    "        stride=model.features[0][0].stride,\n",
    "        padding=model.features[0][0].padding,\n",
    "        bias=model.features[0][0].bias\n",
    "    )\n",
    "    # Modify the classification layer\n",
    "    model.classifier[-1] = torch.nn.Linear(model.classifier[-1].in_features, metadata['num_classes'])\n",
    "    model.to(device)\n",
    "\n",
    "\n",
    "if model_name == \"EfficientNet_b0\":\n",
    "    model = models.efficientnet_b0(weights=None)\n",
    "    # Modify the first convolutional layer\n",
    "    model.features[0][0] = torch.nn.Conv2d(\n",
    "        in_channels=metadata[\"input_shape\"][1],\n",
    "        out_channels=model.features[0][0].out_channels,\n",
    "        kernel_size=model.features[0][0].kernel_size,\n",
    "        stride=model.features[0][0].stride,\n",
    "        padding=model.features[0][0].padding,\n",
    "        bias=model.features[0][0].bias\n",
    "    )\n",
    "    # Modify the classification layer\n",
    "    model.classifier[-1] = torch.nn.Linear(model.classifier[-1].in_features, metadata['num_classes'])\n",
    "    model.to(device)\n",
    "\n",
    "####\n",
    "    \n",
    "class ModifiedResNet18(models.ResNet):\n",
    "    def __init__(self, metadata):\n",
    "        super().__init__(block=models.resnet.BasicBlock, layers=[2, 2, 2, 2])\n",
    "        \n",
    "        # Modify the first convolutional layer\n",
    "        self.conv1 = nn.Conv2d(\n",
    "            in_channels=metadata[\"input_shape\"][1],\n",
    "            out_channels=self.conv1.out_channels,\n",
    "            kernel_size=self.conv1.kernel_size,\n",
    "            stride=self.conv1.stride,\n",
    "            padding=self.conv1.padding,\n",
    "            bias=self.conv1.bias,\n",
    "        )\n",
    "        \n",
    "        # Modify the final fully connected layer\n",
    "        self.fc = nn.Linear(512, metadata['num_classes'])\n",
    "        \n",
    "    def forward_before_global_avg_pool(self, x):\n",
    "        # Extract all layers before the global average pooling\n",
    "        x = self.conv1(x)\n",
    "        x = self.bn1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.maxpool(x)\n",
    "\n",
    "        x = self.layer1(x)\n",
    "        x = self.layer2(x)\n",
    "        x = self.layer3(x)\n",
    "        x = self.layer4(x)\n",
    "        \n",
    "        return x  # Return feature map before pooling\n",
    "\n",
    "\n",
    "#model = ModifiedResNet18(metadata)\n",
    "#model.to(device)\n",
    "\n",
    "#####\n",
    "\n",
    "data_processor = DataProcessor(train_x, train_y, valid_x, valid_y, test_x, metadata, select_augment, model)\n",
    "train_loader, valid_loader, test_loader = data_processor.process()\n",
    "metadata['time_remaining'] = runclock.check()\n",
    "\n",
    "# check that the test_loader is configured correctly\n",
    "#assert_string = \"Test Dataloader is {}, this will break evaluation. Please fix this in your DataProcessor init.\"\n",
    "#assert not isinstance(test_loader.sampler, RandomSampler), assert_string.format(\"shuffling\")\n",
    "#assert not test_loader.drop_last, assert_string.format(\"dropping last batch\")\n",
    "\n",
    "# search for best model using your NAS algorithm\n",
    "#print(\"\\n=== Performing NAS ===\")\n",
    "#print(\"  Allotted compute time remaining: ~{}\".format(show_time(runclock.check())))\n",
    "\n",
    "\n",
    "\n",
    "#model_params = int(general_num_params(model))\n",
    "#metadata['time_remaining'] = runclock.check()\n",
    "\n",
    "# train model using your Trainer\n",
    "#print(\"\\n=== Training ===\")\n",
    "#print(\"  Allotted compute time remaining: ~{}\".format(show_time(runclock.check())))\n",
    "#device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device('cpu')\n",
    "#trainer = Trainer(model, device, train_loader, valid_loader, metadata)\n",
    "#trained_model = trainer.train()\n",
    "\n",
    "# submit predictions to file\n",
    "#print(\"\\n=== Predicting ===\")\n",
    "#print(\"  Allotted compute time remaining: ~{}\".format(show_time(runclock.check())))\n",
    "#predictions = trainer.predict(test_loader)\n",
    "#run_data = {'Runtime': float(np.round(time.time()-start_time, 2)), 'Params': model_params}\n",
    "#with open(\"predictions/{}_stats.pkl\".format(metadata['codename']), \"wb\") as f:\n",
    "#    pkl.dump(run_data, f)\n",
    "#np.save('predictions/{}.npy'.format(metadata['codename']), predictions)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2e9af1c-41c8-451e-98c6-85cefc762895",
   "metadata": {},
   "source": [
    "# Read training results for augmentations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b3b2b44d-c166-4e2f-83fd-10d7f21e7a4b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def results_to_df(path):\n",
    "    data = []\n",
    "    # Open the text file\n",
    "    with open(path, 'r') as file:\n",
    "        lines = file.readlines()\n",
    "        # Initialize an empty dictionary to store data for each block\n",
    "        block_data = {}\n",
    "        for line in lines:\n",
    "            # If the line contains dashes, it indicates the end of a block\n",
    "            if '-------------------------' in line:\n",
    "                # If block_data is not empty, add it to the list of data dictionaries\n",
    "                if block_data:\n",
    "                    data.append(block_data)\n",
    "                    # Reset block_data for the next block\n",
    "                    block_data = {}\n",
    "            elif 'best_acc' in line:\n",
    "                continue\n",
    "            else:\n",
    "                # Split the line by ':'\n",
    "                #print(line)\n",
    "                key, value = line.strip().split(': ')\n",
    "                # Store the key-value pair in the block_data dictionary\n",
    "                block_data[key] = value\n",
    "\n",
    "    # Create a DataFrame from the list of dictionaries\n",
    "    df = pd.DataFrame(data)\n",
    "\n",
    "    # Convert columns to appropriate data types if needed\n",
    "    df['epoch'] = df['epoch'].astype(int)\n",
    "    df['lr'] = df['lr'].astype(float)\n",
    "    df['train_acc'] = df['train_acc'].astype(float)\n",
    "    df['train_loss'] = df['train_loss'].astype(float)\n",
    "    df['test_acc'] = df['test_acc'].astype(float)\n",
    "    df['test_acc_top5'] = df['test_acc_top5'].astype(float)\n",
    "    df['test_loss'] = df['test_loss'].astype(float)\n",
    "    df['epoch_time'] = df['epoch_time'].astype(float)\n",
    "\n",
    "    \n",
    "    return df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "800b3513-ff95-4781-beef-8d1d1ae71c61",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "correlations=[]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "7f9a5d14-ba0c-4567-99a8-64b6fa0d1d9a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hpc/iwb3/iwb3021h/.local/lib/python3.12/site-packages/sklearn/utils/_array_api.py:472: RuntimeWarning: All-NaN slice encountered\n",
      "  return xp.asarray(numpy.nanmin(X, axis=axis))\n",
      "/home/hpc/iwb3/iwb3021h/.local/lib/python3.12/site-packages/sklearn/utils/_array_api.py:489: RuntimeWarning: All-NaN slice encountered\n",
      "  return xp.asarray(numpy.nanmax(X, axis=axis))\n",
      "/home/hpc/iwb3/iwb3021h/.local/lib/python3.12/site-packages/sklearn/utils/_array_api.py:472: RuntimeWarning: All-NaN slice encountered\n",
      "  return xp.asarray(numpy.nanmin(X, axis=axis))\n",
      "/home/hpc/iwb3/iwb3021h/.local/lib/python3.12/site-packages/sklearn/utils/_array_api.py:489: RuntimeWarning: All-NaN slice encountered\n",
      "  return xp.asarray(numpy.nanmax(X, axis=axis))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "22\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hpc/iwb3/iwb3021h/.local/lib/python3.12/site-packages/sklearn/utils/_array_api.py:472: RuntimeWarning: All-NaN slice encountered\n",
      "  return xp.asarray(numpy.nanmin(X, axis=axis))\n",
      "/home/hpc/iwb3/iwb3021h/.local/lib/python3.12/site-packages/sklearn/utils/_array_api.py:489: RuntimeWarning: All-NaN slice encountered\n",
      "  return xp.asarray(numpy.nanmax(X, axis=axis))\n",
      "/home/hpc/iwb3/iwb3021h/.local/lib/python3.12/site-packages/sklearn/utils/_array_api.py:472: RuntimeWarning: All-NaN slice encountered\n",
      "  return xp.asarray(numpy.nanmin(X, axis=axis))\n",
      "/home/hpc/iwb3/iwb3021h/.local/lib/python3.12/site-packages/sklearn/utils/_array_api.py:489: RuntimeWarning: All-NaN slice encountered\n",
      "  return xp.asarray(numpy.nanmax(X, axis=axis))\n",
      "/home/hpc/iwb3/iwb3021h/.local/lib/python3.12/site-packages/sklearn/utils/_array_api.py:472: RuntimeWarning: All-NaN slice encountered\n",
      "  return xp.asarray(numpy.nanmin(X, axis=axis))\n",
      "/home/hpc/iwb3/iwb3021h/.local/lib/python3.12/site-packages/sklearn/utils/_array_api.py:489: RuntimeWarning: All-NaN slice encountered\n",
      "  return xp.asarray(numpy.nanmax(X, axis=axis))\n",
      "/home/hpc/iwb3/iwb3021h/.local/lib/python3.12/site-packages/sklearn/utils/_array_api.py:472: RuntimeWarning: All-NaN slice encountered\n",
      "  return xp.asarray(numpy.nanmin(X, axis=axis))\n",
      "/home/hpc/iwb3/iwb3021h/.local/lib/python3.12/site-packages/sklearn/utils/_array_api.py:489: RuntimeWarning: All-NaN slice encountered\n",
      "  return xp.asarray(numpy.nanmax(X, axis=axis))\n",
      "/home/hpc/iwb3/iwb3021h/.local/lib/python3.12/site-packages/sklearn/utils/_array_api.py:472: RuntimeWarning: All-NaN slice encountered\n",
      "  return xp.asarray(numpy.nanmin(X, axis=axis))\n",
      "/home/hpc/iwb3/iwb3021h/.local/lib/python3.12/site-packages/sklearn/utils/_array_api.py:489: RuntimeWarning: All-NaN slice encountered\n",
      "  return xp.asarray(numpy.nanmax(X, axis=axis))\n",
      "/home/hpc/iwb3/iwb3021h/.local/lib/python3.12/site-packages/sklearn/utils/_array_api.py:472: RuntimeWarning: All-NaN slice encountered\n",
      "  return xp.asarray(numpy.nanmin(X, axis=axis))\n",
      "/home/hpc/iwb3/iwb3021h/.local/lib/python3.12/site-packages/sklearn/utils/_array_api.py:489: RuntimeWarning: All-NaN slice encountered\n",
      "  return xp.asarray(numpy.nanmax(X, axis=axis))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "22\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hpc/iwb3/iwb3021h/.local/lib/python3.12/site-packages/sklearn/utils/_array_api.py:472: RuntimeWarning: All-NaN slice encountered\n",
      "  return xp.asarray(numpy.nanmin(X, axis=axis))\n",
      "/home/hpc/iwb3/iwb3021h/.local/lib/python3.12/site-packages/sklearn/utils/_array_api.py:489: RuntimeWarning: All-NaN slice encountered\n",
      "  return xp.asarray(numpy.nanmax(X, axis=axis))\n",
      "/home/hpc/iwb3/iwb3021h/.local/lib/python3.12/site-packages/sklearn/utils/_array_api.py:472: RuntimeWarning: All-NaN slice encountered\n",
      "  return xp.asarray(numpy.nanmin(X, axis=axis))\n",
      "/home/hpc/iwb3/iwb3021h/.local/lib/python3.12/site-packages/sklearn/utils/_array_api.py:489: RuntimeWarning: All-NaN slice encountered\n",
      "  return xp.asarray(numpy.nanmax(X, axis=axis))\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "dataset=\"Gutenberg\"\n",
    "#Adaline, Caitie, Chester, CIFAR10, Gutenberg, in16, LaMelo, Mateo, Sokoto, Volga\n",
    "#MobileNetV3_large, EfficientNet_b0,  RegNetY_400MF\n",
    "model_name=\"EfficientNet_b0\"\n",
    "for dataset in [\"Adaline\", \"Caitie\", \"Chester\", \"CIFAR10\", \"Gutenberg\", \"in16\", \"LaMelo\", \"Mateo\", \"Sokoto\", \"Volga\"]:\n",
    "    rows=[]\n",
    "    for aug in range(23):\n",
    "        try:\n",
    "            df_scores=results_to_df(f\"/home/woody/iwb3/iwb3021h/augmentations_test/{model_name}/{dataset}/aug_{aug}/worklog.txt\")\n",
    "\n",
    "            row_max = df_scores.loc[[df_scores['test_acc'].idxmax()]]\n",
    "            row_max=row_max.assign(aug=aug)\n",
    "            rows.append(row_max[[\"train_acc\",\"test_acc\",\"test_loss\", \"aug\"]])\n",
    "        except:\n",
    "            print(aug)        \n",
    "        total_scores_df=pd.concat(rows)\n",
    "\n",
    "    #################\n",
    "    measures=[\"epe_nas\",\"nwot\",\"plain\",\"snip\", \"fisher\",\"jacob_cov\", \"grad_norm\",\"grasp\"]\n",
    "    zcosts=pd.read_csv(f\"/home/woody/iwb3/iwb3021h/augmentations_test/{model_name}/{dataset}/ranks_{dataset}.csv\")#, index_col=0)\n",
    "    try: \n",
    "        zcosts=zcosts[measures+[\"aug\"]]#\"synflow\",\"l2_norm\"\n",
    "    except:\n",
    "        measures=[\"epe_nas\",\"nwot\",\"plain\",\"snip\", \"fisher\",\"jacob_cov\", \"grad_norm\"]\n",
    "        zcosts=zcosts[measures+[\"aug\"]]#\"synflow\",\"l2_norm\"\n",
    "    ####################\n",
    "\n",
    "    zcosts=pd.merge(total_scores_df, zcosts, on=\"aug\", how=\"left\").sort_values(by=\"test_acc\", ascending=False)\n",
    "\n",
    "    # Check for columns with complex values\n",
    "    for col in zcosts.select_dtypes(include=['O']).columns:\n",
    "        # Convert to complex and take only the real part\n",
    "        zcosts[col] = pd.to_numeric(zcosts[col], errors='coerce', downcast=None).apply(\n",
    "            lambda x: x.real if pd.notnull(x) else x\n",
    "        )\n",
    "\n",
    "    zcosts.replace([np.inf, -np.inf], np.nan, inplace=True)\n",
    "    # Fill NaN (previously inf) with the column means\n",
    "    zcosts.fillna(zcosts.mean(), inplace=True)\n",
    "\n",
    "    #Correlation\n",
    "    spearman_corr = zcosts.corr(method='spearman')\n",
    "    spearman_corr.to_csv(f\"/home/woody/iwb3/iwb3021h/augmentations_test/{model_name}/{dataset}/corr_{dataset}.csv\")\n",
    "\n",
    "    spearman_corr=spearman_corr[[\"test_acc\",\"aug\",\"test_loss\"]].iloc[4:]\n",
    "    spearman_corr[\"model\"]=model_name\n",
    "    spearman_corr[\"dataset\"]=dataset\n",
    "    correlations.append(spearman_corr)\n",
    "    #########################\n",
    "    #Scaling\n",
    "    from sklearn.preprocessing import MinMaxScaler\n",
    "    scaler = MinMaxScaler()\n",
    "    zcosts[measures] = scaler.fit_transform(zcosts[measures])\n",
    "    #########################################################\n",
    "    for measure in measures:\n",
    "        zcosts[f'rank_{measure}'] = zcosts[f'{measure}'].rank(ascending=False, method='dense')\n",
    "    zcosts[\"fisher_jacob\"]=zcosts[f'fisher']+zcosts[f'jacob_cov']\n",
    "    zcosts[f'rank_fisher_jacob'] = zcosts[\"fisher_jacob\"].rank(ascending=False, method='dense')\n",
    "    #zcosts[f'rank_fisher_jacob'] = (zcosts[f'fisher']+zcosts[f'jacob_cov']).rank(ascending=False, method='dense')\n",
    "    zcosts[\"model\"]=model_name\n",
    "    zcosts.to_csv(f\"/home/woody/iwb3/iwb3021h/augmentations_test/{model_name}/{dataset}/zcosts_{dataset}.csv\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "9be226e1-a1c2-4792-bf8c-ec8ba119432d",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>test_acc</th>\n",
       "      <th>fisher_jacob</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>72.78</td>\n",
       "      <td>0.507096</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>72.65</td>\n",
       "      <td>0.968526</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>72.33</td>\n",
       "      <td>1.218252</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>72.14</td>\n",
       "      <td>1.325703</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>71.97</td>\n",
       "      <td>0.823424</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>71.96</td>\n",
       "      <td>1.153773</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>70.86</td>\n",
       "      <td>1.165185</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>70.63</td>\n",
       "      <td>0.966442</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>70.40</td>\n",
       "      <td>0.790863</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>69.24</td>\n",
       "      <td>1.389937</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>67.02</td>\n",
       "      <td>1.032906</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>66.74</td>\n",
       "      <td>1.012603</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>66.25</td>\n",
       "      <td>0.963571</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>66.14</td>\n",
       "      <td>1.083881</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>60.00</td>\n",
       "      <td>1.143043</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    test_acc  fisher_jacob\n",
       "1      72.78      0.507096\n",
       "4      72.65      0.968526\n",
       "5      72.33      1.218252\n",
       "2      72.14      1.325703\n",
       "6      71.97      0.823424\n",
       "7      71.96      1.153773\n",
       "13     70.86      1.165185\n",
       "11     70.63      0.966442\n",
       "14     70.40      0.790863\n",
       "3      69.24      1.389937\n",
       "12     67.02      1.032906\n",
       "0      66.74      1.012603\n",
       "9      66.25      0.963571\n",
       "8      66.14      1.083881\n",
       "10     60.00      1.143043"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "zcosts[[\"test_acc\", \"fisher_jacob\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "d74435fb-7528-4621-87be-9833629f9880",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>index</th>\n",
       "      <th>train_acc</th>\n",
       "      <th>test_acc</th>\n",
       "      <th>test_loss</th>\n",
       "      <th>aug</th>\n",
       "      <th>epe_nas</th>\n",
       "      <th>nwot</th>\n",
       "      <th>plain</th>\n",
       "      <th>snip</th>\n",
       "      <th>fisher</th>\n",
       "      <th>jacob_cov</th>\n",
       "      <th>grad_norm</th>\n",
       "      <th>dataset</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>test_acc</td>\n",
       "      <td>0.290514</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-0.952736</td>\n",
       "      <td>-0.631423</td>\n",
       "      <td>0.236166</td>\n",
       "      <td>NaN</td>\n",
       "      <td>-0.037549</td>\n",
       "      <td>0.268775</td>\n",
       "      <td>0.245059</td>\n",
       "      <td>-0.339921</td>\n",
       "      <td>0.333992</td>\n",
       "      <td>Adaline</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>test_acc</td>\n",
       "      <td>0.363636</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-0.972813</td>\n",
       "      <td>-0.503953</td>\n",
       "      <td>-0.055336</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.377470</td>\n",
       "      <td>0.156126</td>\n",
       "      <td>-0.206522</td>\n",
       "      <td>0.009881</td>\n",
       "      <td>0.128458</td>\n",
       "      <td>Caitie</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>test_acc</td>\n",
       "      <td>0.675000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-0.956926</td>\n",
       "      <td>-0.335714</td>\n",
       "      <td>0.200000</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.214286</td>\n",
       "      <td>0.542857</td>\n",
       "      <td>0.246429</td>\n",
       "      <td>-0.457143</td>\n",
       "      <td>-0.060714</td>\n",
       "      <td>Chester</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>test_acc</td>\n",
       "      <td>-0.627470</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-0.989125</td>\n",
       "      <td>-0.142292</td>\n",
       "      <td>-0.215415</td>\n",
       "      <td>NaN</td>\n",
       "      <td>-0.063241</td>\n",
       "      <td>0.354743</td>\n",
       "      <td>0.204545</td>\n",
       "      <td>-0.195652</td>\n",
       "      <td>0.346838</td>\n",
       "      <td>CIFAR10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>test_acc</td>\n",
       "      <td>0.874228</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-0.851510</td>\n",
       "      <td>-0.127996</td>\n",
       "      <td>-0.177910</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.187299</td>\n",
       "      <td>0.078577</td>\n",
       "      <td>0.071658</td>\n",
       "      <td>-0.438349</td>\n",
       "      <td>-0.083519</td>\n",
       "      <td>Gutenberg</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>test_acc</td>\n",
       "      <td>0.727947</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-0.987887</td>\n",
       "      <td>-0.473320</td>\n",
       "      <td>0.110672</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.092885</td>\n",
       "      <td>0.313241</td>\n",
       "      <td>0.115613</td>\n",
       "      <td>-0.296443</td>\n",
       "      <td>0.388340</td>\n",
       "      <td>in16</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>test_acc</td>\n",
       "      <td>0.316284</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-0.872618</td>\n",
       "      <td>-0.092414</td>\n",
       "      <td>-0.324685</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.225846</td>\n",
       "      <td>0.422041</td>\n",
       "      <td>0.307388</td>\n",
       "      <td>-0.410180</td>\n",
       "      <td>0.364220</td>\n",
       "      <td>LaMelo</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>test_acc</td>\n",
       "      <td>0.517787</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-0.965159</td>\n",
       "      <td>-0.679842</td>\n",
       "      <td>-0.196640</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.011858</td>\n",
       "      <td>0.427866</td>\n",
       "      <td>0.318182</td>\n",
       "      <td>-0.615613</td>\n",
       "      <td>0.591897</td>\n",
       "      <td>Mateo</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>test_acc</td>\n",
       "      <td>0.207734</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.240911</td>\n",
       "      <td>-0.048539</td>\n",
       "      <td>0.142150</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.038138</td>\n",
       "      <td>0.437842</td>\n",
       "      <td>0.206043</td>\n",
       "      <td>-0.013868</td>\n",
       "      <td>-0.174840</td>\n",
       "      <td>Sokoto</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>test_acc</td>\n",
       "      <td>-0.021429</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-0.985626</td>\n",
       "      <td>-0.392857</td>\n",
       "      <td>-0.582143</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.125000</td>\n",
       "      <td>0.217857</td>\n",
       "      <td>0.260714</td>\n",
       "      <td>-0.614286</td>\n",
       "      <td>0.525000</td>\n",
       "      <td>Volga</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      index  train_acc  test_acc  test_loss       aug   epe_nas  nwot  \\\n",
       "1  test_acc   0.290514       1.0  -0.952736 -0.631423  0.236166   NaN   \n",
       "1  test_acc   0.363636       1.0  -0.972813 -0.503953 -0.055336   NaN   \n",
       "1  test_acc   0.675000       1.0  -0.956926 -0.335714  0.200000   NaN   \n",
       "1  test_acc  -0.627470       1.0  -0.989125 -0.142292 -0.215415   NaN   \n",
       "1  test_acc   0.874228       1.0  -0.851510 -0.127996 -0.177910   NaN   \n",
       "1  test_acc   0.727947       1.0  -0.987887 -0.473320  0.110672   NaN   \n",
       "1  test_acc   0.316284       1.0  -0.872618 -0.092414 -0.324685   NaN   \n",
       "1  test_acc   0.517787       1.0  -0.965159 -0.679842 -0.196640   NaN   \n",
       "1  test_acc   0.207734       1.0   0.240911 -0.048539  0.142150   NaN   \n",
       "1  test_acc  -0.021429       1.0  -0.985626 -0.392857 -0.582143   NaN   \n",
       "\n",
       "      plain      snip    fisher  jacob_cov  grad_norm    dataset  \n",
       "1 -0.037549  0.268775  0.245059  -0.339921   0.333992    Adaline  \n",
       "1  0.377470  0.156126 -0.206522   0.009881   0.128458     Caitie  \n",
       "1  0.214286  0.542857  0.246429  -0.457143  -0.060714    Chester  \n",
       "1 -0.063241  0.354743  0.204545  -0.195652   0.346838    CIFAR10  \n",
       "1  0.187299  0.078577  0.071658  -0.438349  -0.083519  Gutenberg  \n",
       "1  0.092885  0.313241  0.115613  -0.296443   0.388340       in16  \n",
       "1  0.225846  0.422041  0.307388  -0.410180   0.364220     LaMelo  \n",
       "1  0.011858  0.427866  0.318182  -0.615613   0.591897      Mateo  \n",
       "1  0.038138  0.437842  0.206043  -0.013868  -0.174840     Sokoto  \n",
       "1  0.125000  0.217857  0.260714  -0.614286   0.525000      Volga  "
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.concat(correlations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "90466465-e78b-45db-8132-db8891ed49ec",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>test_acc</th>\n",
       "      <th>aug</th>\n",
       "      <th>test_loss</th>\n",
       "      <th>model</th>\n",
       "      <th>dataset</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>epe_nas</th>\n",
       "      <td>-0.057312</td>\n",
       "      <td>0.195652</td>\n",
       "      <td>0.030161</td>\n",
       "      <td>RegNetY_400MF</td>\n",
       "      <td>Adaline</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>nwot</th>\n",
       "      <td>-0.315217</td>\n",
       "      <td>0.595850</td>\n",
       "      <td>0.282324</td>\n",
       "      <td>RegNetY_400MF</td>\n",
       "      <td>Adaline</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>plain</th>\n",
       "      <td>0.207510</td>\n",
       "      <td>-0.054348</td>\n",
       "      <td>-0.289246</td>\n",
       "      <td>RegNetY_400MF</td>\n",
       "      <td>Adaline</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>snip</th>\n",
       "      <td>-0.455534</td>\n",
       "      <td>0.288538</td>\n",
       "      <td>0.456860</td>\n",
       "      <td>RegNetY_400MF</td>\n",
       "      <td>Adaline</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fisher</th>\n",
       "      <td>-0.629447</td>\n",
       "      <td>0.354743</td>\n",
       "      <td>0.636836</td>\n",
       "      <td>RegNetY_400MF</td>\n",
       "      <td>Adaline</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>jacob_cov</th>\n",
       "      <td>-0.223320</td>\n",
       "      <td>0.652174</td>\n",
       "      <td>0.208653</td>\n",
       "      <td>RegNetY_400MF</td>\n",
       "      <td>Adaline</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>grad_norm</th>\n",
       "      <td>-0.442688</td>\n",
       "      <td>0.257905</td>\n",
       "      <td>0.439061</td>\n",
       "      <td>RegNetY_400MF</td>\n",
       "      <td>Adaline</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           test_acc       aug  test_loss          model  dataset\n",
       "epe_nas   -0.057312  0.195652   0.030161  RegNetY_400MF  Adaline\n",
       "nwot      -0.315217  0.595850   0.282324  RegNetY_400MF  Adaline\n",
       "plain      0.207510 -0.054348  -0.289246  RegNetY_400MF  Adaline\n",
       "snip      -0.455534  0.288538   0.456860  RegNetY_400MF  Adaline\n",
       "fisher    -0.629447  0.354743   0.636836  RegNetY_400MF  Adaline\n",
       "jacob_cov -0.223320  0.652174   0.208653  RegNetY_400MF  Adaline\n",
       "grad_norm -0.442688  0.257905   0.439061  RegNetY_400MF  Adaline"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.concat(correlations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "dad954b6-0067-4beb-a83e-6a58d8147742",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>test_acc</th>\n",
       "      <th>aug</th>\n",
       "      <th>test_loss</th>\n",
       "      <th>model</th>\n",
       "      <th>dataset</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>epe_nas</th>\n",
       "      <td>-0.009881</td>\n",
       "      <td>0.077075</td>\n",
       "      <td>0.033284</td>\n",
       "      <td>RegNetY_400MF</td>\n",
       "      <td>Sokoto</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>nwot</th>\n",
       "      <td>0.261428</td>\n",
       "      <td>0.236719</td>\n",
       "      <td>-0.113542</td>\n",
       "      <td>RegNetY_400MF</td>\n",
       "      <td>Sokoto</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>plain</th>\n",
       "      <td>0.001976</td>\n",
       "      <td>-0.529644</td>\n",
       "      <td>0.120220</td>\n",
       "      <td>RegNetY_400MF</td>\n",
       "      <td>Sokoto</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>snip</th>\n",
       "      <td>0.051383</td>\n",
       "      <td>0.117589</td>\n",
       "      <td>-0.351719</td>\n",
       "      <td>RegNetY_400MF</td>\n",
       "      <td>Sokoto</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fisher</th>\n",
       "      <td>0.055336</td>\n",
       "      <td>0.172925</td>\n",
       "      <td>-0.356687</td>\n",
       "      <td>RegNetY_400MF</td>\n",
       "      <td>Sokoto</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>jacob_cov</th>\n",
       "      <td>-0.325099</td>\n",
       "      <td>-0.121542</td>\n",
       "      <td>0.349732</td>\n",
       "      <td>RegNetY_400MF</td>\n",
       "      <td>Sokoto</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>grad_norm</th>\n",
       "      <td>0.072134</td>\n",
       "      <td>0.161067</td>\n",
       "      <td>-0.368113</td>\n",
       "      <td>RegNetY_400MF</td>\n",
       "      <td>Sokoto</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           test_acc       aug  test_loss          model dataset\n",
       "epe_nas   -0.009881  0.077075   0.033284  RegNetY_400MF  Sokoto\n",
       "nwot       0.261428  0.236719  -0.113542  RegNetY_400MF  Sokoto\n",
       "plain      0.001976 -0.529644   0.120220  RegNetY_400MF  Sokoto\n",
       "snip       0.051383  0.117589  -0.351719  RegNetY_400MF  Sokoto\n",
       "fisher     0.055336  0.172925  -0.356687  RegNetY_400MF  Sokoto\n",
       "jacob_cov -0.325099 -0.121542   0.349732  RegNetY_400MF  Sokoto\n",
       "grad_norm  0.072134  0.161067  -0.368113  RegNetY_400MF  Sokoto"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.concat(correlations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "7e722212-c43c-4927-bf91-c7c3d8de568e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from scipy.stats import spearmanr\n",
    "import itertools\n",
    "\n",
    "column_combinations = list(itertools.combinations(measures, 2))\n",
    "\n",
    "\n",
    "for dataset in [\"Adaline\", \"Caitie\", \"Chester\", \"CIFAR10\", \"Gutenberg\", \"in16\", \"LaMelo\", \"Mateo\", \"Sokoto\", \"Volga\"]:\n",
    "    results = []\n",
    "\n",
    "    # Iterate and compute the Spearman correlation for each combination\n",
    "    for col1, col2 in column_combinations:\n",
    "        combined_values = zcosts[col1] - zcosts[col2]  # Combine columns by subtracting their values\n",
    "        corr, _ = spearmanr(zcosts[\"test_acc\"], combined_values)\n",
    "\n",
    "        # Append results to the list\n",
    "        results.append({\"col1\": col1, \"col2\": col2, \"corr\": corr})\n",
    "\n",
    "    # Convert results to a DataFrame and save as CSV\n",
    "    results_df_neg = pd.DataFrame(results)\n",
    "    results_df_neg.to_csv(f\"/home/woody/iwb3/iwb3021h/augmentations_test/{model_name}/{dataset}/comb_corr_{dataset}_neg.csv\", index=False)\n",
    "\n",
    "\n",
    "    results = []\n",
    "\n",
    "    # Iterate and compute the Spearman correlation for each combination\n",
    "    for col1, col2 in column_combinations:\n",
    "        combined_values = zcosts[col1] + zcosts[col2]  # Combine columns by subtracting their values\n",
    "        corr, _ = spearmanr(zcosts[\"test_acc\"], combined_values)\n",
    "\n",
    "        # Append results to the list\n",
    "        results.append({\"col1\": col1, \"col2\": col2, \"corr\": corr})\n",
    "\n",
    "    # Convert results to a DataFrame and save as CSV\n",
    "    results_df = pd.DataFrame(results)\n",
    "    results_df.to_csv(f\"/home/woody/iwb3/iwb3021h/augmentations_test/{model_name}/{dataset}/comb_corr_{dataset}_pos.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 244,
   "id": "ac8f9164-f579-4b6d-bec4-c788f744a57f",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>index</th>\n",
       "      <th>train_acc</th>\n",
       "      <th>test_acc</th>\n",
       "      <th>test_loss</th>\n",
       "      <th>aug</th>\n",
       "      <th>epe_nas</th>\n",
       "      <th>nwot</th>\n",
       "      <th>plain</th>\n",
       "      <th>snip</th>\n",
       "      <th>fisher</th>\n",
       "      <th>...</th>\n",
       "      <th>grad_norm</th>\n",
       "      <th>rank_epe_nas</th>\n",
       "      <th>rank_nwot</th>\n",
       "      <th>rank_plain</th>\n",
       "      <th>rank_snip</th>\n",
       "      <th>rank_fisher</th>\n",
       "      <th>rank_jacob_cov</th>\n",
       "      <th>rank_grad_norm</th>\n",
       "      <th>model</th>\n",
       "      <th>pos</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2</td>\n",
       "      <td>66.74</td>\n",
       "      <td>67.91</td>\n",
       "      <td>0.74</td>\n",
       "      <td>9</td>\n",
       "      <td>1.373022e+06</td>\n",
       "      <td>5227.063889</td>\n",
       "      <td>-0.010614</td>\n",
       "      <td>21.345419</td>\n",
       "      <td>0.000279</td>\n",
       "      <td>...</td>\n",
       "      <td>6.842886</td>\n",
       "      <td>12.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>MobileNetV3_large</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>5</td>\n",
       "      <td>61.83</td>\n",
       "      <td>64.46</td>\n",
       "      <td>0.82</td>\n",
       "      <td>12</td>\n",
       "      <td>1.380295e+06</td>\n",
       "      <td>5120.348988</td>\n",
       "      <td>-0.030193</td>\n",
       "      <td>22.489124</td>\n",
       "      <td>0.000258</td>\n",
       "      <td>...</td>\n",
       "      <td>7.516641</td>\n",
       "      <td>9.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>MobileNetV3_large</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>64.04</td>\n",
       "      <td>63.94</td>\n",
       "      <td>0.81</td>\n",
       "      <td>8</td>\n",
       "      <td>1.365543e+06</td>\n",
       "      <td>5256.341822</td>\n",
       "      <td>-0.021642</td>\n",
       "      <td>22.050383</td>\n",
       "      <td>0.000321</td>\n",
       "      <td>...</td>\n",
       "      <td>6.945489</td>\n",
       "      <td>14.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>MobileNetV3_large</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>14</td>\n",
       "      <td>60.76</td>\n",
       "      <td>63.45</td>\n",
       "      <td>0.86</td>\n",
       "      <td>21</td>\n",
       "      <td>1.486250e+06</td>\n",
       "      <td>5628.782053</td>\n",
       "      <td>-0.017527</td>\n",
       "      <td>21.932199</td>\n",
       "      <td>0.000332</td>\n",
       "      <td>...</td>\n",
       "      <td>7.192066</td>\n",
       "      <td>6.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>MobileNetV3_large</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>67.67</td>\n",
       "      <td>63.03</td>\n",
       "      <td>0.85</td>\n",
       "      <td>0</td>\n",
       "      <td>1.375757e+06</td>\n",
       "      <td>5200.191685</td>\n",
       "      <td>-0.006991</td>\n",
       "      <td>21.727465</td>\n",
       "      <td>0.000333</td>\n",
       "      <td>...</td>\n",
       "      <td>7.003474</td>\n",
       "      <td>10.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>MobileNetV3_large</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>9</td>\n",
       "      <td>64.30</td>\n",
       "      <td>61.78</td>\n",
       "      <td>0.87</td>\n",
       "      <td>16</td>\n",
       "      <td>1.512495e+06</td>\n",
       "      <td>5692.449709</td>\n",
       "      <td>-0.007924</td>\n",
       "      <td>22.023966</td>\n",
       "      <td>0.000386</td>\n",
       "      <td>...</td>\n",
       "      <td>7.500123</td>\n",
       "      <td>2.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>MobileNetV3_large</td>\n",
       "      <td>6.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>6</td>\n",
       "      <td>58.27</td>\n",
       "      <td>61.42</td>\n",
       "      <td>0.88</td>\n",
       "      <td>13</td>\n",
       "      <td>1.371341e+06</td>\n",
       "      <td>4856.640451</td>\n",
       "      <td>-0.025797</td>\n",
       "      <td>22.767122</td>\n",
       "      <td>0.000325</td>\n",
       "      <td>...</td>\n",
       "      <td>7.814985</td>\n",
       "      <td>13.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>MobileNetV3_large</td>\n",
       "      <td>7.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>11</td>\n",
       "      <td>62.71</td>\n",
       "      <td>60.87</td>\n",
       "      <td>0.88</td>\n",
       "      <td>18</td>\n",
       "      <td>1.489730e+06</td>\n",
       "      <td>5627.411557</td>\n",
       "      <td>-0.017340</td>\n",
       "      <td>21.251942</td>\n",
       "      <td>0.000258</td>\n",
       "      <td>...</td>\n",
       "      <td>7.054242</td>\n",
       "      <td>3.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>MobileNetV3_large</td>\n",
       "      <td>8.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>3</td>\n",
       "      <td>59.15</td>\n",
       "      <td>60.25</td>\n",
       "      <td>0.92</td>\n",
       "      <td>10</td>\n",
       "      <td>1.363266e+06</td>\n",
       "      <td>5263.432441</td>\n",
       "      <td>0.018137</td>\n",
       "      <td>22.156715</td>\n",
       "      <td>0.000274</td>\n",
       "      <td>...</td>\n",
       "      <td>6.657001</td>\n",
       "      <td>15.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>MobileNetV3_large</td>\n",
       "      <td>9.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>13</td>\n",
       "      <td>57.42</td>\n",
       "      <td>59.63</td>\n",
       "      <td>0.92</td>\n",
       "      <td>20</td>\n",
       "      <td>1.488481e+06</td>\n",
       "      <td>5623.401090</td>\n",
       "      <td>0.001243</td>\n",
       "      <td>22.331259</td>\n",
       "      <td>0.000289</td>\n",
       "      <td>...</td>\n",
       "      <td>7.694211</td>\n",
       "      <td>4.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>MobileNetV3_large</td>\n",
       "      <td>10.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>12</td>\n",
       "      <td>59.87</td>\n",
       "      <td>59.37</td>\n",
       "      <td>0.93</td>\n",
       "      <td>19</td>\n",
       "      <td>1.486710e+06</td>\n",
       "      <td>5631.860975</td>\n",
       "      <td>-0.013066</td>\n",
       "      <td>22.326313</td>\n",
       "      <td>0.000325</td>\n",
       "      <td>...</td>\n",
       "      <td>7.165279</td>\n",
       "      <td>5.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>MobileNetV3_large</td>\n",
       "      <td>11.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>8</td>\n",
       "      <td>60.32</td>\n",
       "      <td>58.44</td>\n",
       "      <td>0.95</td>\n",
       "      <td>15</td>\n",
       "      <td>1.484411e+06</td>\n",
       "      <td>5629.858944</td>\n",
       "      <td>-0.019462</td>\n",
       "      <td>22.341305</td>\n",
       "      <td>0.000365</td>\n",
       "      <td>...</td>\n",
       "      <td>7.645125</td>\n",
       "      <td>7.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>MobileNetV3_large</td>\n",
       "      <td>12.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>4</td>\n",
       "      <td>55.60</td>\n",
       "      <td>58.20</td>\n",
       "      <td>0.96</td>\n",
       "      <td>11</td>\n",
       "      <td>1.385417e+06</td>\n",
       "      <td>5414.866612</td>\n",
       "      <td>-0.006670</td>\n",
       "      <td>22.910683</td>\n",
       "      <td>0.000292</td>\n",
       "      <td>...</td>\n",
       "      <td>7.462081</td>\n",
       "      <td>8.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>MobileNetV3_large</td>\n",
       "      <td>13.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>7</td>\n",
       "      <td>53.64</td>\n",
       "      <td>57.16</td>\n",
       "      <td>0.99</td>\n",
       "      <td>14</td>\n",
       "      <td>1.375202e+06</td>\n",
       "      <td>5321.755661</td>\n",
       "      <td>-0.025493</td>\n",
       "      <td>23.912573</td>\n",
       "      <td>0.000365</td>\n",
       "      <td>...</td>\n",
       "      <td>7.737163</td>\n",
       "      <td>11.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>MobileNetV3_large</td>\n",
       "      <td>14.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>10</td>\n",
       "      <td>52.82</td>\n",
       "      <td>51.42</td>\n",
       "      <td>1.08</td>\n",
       "      <td>17</td>\n",
       "      <td>1.528889e+06</td>\n",
       "      <td>5728.593305</td>\n",
       "      <td>-0.011147</td>\n",
       "      <td>22.071995</td>\n",
       "      <td>0.000336</td>\n",
       "      <td>...</td>\n",
       "      <td>7.113778</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>MobileNetV3_large</td>\n",
       "      <td>15.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>15 rows × 21 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    index  train_acc  test_acc  test_loss  aug       epe_nas         nwot  \\\n",
       "0       2      66.74     67.91       0.74    9  1.373022e+06  5227.063889   \n",
       "1       5      61.83     64.46       0.82   12  1.380295e+06  5120.348988   \n",
       "2       1      64.04     63.94       0.81    8  1.365543e+06  5256.341822   \n",
       "3      14      60.76     63.45       0.86   21  1.486250e+06  5628.782053   \n",
       "4       0      67.67     63.03       0.85    0  1.375757e+06  5200.191685   \n",
       "5       9      64.30     61.78       0.87   16  1.512495e+06  5692.449709   \n",
       "6       6      58.27     61.42       0.88   13  1.371341e+06  4856.640451   \n",
       "7      11      62.71     60.87       0.88   18  1.489730e+06  5627.411557   \n",
       "8       3      59.15     60.25       0.92   10  1.363266e+06  5263.432441   \n",
       "9      13      57.42     59.63       0.92   20  1.488481e+06  5623.401090   \n",
       "10     12      59.87     59.37       0.93   19  1.486710e+06  5631.860975   \n",
       "11      8      60.32     58.44       0.95   15  1.484411e+06  5629.858944   \n",
       "12      4      55.60     58.20       0.96   11  1.385417e+06  5414.866612   \n",
       "13      7      53.64     57.16       0.99   14  1.375202e+06  5321.755661   \n",
       "14     10      52.82     51.42       1.08   17  1.528889e+06  5728.593305   \n",
       "\n",
       "       plain       snip    fisher  ...  grad_norm  rank_epe_nas  rank_nwot  \\\n",
       "0  -0.010614  21.345419  0.000279  ...   6.842886          12.0       12.0   \n",
       "1  -0.030193  22.489124  0.000258  ...   7.516641           9.0       14.0   \n",
       "2  -0.021642  22.050383  0.000321  ...   6.945489          14.0       11.0   \n",
       "3  -0.017527  21.932199  0.000332  ...   7.192066           6.0        5.0   \n",
       "4  -0.006991  21.727465  0.000333  ...   7.003474          10.0       13.0   \n",
       "5  -0.007924  22.023966  0.000386  ...   7.500123           2.0        2.0   \n",
       "6  -0.025797  22.767122  0.000325  ...   7.814985          13.0       15.0   \n",
       "7  -0.017340  21.251942  0.000258  ...   7.054242           3.0        6.0   \n",
       "8   0.018137  22.156715  0.000274  ...   6.657001          15.0       10.0   \n",
       "9   0.001243  22.331259  0.000289  ...   7.694211           4.0        7.0   \n",
       "10 -0.013066  22.326313  0.000325  ...   7.165279           5.0        3.0   \n",
       "11 -0.019462  22.341305  0.000365  ...   7.645125           7.0        4.0   \n",
       "12 -0.006670  22.910683  0.000292  ...   7.462081           8.0        8.0   \n",
       "13 -0.025493  23.912573  0.000365  ...   7.737163          11.0        9.0   \n",
       "14 -0.011147  22.071995  0.000336  ...   7.113778           1.0        1.0   \n",
       "\n",
       "    rank_plain  rank_snip  rank_fisher  rank_jacob_cov  rank_grad_norm  \\\n",
       "0          6.0       14.0         12.0            14.0            14.0   \n",
       "1         15.0        4.0         14.0            11.0             5.0   \n",
       "2         12.0       10.0          9.0            13.0            13.0   \n",
       "3         10.0       12.0          6.0             5.0             8.0   \n",
       "4          4.0       13.0          5.0            12.0            12.0   \n",
       "5          5.0       11.0          1.0             2.0             6.0   \n",
       "6         14.0        3.0          8.0             9.0             1.0   \n",
       "7          9.0       15.0         15.0             7.0            11.0   \n",
       "8          1.0        8.0         13.0            15.0            15.0   \n",
       "9          2.0        6.0         11.0             6.0             3.0   \n",
       "10         8.0        7.0          7.0             3.0             9.0   \n",
       "11        11.0        5.0          3.0             4.0             4.0   \n",
       "12         3.0        2.0         10.0            10.0             7.0   \n",
       "13        13.0        1.0          2.0             8.0             2.0   \n",
       "14         7.0        9.0          4.0             1.0            10.0   \n",
       "\n",
       "                model   pos  \n",
       "0   MobileNetV3_large   1.0  \n",
       "1   MobileNetV3_large   2.0  \n",
       "2   MobileNetV3_large   3.0  \n",
       "3   MobileNetV3_large   4.0  \n",
       "4   MobileNetV3_large   5.0  \n",
       "5   MobileNetV3_large   6.0  \n",
       "6   MobileNetV3_large   7.0  \n",
       "7   MobileNetV3_large   8.0  \n",
       "8   MobileNetV3_large   9.0  \n",
       "9   MobileNetV3_large  10.0  \n",
       "10  MobileNetV3_large  11.0  \n",
       "11  MobileNetV3_large  12.0  \n",
       "12  MobileNetV3_large  13.0  \n",
       "13  MobileNetV3_large  14.0  \n",
       "14  MobileNetV3_large  15.0  \n",
       "\n",
       "[15 rows x 21 columns]"
      ]
     },
     "execution_count": 244,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "bcfa6331-c53c-421c-bd88-fc4ba4cec39f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#Adaline, Caitie, Chester, CIFAR10, Gutenberg, in16, LaMelo, Mateo, Sokoto, Volga\n",
    "#MobileNetV3_large, EfficientNet_b0, ResNet18, RegNetY_400MF\n",
    "import pandas as pd\n",
    "model_name=\"EfficientNet_b0\"\n",
    "#dataset=\"Adaline\"\n",
    "import numpy as np\n",
    "dataset_results={}\n",
    "dataset_autoaugment={}\n",
    "dataset_noaug={}\n",
    "dataset_randaug={}\n",
    "correlations=[]\n",
    "correlations_pos=[]\n",
    "\n",
    "for dataset in [\"Adaline\", \"Caitie\", \"Chester\", \"CIFAR10\", \"Gutenberg\", \"in16\", \"LaMelo\", \"Mateo\", \"Sokoto\", \"Volga\"]:\n",
    "    df=pd.read_csv(f\"/home/woody/iwb3/iwb3021h/augmentations_test/{model_name}/{dataset}/zcosts_{dataset}.csv\", index_col=0).reset_index()\n",
    "    corr=pd.read_csv(f\"/home/woody/iwb3/iwb3021h/augmentations_test/{model_name}/{dataset}/corr_{dataset}.csv\", index_col=0).reset_index()\n",
    "    corr=corr[corr[\"index\"]==\"test_acc\"]\n",
    "    corr[\"dataset\"]=dataset\n",
    "    correlations.append(corr)\n",
    "    \n",
    "    corr_pos=pd.read_csv(f\"/home/woody/iwb3/iwb3021h/augmentations_test/{model_name}/{dataset}/comb_corr_{dataset}_pos.csv\", index_col=0).reset_index()\n",
    "    corr_pos[\"dataset\"]=dataset\n",
    "    correlations_pos.append(corr_pos)\n",
    "    \n",
    "    df[f\"pos\"]=df[\"test_acc\"].rank(ascending=False, method='dense')\n",
    "    dataset_results[dataset]={}\n",
    "    try:\n",
    "        dataset_autoaugment[dataset]=df[df.aug==22].pos.values[0]\n",
    "    except:\n",
    "        dataset_autoaugment[dataset]=np.NaN\n",
    "    try:\n",
    "        dataset_noaug[dataset]=df[df.aug==0].pos.values[0]\n",
    "    except:\n",
    "        dataset_noaug[dataset]=np.NaN\n",
    "    try:\n",
    "        dataset_randaug[dataset]=df[df.aug==1].pos.values[0]\n",
    "    except:\n",
    "        dataset_randaug[dataset]=np.NaN\n",
    "    for metric in [\"epe_nas\",\"nwot\",\"plain\",\"snip\", \"fisher\",\"jacob_cov\", \"grad_norm\", \"fisher_jacob\"]:#\"grasp\":\n",
    "        #df.sort_values(by=metric,ascending=False)[[\"test_acc\",metric,f\"rank_{metric}\",\"pos\"]]\n",
    "        try:\n",
    "            score=df.sort_values(by=metric,ascending=True)[[\"test_acc\",metric,f\"rank_{metric}\",\"pos\"]].head(1).pos.values[0]\n",
    "            dataset_results[dataset][metric]=score\n",
    "\n",
    "        except:\n",
    "            print(metric)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5578b5ad-19a2-4b2e-a681-42ca5f1c9fe2",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "col1       col2     \n",
       "epe_nas    fisher      -0.342857\n",
       "           grad_norm   -0.157143\n",
       "           jacob_cov   -0.560714\n",
       "           nwot        -0.567857\n",
       "           plain       -0.439286\n",
       "           snip        -0.235714\n",
       "fisher     grad_norm    0.307143\n",
       "           jacob_cov   -0.432143\n",
       "jacob_cov  grad_norm   -0.285714\n",
       "nwot       fisher      -0.357143\n",
       "           grad_norm   -0.207143\n",
       "           jacob_cov   -0.653571\n",
       "           plain       -0.428571\n",
       "           snip        -0.228571\n",
       "plain      fisher       0.150000\n",
       "           grad_norm    0.303571\n",
       "           jacob_cov   -0.446429\n",
       "           snip         0.207143\n",
       "snip       fisher       0.278571\n",
       "           grad_norm    0.410714\n",
       "           jacob_cov   -0.335714\n",
       "Name: corr, dtype: float64"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.concat(correlations_pos).groupby([\"col1\",\"col2\"])[\"corr\"].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "5475ba40-d639-4eb9-ab04-528d21ece47d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "results=[]\n",
    "results.append(pd.DataFrame(dataset_results))\n",
    "results.append(pd.DataFrame(dataset_autoaugment, index=[\"AutoAugment\"]))\n",
    "results.append(pd.DataFrame(dataset_noaug, index=[\"NoAugmentation\"]))\n",
    "results.append(pd.DataFrame(dataset_randaug, index=[\"RandAugment\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "5da6577d-33dd-422d-b57f-f226660f6679",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Adaline</th>\n",
       "      <th>Caitie</th>\n",
       "      <th>Chester</th>\n",
       "      <th>CIFAR10</th>\n",
       "      <th>Gutenberg</th>\n",
       "      <th>in16</th>\n",
       "      <th>LaMelo</th>\n",
       "      <th>Mateo</th>\n",
       "      <th>Sokoto</th>\n",
       "      <th>Volga</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>epe_nas</th>\n",
       "      <td>15.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>nwot</th>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>plain</th>\n",
       "      <td>1.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>18.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>12.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>snip</th>\n",
       "      <td>14.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>23.0</td>\n",
       "      <td>20.0</td>\n",
       "      <td>23.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>15.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fisher</th>\n",
       "      <td>4.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>23.0</td>\n",
       "      <td>20.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>9.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>jacob_cov</th>\n",
       "      <td>9.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>grad_norm</th>\n",
       "      <td>18.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>23.0</td>\n",
       "      <td>18.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>8.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fisher_jacob</th>\n",
       "      <td>9.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AutoAugment</th>\n",
       "      <td>15.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>9.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>NoAugmentation</th>\n",
       "      <td>11.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>12.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RandAugment</th>\n",
       "      <td>6.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>2.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>18.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                Adaline  Caitie  Chester  CIFAR10  Gutenberg  in16  LaMelo  \\\n",
       "epe_nas            15.0    16.0      5.0     16.0        1.0   8.0    12.0   \n",
       "nwot                1.0     1.0      1.0      1.0        1.0   1.0     1.0   \n",
       "plain               1.0    21.0     14.0      6.0        7.0  18.0    12.0   \n",
       "snip               14.0     3.0     11.0     23.0       20.0  23.0    19.0   \n",
       "fisher              4.0     1.0     13.0     23.0       20.0  21.0    14.0   \n",
       "jacob_cov           9.0    16.0      7.0      9.0       19.0   7.0     7.0   \n",
       "grad_norm          18.0     1.0      5.0     21.0       16.0  23.0    18.0   \n",
       "fisher_jacob        9.0    16.0      7.0      5.0        6.0   7.0     7.0   \n",
       "AutoAugment        15.0    16.0      NaN      9.0        7.0  22.0     7.0   \n",
       "NoAugmentation     11.0    15.0      9.0     22.0        5.0   4.0     8.0   \n",
       "RandAugment         6.0     8.0      NaN      2.0       14.0  18.0     9.0   \n",
       "\n",
       "                Mateo  Sokoto  Volga  \n",
       "epe_nas           2.0    15.0    5.0  \n",
       "nwot              1.0     1.0    1.0  \n",
       "plain            13.0     4.0   12.0  \n",
       "snip             17.0    17.0   15.0  \n",
       "fisher           21.0    17.0    9.0  \n",
       "jacob_cov        10.0     4.0    5.0  \n",
       "grad_norm         8.0    10.0    8.0  \n",
       "fisher_jacob     10.0     3.0    1.0  \n",
       "AutoAugment      11.0    11.0    NaN  \n",
       "NoAugmentation    6.0    14.0   12.0  \n",
       "RandAugment       2.0     1.0    NaN  "
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.concat(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e0ef356c-9f23-4111-9c7e-a806264c09a4",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "epe_nas           11.500\n",
       "nwot               9.500\n",
       "plain             12.800\n",
       "snip               6.700\n",
       "fisher             7.700\n",
       "jacob_cov          8.200\n",
       "grad_norm          8.300\n",
       "fisher_jacob       7.500\n",
       "AutoAugment        8.250\n",
       "NoAugmentation    15.800\n",
       "RandAugment        6.375\n",
       "dtype: float64"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.concat(results).mean(axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "2c07c30c-2015-4f78-ba5e-f366a8ee89f8",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Adaline</th>\n",
       "      <th>Caitie</th>\n",
       "      <th>Chester</th>\n",
       "      <th>CIFAR10</th>\n",
       "      <th>Gutenberg</th>\n",
       "      <th>in16</th>\n",
       "      <th>LaMelo</th>\n",
       "      <th>Mateo</th>\n",
       "      <th>Sokoto</th>\n",
       "      <th>Volga</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>epe_nas</th>\n",
       "      <td>10.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>6.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>nwot</th>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>plain</th>\n",
       "      <td>12.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>23.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>snip</th>\n",
       "      <td>1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>8.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fisher</th>\n",
       "      <td>14.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>8.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>jacob_cov</th>\n",
       "      <td>6.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>grad_norm</th>\n",
       "      <td>5.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>8.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fisher_jacob</th>\n",
       "      <td>9.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AutoAugment</th>\n",
       "      <td>15.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>9.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>NoAugmentation</th>\n",
       "      <td>11.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>12.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RandAugment</th>\n",
       "      <td>6.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>2.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>18.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                Adaline  Caitie  Chester  CIFAR10  Gutenberg  in16  LaMelo  \\\n",
       "epe_nas            10.0     8.0      7.0     16.0       19.0   2.0    12.0   \n",
       "nwot                1.0     1.0      1.0      1.0        1.0   1.0     1.0   \n",
       "plain              12.0    12.0      5.0     22.0       14.0  23.0     6.0   \n",
       "snip                1.0     3.0      6.0      3.0       11.0  14.0    15.0   \n",
       "fisher             14.0     3.0      8.0     16.0       16.0  10.0    21.0   \n",
       "jacob_cov           6.0    16.0      9.0      5.0       10.0  14.0     6.0   \n",
       "grad_norm           5.0    22.0     11.0     16.0        9.0  15.0     3.0   \n",
       "fisher_jacob        9.0    10.0      7.0      5.0       10.0  14.0     9.0   \n",
       "AutoAugment        15.0    16.0      NaN      9.0        7.0  22.0     7.0   \n",
       "NoAugmentation     11.0    15.0      9.0     22.0        5.0   4.0     8.0   \n",
       "RandAugment         6.0     8.0      NaN      2.0       14.0  18.0     9.0   \n",
       "\n",
       "                Mateo  Sokoto  Volga  \n",
       "epe_nas          11.0     6.0    6.0  \n",
       "nwot              1.0     1.0    1.0  \n",
       "plain            12.0     4.0    3.0  \n",
       "snip             13.0    10.0    8.0  \n",
       "fisher            6.0    14.0    8.0  \n",
       "jacob_cov        10.0     3.0    5.0  \n",
       "grad_norm        15.0    10.0    8.0  \n",
       "fisher_jacob     10.0     3.0    5.0  \n",
       "AutoAugment      11.0    11.0    NaN  \n",
       "NoAugmentation    6.0    14.0   12.0  \n",
       "RandAugment       2.0     1.0    NaN  "
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.concat(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "dfd613a8-a043-4625-8d3c-7bad90a997ca",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Adaline</th>\n",
       "      <th>Caitie</th>\n",
       "      <th>Chester</th>\n",
       "      <th>CIFAR10</th>\n",
       "      <th>Gutenberg</th>\n",
       "      <th>in16</th>\n",
       "      <th>LaMelo</th>\n",
       "      <th>Mateo</th>\n",
       "      <th>Sokoto</th>\n",
       "      <th>Volga</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>epe_nas</th>\n",
       "      <td>15.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>6.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>nwot</th>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>plain</th>\n",
       "      <td>15.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>23.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>snip</th>\n",
       "      <td>21.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>8.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fisher</th>\n",
       "      <td>23.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>8.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>jacob_cov</th>\n",
       "      <td>7.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>grad_norm</th>\n",
       "      <td>13.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>8.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fisher_jacob</th>\n",
       "      <td>7.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AutoAugment</th>\n",
       "      <td>15.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>9.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>NoAugmentation</th>\n",
       "      <td>11.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>12.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RandAugment</th>\n",
       "      <td>6.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>2.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>18.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                Adaline  Caitie  Chester  CIFAR10  Gutenberg  in16  LaMelo  \\\n",
       "epe_nas            15.0     8.0      7.0     16.0       19.0   2.0    12.0   \n",
       "nwot                1.0     1.0      1.0      1.0        1.0   1.0     1.0   \n",
       "plain              15.0    12.0      5.0     22.0       14.0  23.0     6.0   \n",
       "snip               21.0     3.0      6.0      3.0       11.0  14.0    15.0   \n",
       "fisher             23.0     3.0      8.0     16.0       16.0  10.0    21.0   \n",
       "jacob_cov           7.0    16.0      9.0      5.0       10.0  14.0     6.0   \n",
       "grad_norm          13.0    22.0     11.0     16.0        9.0  15.0     3.0   \n",
       "fisher_jacob        7.0    10.0      7.0      5.0       10.0  14.0     9.0   \n",
       "AutoAugment        15.0    16.0      NaN      9.0        7.0  22.0     7.0   \n",
       "NoAugmentation     11.0    15.0      9.0     22.0        5.0   4.0     8.0   \n",
       "RandAugment         6.0     8.0      NaN      2.0       14.0  18.0     9.0   \n",
       "\n",
       "                Mateo  Sokoto  Volga  \n",
       "epe_nas          11.0     6.0    6.0  \n",
       "nwot              1.0     1.0    1.0  \n",
       "plain            12.0     4.0    3.0  \n",
       "snip             13.0    10.0    8.0  \n",
       "fisher            6.0    14.0    8.0  \n",
       "jacob_cov        10.0     3.0    5.0  \n",
       "grad_norm        15.0    10.0    8.0  \n",
       "fisher_jacob     10.0     3.0    5.0  \n",
       "AutoAugment      11.0    11.0    NaN  \n",
       "NoAugmentation    6.0    14.0   12.0  \n",
       "RandAugment       2.0     1.0    NaN  "
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.concat(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a2b378ff-b116-4722-b7cf-6f240e7cd6df",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Adaline</th>\n",
       "      <th>Caitie</th>\n",
       "      <th>Chester</th>\n",
       "      <th>CIFAR10</th>\n",
       "      <th>Gutenberg</th>\n",
       "      <th>in16</th>\n",
       "      <th>LaMelo</th>\n",
       "      <th>Mateo</th>\n",
       "      <th>Sokoto</th>\n",
       "      <th>Volga</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>epe_nas</th>\n",
       "      <td>10.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>20.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>6.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>nwot</th>\n",
       "      <td>19.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>plain</th>\n",
       "      <td>15.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>snip</th>\n",
       "      <td>20.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>10.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fisher</th>\n",
       "      <td>10.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>9.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>jacob_cov</th>\n",
       "      <td>8.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>grad_norm</th>\n",
       "      <td>10.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>10.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fisher_jacob</th>\n",
       "      <td>10.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AutoAugment</th>\n",
       "      <td>6.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>6.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>NoAugmentation</th>\n",
       "      <td>16.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>23.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>15.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RandAugment</th>\n",
       "      <td>3.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                Adaline  Caitie  Chester  CIFAR10  Gutenberg  in16  LaMelo  \\\n",
       "epe_nas            10.0     2.0      1.0     20.0        9.0   4.0    19.0   \n",
       "nwot               19.0     9.0      1.0      8.0       14.0   5.0     4.0   \n",
       "plain              15.0    13.0      6.0     15.0        4.0   5.0     4.0   \n",
       "snip               20.0     7.0      3.0     12.0        7.0  17.0     9.0   \n",
       "fisher             10.0     8.0      3.0      5.0        7.0   3.0     9.0   \n",
       "jacob_cov           8.0    21.0      1.0      8.0        9.0   1.0    11.0   \n",
       "grad_norm          10.0     7.0      3.0     12.0        7.0  17.0     9.0   \n",
       "fisher_jacob       10.0     9.0      3.0      8.0        7.0   1.0    11.0   \n",
       "AutoAugment         6.0    11.0      NaN      6.0        9.0  10.0    11.0   \n",
       "NoAugmentation     16.0    16.0     14.0     23.0       16.0  16.0    15.0   \n",
       "RandAugment         3.0     7.0      NaN      1.0        5.0   4.0     8.0   \n",
       "\n",
       "                Mateo  Sokoto  Volga  \n",
       "epe_nas          16.0     7.0    6.0  \n",
       "nwot              6.0     6.0    2.0  \n",
       "plain            11.0    13.0    4.0  \n",
       "snip              4.0    15.0   10.0  \n",
       "fisher            4.0    15.0    9.0  \n",
       "jacob_cov         6.0     6.0    1.0  \n",
       "grad_norm         4.0    15.0   10.0  \n",
       "fisher_jacob      6.0     6.0    2.0  \n",
       "AutoAugment       7.0     6.0    NaN  \n",
       "NoAugmentation   16.0    11.0   15.0  \n",
       "RandAugment       4.0    19.0    NaN  "
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.concat(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "f5c68c11-6cb0-402c-a702-5c65b4d6d1a3",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Adaline</th>\n",
       "      <th>Caitie</th>\n",
       "      <th>Chester</th>\n",
       "      <th>CIFAR10</th>\n",
       "      <th>Gutenberg</th>\n",
       "      <th>in16</th>\n",
       "      <th>LaMelo</th>\n",
       "      <th>Mateo</th>\n",
       "      <th>Sokoto</th>\n",
       "      <th>Volga</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>epe_nas</th>\n",
       "      <td>9.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>20.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>6.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>nwot</th>\n",
       "      <td>16.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>plain</th>\n",
       "      <td>15.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>snip</th>\n",
       "      <td>5.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>10.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fisher</th>\n",
       "      <td>5.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>9.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>jacob_cov</th>\n",
       "      <td>16.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>grad_norm</th>\n",
       "      <td>5.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>10.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fisher_jacob</th>\n",
       "      <td>5.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AutoAugment</th>\n",
       "      <td>6.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>6.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>NoAugmentation</th>\n",
       "      <td>16.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>23.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>15.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RandAugment</th>\n",
       "      <td>3.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                Adaline  Caitie  Chester  CIFAR10  Gutenberg  in16  LaMelo  \\\n",
       "epe_nas             9.0     2.0      1.0     20.0        9.0   4.0    19.0   \n",
       "nwot               16.0     9.0      1.0      8.0       14.0   5.0     4.0   \n",
       "plain              15.0    13.0      6.0     15.0        4.0   5.0     4.0   \n",
       "snip                5.0     7.0      3.0     12.0        7.0  17.0     9.0   \n",
       "fisher              5.0     8.0      3.0      5.0        7.0   3.0     9.0   \n",
       "jacob_cov          16.0    21.0      1.0      8.0        9.0   1.0    11.0   \n",
       "grad_norm           5.0     7.0      3.0     12.0        7.0  17.0     9.0   \n",
       "fisher_jacob        5.0     9.0      3.0      8.0        7.0   1.0    11.0   \n",
       "AutoAugment         6.0    11.0      NaN      6.0        9.0  10.0    11.0   \n",
       "NoAugmentation     16.0    16.0     14.0     23.0       16.0  16.0    15.0   \n",
       "RandAugment         3.0     7.0      NaN      1.0        5.0   4.0     8.0   \n",
       "\n",
       "                Mateo  Sokoto  Volga  \n",
       "epe_nas          16.0     7.0    6.0  \n",
       "nwot              6.0     6.0    2.0  \n",
       "plain            11.0    13.0    4.0  \n",
       "snip              4.0    15.0   10.0  \n",
       "fisher            4.0    15.0    9.0  \n",
       "jacob_cov         6.0     6.0    1.0  \n",
       "grad_norm         4.0    15.0   10.0  \n",
       "fisher_jacob      6.0     6.0    2.0  \n",
       "AutoAugment       7.0     6.0    NaN  \n",
       "NoAugmentation   16.0    11.0   15.0  \n",
       "RandAugment       4.0    19.0    NaN  "
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.concat(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "27ddd738-fef0-4523-a9ee-60b4d955cc5b",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Adaline</th>\n",
       "      <th>Caitie</th>\n",
       "      <th>Chester</th>\n",
       "      <th>CIFAR10</th>\n",
       "      <th>Gutenberg</th>\n",
       "      <th>in16</th>\n",
       "      <th>LaMelo</th>\n",
       "      <th>Mateo</th>\n",
       "      <th>Sokoto</th>\n",
       "      <th>Volga</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>epe_nas</th>\n",
       "      <td>8.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>20.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>6.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>nwot</th>\n",
       "      <td>16.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>plain</th>\n",
       "      <td>13.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>snip</th>\n",
       "      <td>6.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>10.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fisher</th>\n",
       "      <td>6.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>9.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>jacob_cov</th>\n",
       "      <td>7.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>grad_norm</th>\n",
       "      <td>6.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>10.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fisher_jacob</th>\n",
       "      <td>7.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AutoAugment</th>\n",
       "      <td>6.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>6.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>NoAugmentation</th>\n",
       "      <td>16.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>23.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>15.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RandAugment</th>\n",
       "      <td>3.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                Adaline  Caitie  Chester  CIFAR10  Gutenberg  in16  LaMelo  \\\n",
       "epe_nas             8.0     2.0      1.0     20.0        9.0   4.0    19.0   \n",
       "nwot               16.0     9.0      1.0      8.0       14.0   5.0     4.0   \n",
       "plain              13.0    13.0      6.0     15.0        4.0   5.0     4.0   \n",
       "snip                6.0     7.0      3.0     12.0        7.0  17.0     9.0   \n",
       "fisher              6.0     8.0      3.0      5.0        7.0   3.0     9.0   \n",
       "jacob_cov           7.0    21.0      1.0      8.0        9.0   1.0    11.0   \n",
       "grad_norm           6.0     7.0      3.0     12.0        7.0  17.0     9.0   \n",
       "fisher_jacob        7.0     9.0      3.0      8.0        7.0   1.0    11.0   \n",
       "AutoAugment         6.0    11.0      NaN      6.0        9.0  10.0    11.0   \n",
       "NoAugmentation     16.0    16.0     14.0     23.0       16.0  16.0    15.0   \n",
       "RandAugment         3.0     7.0      NaN      1.0        5.0   4.0     8.0   \n",
       "\n",
       "                Mateo  Sokoto  Volga  \n",
       "epe_nas          16.0     7.0    6.0  \n",
       "nwot              6.0     6.0    2.0  \n",
       "plain            11.0    13.0    4.0  \n",
       "snip              4.0    15.0   10.0  \n",
       "fisher            4.0    15.0    9.0  \n",
       "jacob_cov         6.0     6.0    1.0  \n",
       "grad_norm         4.0    15.0   10.0  \n",
       "fisher_jacob      6.0     6.0    2.0  \n",
       "AutoAugment       7.0     6.0    NaN  \n",
       "NoAugmentation   16.0    11.0   15.0  \n",
       "RandAugment       4.0    19.0    NaN  "
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.concat(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "ca660b97-6267-4a00-ac57-378f5edbdc4b",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Adaline</th>\n",
       "      <th>Caitie</th>\n",
       "      <th>Chester</th>\n",
       "      <th>CIFAR10</th>\n",
       "      <th>Gutenberg</th>\n",
       "      <th>in16</th>\n",
       "      <th>LaMelo</th>\n",
       "      <th>Mateo</th>\n",
       "      <th>Sokoto</th>\n",
       "      <th>Volga</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>epe_nas</th>\n",
       "      <td>19.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>20.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>6.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>nwot</th>\n",
       "      <td>4.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>plain</th>\n",
       "      <td>3.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>snip</th>\n",
       "      <td>14.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>10.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fisher</th>\n",
       "      <td>14.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>9.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>jacob_cov</th>\n",
       "      <td>16.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>grad_norm</th>\n",
       "      <td>14.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>10.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fisher_jacob</th>\n",
       "      <td>3.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AutoAugment</th>\n",
       "      <td>6.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>6.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>NoAugmentation</th>\n",
       "      <td>16.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>23.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>15.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RandAugment</th>\n",
       "      <td>3.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                Adaline  Caitie  Chester  CIFAR10  Gutenberg  in16  LaMelo  \\\n",
       "epe_nas            19.0     2.0      1.0     20.0        9.0   4.0    19.0   \n",
       "nwot                4.0     9.0      1.0      8.0       14.0   5.0     4.0   \n",
       "plain               3.0    13.0      6.0     15.0        4.0   5.0     4.0   \n",
       "snip               14.0     7.0      3.0     12.0        7.0  17.0     9.0   \n",
       "fisher             14.0     8.0      3.0      5.0        7.0   3.0     9.0   \n",
       "jacob_cov          16.0    21.0      1.0      8.0        9.0   1.0    11.0   \n",
       "grad_norm          14.0     7.0      3.0     12.0        7.0  17.0     9.0   \n",
       "fisher_jacob        3.0     9.0      3.0      8.0        7.0   1.0    11.0   \n",
       "AutoAugment         6.0    11.0      NaN      6.0        9.0  10.0    11.0   \n",
       "NoAugmentation     16.0    16.0     14.0     23.0       16.0  16.0    15.0   \n",
       "RandAugment         3.0     7.0      NaN      1.0        5.0   4.0     8.0   \n",
       "\n",
       "                Mateo  Sokoto  Volga  \n",
       "epe_nas          16.0     7.0    6.0  \n",
       "nwot              6.0     6.0    2.0  \n",
       "plain            11.0    13.0    4.0  \n",
       "snip              4.0    15.0   10.0  \n",
       "fisher            4.0    15.0    9.0  \n",
       "jacob_cov         6.0     6.0    1.0  \n",
       "grad_norm         4.0    15.0   10.0  \n",
       "fisher_jacob      6.0     6.0    2.0  \n",
       "AutoAugment       7.0     6.0    NaN  \n",
       "NoAugmentation   16.0    11.0   15.0  \n",
       "RandAugment       4.0    19.0    NaN  "
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.concat(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8f58c28e-cbd8-4574-8d56-0046d74a57e9",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "epe_nas            9.100\n",
       "nwot               8.800\n",
       "plain             10.400\n",
       "snip              11.100\n",
       "fisher             8.000\n",
       "jacob_cov          5.800\n",
       "grad_norm         11.100\n",
       "fisher_jacob       6.700\n",
       "AutoAugment        8.250\n",
       "NoAugmentation    15.800\n",
       "RandAugment        6.375\n",
       "dtype: float64"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.concat(results).mean(axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 227,
   "id": "d9f0ed7b-c2c1-416f-8f4b-451015d89cc7",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "epe_nas           10.00\n",
       "nwot               1.00\n",
       "plain             12.60\n",
       "snip              10.40\n",
       "fisher            10.10\n",
       "jacob_cov         14.70\n",
       "grad_norm         11.00\n",
       "AutoAugment       12.25\n",
       "NoAugmentation    10.60\n",
       "RandAugment        7.50\n",
       "dtype: float64"
      ]
     },
     "execution_count": 227,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.concat(results).mean(axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "913963d5-af7b-4108-89e1-b81164750719",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#pd.concat(results).to_csv(\"EfficientNet_b0_pos.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "ebbbc011-0a01-4939-becf-a72c4b51ca2a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "epe_nas           11.100\n",
       "nwot              10.900\n",
       "plain             13.600\n",
       "snip               7.600\n",
       "fisher             8.600\n",
       "jacob_cov          8.200\n",
       "grad_norm          9.300\n",
       "fisher_jacob       5.900\n",
       "AutoAugment        8.250\n",
       "NoAugmentation    15.800\n",
       "RandAugment        6.375\n",
       "dtype: float64"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.read_csv(\"RegNetY_400MF_pos.csv\", index_col=0).mean(axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "ebf86fba-f81d-4893-a3dd-a4e44f5828d1",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "epe_nas            9.50\n",
       "nwot               1.00\n",
       "plain             10.80\n",
       "snip              16.20\n",
       "fisher            14.30\n",
       "jacob_cov          9.30\n",
       "grad_norm         12.80\n",
       "fisher_jacob       7.10\n",
       "AutoAugment       12.25\n",
       "NoAugmentation    10.60\n",
       "RandAugment        7.50\n",
       "dtype: float64"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.read_csv(\"EfficientNet_b0_pos.csv\", index_col=0).mean(axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "5c3da7f8-e77a-4a76-bb7c-c756a3540832",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "epe_nas           10.500\n",
       "nwot               8.100\n",
       "plain              5.000\n",
       "snip               9.000\n",
       "fisher             8.000\n",
       "jacob_cov          7.700\n",
       "grad_norm          8.600\n",
       "fisher_jacob       7.500\n",
       "AutoAugment       11.000\n",
       "NoAugmentation    13.700\n",
       "RandAugment        5.375\n",
       "dtype: float64"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.read_csv(\"ResNet18_pos.csv\", index_col=0).mean(axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "f619610d-29f1-4836-8f74-ce244b4d2748",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "index\n",
       "epe_nas     -0.209635\n",
       "fisher      -0.236562\n",
       "grad_norm   -0.123767\n",
       "grasp        0.112866\n",
       "jacob_cov   -0.362412\n",
       "nwot        -0.250534\n",
       "plain       -0.006321\n",
       "snip        -0.163792\n",
       "Name: test_acc, dtype: float64"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.read_csv(\"correlations_RegNetY_400MF.csv\", index_col=0).reset_index().groupby(\"index\")[\"test_acc\"].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "4af54105-e3d1-4f44-947e-b7d8f4d3a18c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "index\n",
       "epe_nas     -0.232228\n",
       "fisher       0.041336\n",
       "grad_norm    0.186782\n",
       "grasp        0.013587\n",
       "jacob_cov   -0.303095\n",
       "nwot              NaN\n",
       "plain       -0.085789\n",
       "snip         0.067450\n",
       "Name: test_acc, dtype: float64"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.read_csv(\"correlations_EfficientNet_b0.csv\", index_col=0).reset_index().groupby(\"index\")[\"test_acc\"].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4360f3c4-d1fd-47a4-b90d-df1c90e58f4a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PyTorch 2.2.0",
   "language": "python",
   "name": "pytorch-2.2.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
