{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import os\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"6\"\n",
    "import torchvision\n",
    "from voc import *\n",
    "from coco import *\n",
    "import torchvision.transforms as transforms\n",
    "from torchvision.models import resnet152, resnet101, resnet18, resnet34, resnet50\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from config import seed_everything\n",
    "seed_everything(0)\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy.misc\n",
    "from PIL import Image\n",
    "import json\n",
    "%matplotlib inline\n",
    "import pathlib\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "coco"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "82081 {'file_name': 'COCO_train2014_000000057870.jpg', 'labels': [12, 77, 51, 22, 27]}\n",
      "34\n",
      "Counter({1: 1234, 2: 407, 3: 116, 4: 27, 5: 8, 8: 3, 7: 2, 6: 1})\n",
      "45174.0 128.0\n",
      "dict_keys([49, 22, 18, 27, 26, 13, 14, 36, 74, 2, 8, 12, 21, 57, 75, 24, 23, 41, 51, 28, 62, 72, 20, 16, 76, 68, 7, 32, 77, 58, 61, 42, 73, 44, 67, 65, 70, 9, 0, 10, 59, 50, 53, 11, 17, 37, 47, 5, 4, 34, 78, 52, 54, 64, 40, 3, 30, 33, 66, 39, 25, 15, 79, 45, 46, 63, 31, 19, 1, 60, 56, 43, 29, 38, 71, 55, 6, 48, 69, 35]) dict_values([45174, 8950, 8606, 8378, 6518, 5968, 5028, 4861, 4321, 3924, 3844, 3734, 3322, 3291, 3191, 3170, 3159, 3097, 3084, 3041, 2986, 2893, 2818, 2791, 2749, 2667, 2539, 2537, 2530, 2511, 2493, 2475, 2464, 2442, 2368, 2343, 2317, 2287, 2243, 2241, 2209, 2202, 2180, 2098, 2080, 2068, 2003, 1884, 1804, 1798, 1771, 1671, 1645, 1631, 1625, 1618, 1518, 1511, 1510, 1471, 1389, 1340, 1324, 1290, 1216, 1214, 1205, 1186, 1171, 1170, 1105, 1089, 1062, 821, 700, 673, 668, 481, 151, 128])\n"
     ]
    }
   ],
   "source": [
    "path_csv = '../data/coco'\n",
    "\n",
    "from collections import defaultdict, Counter\n",
    "class_num = defaultdict(int)\n",
    "with open(path_csv + '/data/train_anno.json') as f:\n",
    "  adj = np.zeros((80,80))\n",
    "  import json\n",
    "  train = json.load(f)\n",
    "  print(len(train), train[0])\n",
    "\n",
    "  li = []\n",
    "  gt_labels = np.zeros((len(train),80))\n",
    "  img_id2idx = dict()\n",
    "  idx2img_id = []\n",
    "  for i,each in enumerate(train):\n",
    "    li += each['labels']\n",
    "    gt_labels[i, each['labels']] = 1\n",
    "    for l in each['labels']:\n",
    "      class_num[l] += 1\n",
    "\n",
    "  nums = gt_labels.sum(axis=0)\n",
    "  adj = []\n",
    "  for i,col in enumerate(gt_labels.T):\n",
    "    if i in [34]:\n",
    "      print(i)\n",
    "      selected = gt_labels[np.isin(col, [1.0]), :]\n",
    "      nonzero_cnt = (selected != 0).sum(1)\n",
    "      cnter = Counter(nonzero_cnt)\n",
    "      print(cnter)\n",
    "    cond_prob = gt_labels[np.isin(col,[1.0]),:].sum(axis=0)\n",
    "    cond_prob[i] = 0\n",
    "    adj.append(cond_prob)\n",
    "    # print(adj[-1])\n",
    "  nums = nums.tolist()\n",
    "  nums.sort()\n",
    "  nums.reverse()\n",
    "  # nums = reversed(nums)\n",
    "  print(max(nums), min(nums))\n",
    "  di={'adj': np.asarray(adj), \"nums\": np.asarray(nums)}\n",
    "  class_di = {k: v for k, v in sorted(class_num.items(), key=lambda item: item[1], reverse=True)} #sorted\n",
    "print(class_di.keys(), class_di.values())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[dataset] Done!\n",
      "[annotation] Done!\n",
      "[json] Done!\n",
      "82081\n"
     ]
    }
   ],
   "source": [
    "test_dataset = COCO2014('../data/coco', phase='train')\n",
    "print(test_dataset.__len__())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[dataset] Done!\n",
      "[annotation] Done!\n",
      "[json] Done!\n"
     ]
    }
   ],
   "source": [
    "test_dataset = COCO2014('../data/coco', phase='val')\n",
    "# partial=torch.utils.data.Subset(test_dataset, list(range(100)))\n",
    "# train_dataset = Voc2007Classification('data/voc', 'trainval', inp_name='data/voc/voc_glove_word2vec.pkl', LT=True)\n",
    "# test_dataset = Voc2007Classification('data/voc', 'test', inp_name='data/voc/voc_glove_word2vec.pkl')\n",
    "# train_dataset = COCO2014('data/coco', phase='train', inp_name='data/coco/coco_glove_word2vec.pkl')\n",
    "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                                  std=[0.229, 0.224, 0.225])\n",
    "test_dataset.transform = transforms.Compose([\n",
    "                Warp(384),\n",
    "                transforms.ToTensor(),\n",
    "                normalize,\n",
    "            ])\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False)\n",
    "from util import AveragePrecisionMeter\n",
    "AP = AveragePrecisionMeter(difficult_examples=False)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'resnet50': 'resnetv2_50x3_bitm_in21k', 'swin': 'swin_base_patch4_window12_384_in22k', 'swin_large': 'swin_large_patch4_window12_384_in22k', 'convnext': 'convnext_large_in22k', 'resnet101': 'resnetv2_101x1_bitm_in21k'}\n",
      "resnet50 : resnetv2_50x3_bitm_in21k\n",
      "swin : swin_base_patch4_window12_384_in22k\n",
      "swin_large : swin_large_patch4_window12_384_in22k\n",
      "convnext : convnext_large_in22k\n",
      "resnet101 : resnetv2_101x1_bitm_in21k\n"
     ]
    }
   ],
   "source": [
    "from backbones.config import config\n",
    "print(config)\n",
    "for k, v in config.items():\n",
    "  print(\"{} : {}\".format(k, v))\n",
    "  pathlib.Path('../figures/{}'.format(k)).mkdir(parents=True, exist_ok=True) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'mAP': tensor(77.5326), 'OF1': 0.7472008685789622, 'CF1': 0.7163356289242881}\n",
      "dict_keys(['epoch', 'arch', 'state_dict', 'best_score'])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from models import *\n",
    "from backbones.config import config\n",
    "import pathlib\n",
    "\n",
    "# model = base_swin(config['swin_large'], 80, image_size=384, pretrained=True, cond=True, where=0, aggregate='1')\n",
    "model = base_resnet101(config['resnet101'], 80, 384, cond=False)\n",
    "# path = '../checkpoint/coco/coco_l_alpha-@0_swin_large_base_best.pth.tar'\n",
    "# path ='../checkpoint/coco/coco_squeeze-excitation-noscale_resnet101_base_best.pth.tar'\n",
    "path ='../checkpoint/coco/coco_baseline-scheduler_resnet101_base_best.pth.tar'\n",
    "# path ='../checkpoint/coco/coco_baseline_swin_large_base_best.pth.tar'\n",
    "# path = '../checkpoint/coco/coco_l_alpha@2_swin_base_best.pth.tar'\n",
    "di = torch.load(path)\n",
    "print(di['best_score'])\n",
    "print(di.keys())\n",
    "model.load_state_dict(di['state_dict'])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "402it [07:56,  1.19s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([40137, 80])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "# model = get_model(i).to(device)\n",
    "model = model.to(device)\n",
    "model = model.eval()\n",
    "act_li = np.array([[]])\n",
    "for i, (input, target) in tqdm(enumerate(test_loader)):\n",
    "  img, path = input\n",
    "  target[target == 0] = 1\n",
    "  target[target == -1] = 0\n",
    "  feat_Var = torch.autograd.Variable(img).float().to(device)\n",
    "  \n",
    "  # output = model(feat_Var, None).detach()\n",
    "  output = model(feat_Var)\n",
    "  # act_li = np.append(act_li, act_.detach().cpu().numpy())\n",
    "  # print(act_li.shape)\n",
    "  # print(output.requires_grad, target.requires_grad)\n",
    "  # print(output.shape, target.shape)\n",
    "  AP.add(output.detach(), target)\n",
    "\n",
    "# map = 100 * AP.value().mean()\n",
    "# print(100 * AP.value())\n",
    "# ap_li = 100 * AP.value()\n",
    "print(AP.scores.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save as file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "m_name=\"resnet101\"\n",
    "act_li = act_li.squeeze()\n",
    "with open('../figures/{}/act.npy'.format(m_name), 'wb') as f:\n",
    "    np.save(f, act_li)\n",
    "with open('../figures/{}/apscore.npy'.format(m_name), 'wb') as f:\n",
    "    np.save(f, AP.scores)\n",
    "with open('../figures/{}/apvalue.npy'.format(m_name), 'wb') as f:\n",
    "    np.save(f, 100*AP.value())\n",
    "with open('../figures/{}/aptarget.npy'.format(m_name), 'wb') as f:\n",
    "    np.save(f, AP.targets)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load saved files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([40137, 80])\n"
     ]
    }
   ],
   "source": [
    "m_name='resnet101'\n",
    "act_li = np.load('../figures/{}/act.npy'.format(m_name))\n",
    "sc_li = np.load('../figures/{}/scale.npy'.format(m_name))\n",
    "AP = AveragePrecisionMeter()\n",
    "AP.scores = torch.from_numpy(np.load('../figures/{}/apscore.npy'.format(m_name)))\n",
    "AP.targets = torch.from_numpy(np.load('../figures/{}/aptarget.npy'.format(m_name)))\n",
    "print(AP.scores.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'airplane'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import json\n",
    "with open('../data/coco/data/category.json', 'r') as f:\n",
    "  category = json.load(f)\n",
    "  # category[\"banana\"]\n",
    "  inverse = {v:k for k, v in category.items()}\n",
    "\n",
    "inverse[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for i in range(AP.scores.shape[1]):##iterate each class\n",
    "  classwise = dict({\"alpha\":[], \"score\":[], \"target\": [],})\n",
    "  score = AP.scores[:,i]\n",
    "  target = AP.targets[:, i]\n",
    "\n",
    "  #filter\n",
    "  # true_indices = torch.where(target == 0)\n",
    "\n",
    "  score_true = score[:]\n",
    "  # print(act_li.shape)\n",
    "  act_li_ = act_li[:]\n",
    "  # sc_li = sc_li[true_indices]\n",
    "  # print(torch.where(target==1))\n",
    "  \n",
    "  # dataset_true = [test_dataset[each] for each in torch.where(target == 0)[0].tolist()]\n",
    "  sorted, indices = torch.sort(score_true, dim=0, descending=True)\n",
    "  #sort with indices\n",
    "  act_li_sorted = act_li_[indices]\n",
    "  target_sorted = target[indices]\n",
    "  # sc_li = sc_li[indices]\n",
    "  \n",
    "  print(score_true.shape)\n",
    "  # classwise[\"class_index\"].append(i)\n",
    "  classwise[\"alpha\"] = act_li_sorted\n",
    "  # classwise[\"scale\"] = sc_li\n",
    "  classwise[\"score\"] = sorted\n",
    "  classwise[\"target\"] = target_sorted\n",
    "\n",
    "  \n",
    "\n",
    "  img_df = pd.DataFrame(data=classwise)\n",
    "  img_df.to_csv(\"../figures/{}/class[{}]_alpha_scale.csv\".format(m_name, inverse[i]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plot alpha scale distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_outlier(points, thresh=3.5):\n",
    "    \"\"\"\n",
    "    Returns a boolean array with True if points are outliers and False \n",
    "    otherwise.\n",
    "\n",
    "    Parameters:\n",
    "    -----------\n",
    "        points : An numobservations by numdimensions array of observations\n",
    "        thresh : The modified z-score to use as a threshold. Observations with\n",
    "            a modified z-score (based on the median absolute deviation) greater\n",
    "            than this value will be classified as outliers.\n",
    "\n",
    "    Returns:\n",
    "    --------\n",
    "        mask : A numobservations-length boolean array.\n",
    "\n",
    "    References:\n",
    "    ----------\n",
    "        Boris Iglewicz and David Hoaglin (1993), \"Volume 16: How to Detect and\n",
    "        Handle Outliers\", The ASQC Basic References in Quality Control:\n",
    "        Statistical Techniques, Edward F. Mykytka, Ph.D., Editor. \n",
    "    \"\"\"\n",
    "    if len(points.shape) == 1:\n",
    "        points = points[:,None]\n",
    "    median = np.median(points, axis=0)\n",
    "    diff = np.sum((points - median)**2, axis=-1)\n",
    "    diff = np.sqrt(diff)\n",
    "    med_abs_deviation = np.median(diff)\n",
    "\n",
    "    modified_z_score = 0.6745 * diff / med_abs_deviation\n",
    "\n",
    "    return modified_z_score > thresh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "airplane 0\n",
      "apple 1\n",
      "backpack 2\n",
      "banana 3\n",
      "baseball bat 4\n",
      "baseball glove 5\n",
      "bear 6\n",
      "bed 7\n",
      "bench 8\n",
      "bicycle 9\n",
      "bird 10\n",
      "boat 11\n",
      "book 12\n",
      "bottle 13\n",
      "bowl 14\n",
      "broccoli 15\n",
      "bus 16\n",
      "cake 17\n",
      "car 18\n",
      "carrot 19\n",
      "cat 20\n",
      "cell phone 21\n",
      "chair 22\n",
      "clock 23\n",
      "couch 24\n",
      "cow 25\n",
      "cup 26\n",
      "dining table 27\n",
      "dog 28\n",
      "donut 29\n",
      "elephant 30\n",
      "fire hydrant 31\n",
      "fork 32\n",
      "frisbee 33\n",
      "giraffe 34\n",
      "hair drier 35\n",
      "handbag 36\n",
      "horse 37\n",
      "hot dog 38\n",
      "keyboard 39\n",
      "kite 40\n",
      "knife 41\n",
      "laptop 42\n",
      "microwave 43\n",
      "motorcycle 44\n",
      "mouse 45\n",
      "orange 46\n",
      "oven 47\n",
      "parking meter 48\n",
      "person 49\n",
      "pizza 50\n",
      "potted plant 51\n",
      "refrigerator 52\n",
      "remote 53\n",
      "sandwich 54\n",
      "scissors 55\n",
      "sheep 56\n",
      "sink 57\n",
      "skateboard 58\n",
      "skis 59\n",
      "snowboard 60\n",
      "spoon 61\n",
      "sports ball 62\n",
      "stop sign 63\n",
      "suitcase 64\n",
      "surfboard 65\n",
      "1.734464994922665e-10\n",
      "teddy bear 66\n",
      "tennis racket 67\n",
      "tie 68\n",
      "toaster 69\n",
      "toilet 70\n",
      "toothbrush 71\n",
      "traffic light 72\n",
      "train 73\n",
      "truck 74\n",
      "tv 75\n",
      "umbrella 76\n",
      "vase 77\n",
      "wine glass 78\n",
      "zebra 79\n"
     ]
    }
   ],
   "source": [
    "# path=\"../figures/swin_large/class[hair drier]_alpha_scale.csv\"\n",
    "\n",
    "# df = pd.read_csv(path)\n",
    "# alpha_t = df.loc[df['target'] == 1, 'alpha']\n",
    "# print(alpha_t.describe())\n",
    "# alpha_f = df.loc[df['target'] == 0, 'alpha']\n",
    "# print(alpha_f.describe())\n",
    "import json\n",
    "alpha = []\n",
    "con_T = None\n",
    "con_F = None\n",
    "with open('../data/coco/data/category.json', 'r') as f:\n",
    "  category = json.load(f)\n",
    "  # category[\"banana\"]\n",
    "  inverse = {v:k for k, v in category.items()}\n",
    "\n",
    "  for k, v in category.items():\n",
    "    print(k, v)\n",
    "    if k == 'surfboard':\n",
    "      path=\"../figures/{}/class[{}]_alpha_scale.csv\".format(m_name, k)\n",
    "\n",
    "      df = pd.read_csv(path)\n",
    "      alpha_t = df.loc[df['target'] == 1, 'alpha']\n",
    "      alpha_f = df.loc[df['target'] == 0, 'alpha']\n",
    "      con_T = pd.concat([alpha_t, con_T])\n",
    "      con_F = pd.concat([alpha_t, con_F])\n",
    "      # alpha_t.to_csv(\"../figures/{}/classTrue[{}]_alpha_scale.csv\".format(m_name, k))\n",
    "      # print(alpha_t.head())\n",
    "      # print(alpha_t.describe())\n",
    "      print(alpha_t.mean())\n",
    "      alpha.append(alpha_t.mean())\n",
    "# assert len(alpha)==80\n",
    "\n",
    "# df = pd.DataFrame({\"class\": list(category.keys()), \"scale\": alpha})\n",
    "# df.to_csv(\"../figures/{}/scale_per_class.csv\".format(m_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_42511/3611878041.py:24: FutureWarning: Support for multi-dimensional indexing (e.g. `obj[:, None]`) is deprecated and will be removed in a future version.  Convert to a numpy array before indexing instead.\n",
      "  points = points[:,None]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZsAAAELCAYAAAAP/iu7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZWElEQVR4nO3de5RkZX3u8e8jFxHv4ogGMwxIvGC8HUcPehTxguCSED1H8YKooIxi4l0jR4ziiDHGo7LEGyRGjKCC9xBOEDjIgAoqYlARQXSAoCij4BUcIPzOH7sGm6Kru3ZX7a7q5vtZq1d17b3fXb/qBTy8e7/7fVNVSJLUpdtMugBJ0vJn2EiSOmfYSJI6Z9hIkjpn2EiSOrf5pAuYRne/+91r1apVky5Dkpacb33rW7+oqhX92w2bWaxatYpzzjln0mVI0pKT5NLZtnsZTZLUOcNGktQ5w0aS1DnDRpLUOcNGktQ5w0aS1DnDRpLUOcNGktQ5w0aS1DlnEFji3nvKRTf9/urd7zvBSiRpMHs2kqTOGTaSpM4ZNpKkzhk2kqTOGTaSpM4ZNpKkzhk2kqTOGTaSpM4ZNpKkzhk2kqTOGTaSpM4ZNpKkzhk2kqTOGTaSpM4ZNpKkzi162CS5d5IjkpyV5JoklWRV3zGrkxyV5Ae9Yy5LcmySHYb8jNN75+3/eVUX30mSNLdJLJ62E7AP8C3gTODJsxzzbOCBwPuA84HtgL8Fzkny0Kr6zyE+5zvAS/q2XbLAmiVJI5hE2JxRVdsCJHkxs4fNO6tqw8wNSb4KrAcOBN48xOf8tqrOHrVYSdLoFv0yWlXdOMQxG2bZdimwgaaXI0laQpbMAIEkDwDuAVwwZJOHJfl1kuuTfCfJizosT5I0h0lcRmstyebAh2l6Nh8ZoskZwLHARcBdgOcD/5TkXlV12IDPWAOsAVi5cuUYqpYkbbIkwgZ4P/Bo4KlVdfV8B1dV/z2dLyb5PHBIksOr6neztDkKOApg9erVNYaaJUk9U38ZLcnf0/Q4Dqiqk0c41SeBrYAHjaUwSdLQprpnk+QQ4A3Ay6vq42M6rb0WSVpkU9uzSfIK4DDgkKp6/xhOuS9wLfDdMZxLktTCRHo2SZ7R+/XhvdenJNkAbKiqdUmeDRwOnASclmSXGc1/U1Xfn3Gui4FLq+qJvfePBQ4GPkfzEOedgRcAewMHV9XvO/tikqRZTeoy2qf73n+w97oO2A3YE0jvdc++Yzcds8nmwGYz3l9B02NbC9wduJ5mNoHnVtUnRy9dktTWRMKmqjLP/hcCLxzyXKv63l8MPGWBpUmSOjC192wkScuHYSNJ6pxhI0nqnGEjSeqcYSNJ6pxhI0nqnGEjSeqcYSNJ6pxhI0nqnGEjSeqcYSNJ6pxhI0nqnGEjSeqcYSNJ6txULwuthXvvKRfd9Purd7/vBCuRJHs2kqRFYNhIkjpn2EiSOmfYSJI6Z9hIkjpn2EiSOmfYSJI6Z9hIkjq36GGT5N5JjkhyVpJrklSSVbMct1WSdyW5Ism1veN3bfE5Byb5QZKNSS5M8tKxfhFJ0tAm0bPZCdgHuBo4c47jPgIcCLwZ2Au4AvhSkofO9wFJDgSOBD4L7Al8GvhgkoNGqlyStCCTmK7mjKraFiDJi4En9x+Q5CHAc4EDquqjvW3rgPOBtcDeg06eZHPg7cDHq+qQ3uYvJ/kT4G1J/qmqrh/nF5IkzW3RezZVdeMQh+0NXA8cN6PdDcCngD2S3HaOto8CVgDH9G3/OLAN8JhWBUuSRjatAwQeCKyvqmv6tp8PbElzKW6utgDfm6UtwM6jlydJamNaw+ZuNPd0+l01Y/9cbZml/TBtJUkdcImBniRrgDUAK1eu7PSzZk7/P8hcywIMaj/MeSVpElr1bJL8XZJu/0vcuBq46yzbN/VKrppl38y2zNJ+zrZVdVRVra6q1StWrBi6UEnS/NpeRns58OMk/zfJ3km6ugx3PrBDkq37tu8MXAdcPE9b+OO9m5ltAb4/enmSpDbahsW9gL8CtgW+AFya5C1JthtzXScAWwDP3LShN6T5WcDJVbVxjrZnAb8A9u3b/jyaXs1Xx1uqJGk+re7ZVNXvaB6WPDLJI4CXAq8HDklyInBkVZ0033mSPKP368N7r09JsgHYUFXrqurbSY4DDk+yBbAeOAjYgb4QSXIxcGlVPbFX4/VJ/pbmIc6fAKcCTwAOAF5eVde1+c6SpNEteIBAVX0T+GaSg2me0P9LYO8klwLvBj40xzM1n+57/8He6zpgt97v+9M8nHkYcBfgPGDPqjp3lu+wWV9tH05SwGtpwvAy4K+r6oNIkhbdgsMmyX2AlwAvpLkZ/3maEPkL4HDgIfRGd/Wrqsx3/qq6FnhN72eu41YN2H4kTS9MkjRhrcImyWbA02lC5vHAz4EP0Vw++2nvsE8lORN4JwPCRpJ069K2Z/MTmqlgzgCeA3y+N41Mv28DdxyxNknSMtE2bI6nuRdzwVwHVdXXmd7ZCSRJi6ztaLRXdFWIJGn5ajuDwBuSHDFg3/uSvH48ZUmSlpO2l7r2B74zYN9/9PZLknQzbcNmJfDDAft+DGw/WjmSpOWo7QCBa4BBU9PcG5hrGplbtWmZkXlmHTNnlh60fVDbYcw1c7WkW5e2PZszgdf3r5TZe//a3n5Jkm6mbc/mUOBrwEVJjqF57mY7mkkut6GZTUCSpJtpO/T5vCSPB/4P8AaantGNwFeA/1VV542/REnSUtd6brSq+gawa5Lb0cyJdnVvHjNJkmY1yqzP1wKGjCRpXq3DJsmOwD40w6C36ttdVfWicRQmSVo+2s76/DSa+dFuA1zJLYc613jKkiQtJ217Nm8DTgf2raoN4y9HkrQctQ2bHYHXGjSSpDbaPtT5A5rnaSRJGlrbsPkb4I29QQKSJA1lITMIbANckOSHwFV9+6uqHjeOwiRJy0fbsPkv4MIuCpEkLV9tp6vZraM6JEnL2IJnENDSN8ySAaMsjTDKcgaSlpe2AwRIsl2S9yQ5J8n6JH/e2/6qJP99/CVKkpa6VmGT5IHAd4H9gJ/STFmzZW/39sArx1VYktOT1ICfk+ZpO6jdQ8dVnyRpeG0vo70buADYA/gDcN2MfV8D3jmmugBeBtypb9ujgPcA/zpE+6OBI/u2TcdymZJ0K9M2bB4DPKeqfpdks759PwfuOZ6yoKq+378tyYE0AfepIU7xk6o6e1z1SJIWru09mxvn2Hd3OlxyIMnWwDOBE6qq//keSdIUaxs23wD2H7BvH+Cro5Uzp6cDdwQ+NuTxByXZmOSaJKcleWyHtUmS5rCQWZ9PTXIy8AmaJQWelOSVNGGw65jrm+n5NMsa/PsQxx4D/BvNIIbtgdcDpyXZvapOn61BkjXAGoCVK1eOo15JUk+rnk1VrQOeBuwA/DMQ4O+BxwJPq6qvj7tAgCR/AjwJOLaqbhiizv2q6riqOrOqjqG51/RT4LA52hxVVauravWKFSvGVrskaQEPdVbVicCJSXYC7gH8sqq6nsLmeTTBOOwltJupqt8mORFwFVFJmoAFzyBQVRcDF4+xlrm8ADivqs4b8TyuJCpJE9B2Wejnz3dMVf3LwsuZ9TNXAzsDrxnhHHcC9qIZ4CBJWmRtezZHD9g+s8cw1rChGRhwA3Bs/44k2wM/AtZW1drettcB9wO+zB8HCLyO5hmgfcdcmyRpCG3DZodZtm1D02t4Ls29lbFJsgXwHOCkqrpytkOAzbj5QIcLaUbGPR24M/AbmiHZL6oqezaSNAFtlxi4dJbNlwLnJgnNpa7njqOw3uddDwwcGlZVl9AEzsxtJwAnjKuGtsY1o/Eosy2PajE/2xmgpVuH1rM+z+FM4KljPJ8kaZkYZ9jsAvxujOeTJC0TbUejvXmWzVsCf07Tq3n/OIqSJC0vbQcIHDrLto00923eDrxj1IIkSctP2wEC47zsJkm6lTA8JEmda3vPptV0yFV1WbtyJEnLUdt7NpfQbn6x/tU8JUm3Qm3D5iDgEJqn8o/nj0tB7wPcgWaQwMZxFihJWvrahs0DgHOBp1fVTT2cJGuBLwAPqKpXj688SdJy0HaAwHOAI2cGDUDv/YcZ41Q1kqTlo23Y3IHBc5XdA7j9aOVIkpajtmFzOvB3SR4xc2OSR9Lcrzl9PGVJkpaTtmHz1zQDAM5OckmSrye5BDgL+ENvvyRJN9N2BoH1Se4PvJBm4s17Ad+jCZuP9ZYEUM8klwmYaVrqGJeluCzBUqxZGqe2o9E2rTHzj70fSZLm1TpsAJI8GNiVZpXOI6vqZ0l2An5eVb8dZ4GSpKWv7XQ1twWOAf4nzQqZRbMq5s+AfwAuAg4ec42SpCWu7QCBtwNPAvYDtuXmSzL/O7DHmOqSJC0jbS+jPQd4U1V9Ikn/vGfrgVVjqUqStKy07dlsA1wwx7luO1o5kqTlqG3YrAceNWDfI4ELRytHkrQctQ2bfwEOTrIvsEVvWyV5PPBq4J/HWZwkaXloGzb/AJwIfBy4urftK8CpwElVdcS4CkuyW5Ka5edXQ7TdKsm7klyR5NokZyXZdVy1SZLaaTuDwH8Bz07yAZqRZ/cAfkkTNOs6qA/gFcA3Z7y/YYg2HwGeCrwe+DHwV8CXkjyqqv5j7BVKkuY0dNgk2RI4Gzi4qk4Gzuysqpu7oKrOHvbgJA+hWerggKr6aG/bOuB8YC2wdydVSpIGGvoyWlVdB+zAcD2LSdobuB44btOGqroB+BSwR+/BVEnSImp7z+YU4MldFDKHY5P8V5JfJvlEkpXzHP9AYH1VXdO3/XxgS2CnTqqUJA3U9qHOI4BjkmxOswz0FTRT1tykqn48ntL4NfBuYB3wG+BhwBuBs5I8rKquHNDubvxx8MJMV83YfwtJ1gBrAFaunC/P1NZizzzddpblQfWNMkPzcpttWxpF27DZNAjgNTRDnWfTP7PAglTVt4Fvz/zsJGcA36AZNPCmcXzOjM87CjgKYPXq1TXP4ZKkFuYNmyRPAL5RVb8DDqCvJ7OYqurcJBcBj5jjsKuB7WfZvqlHc9Us+yRJHRqmZ3MKzawB36iqo5Pchmb55xdV1Q+7LG4OcwXe+cDTk2zdd99mZ+A64OJOK5Mk3cIwAwQyy/vHAHccfznzFJKsBu5HcyltkBNoZjd45ox2mwPPAk6uqo2dFilJuoUFLZ62GJIcSzMX27nAr2gGCPxv4CfA+3rHbA/8CFhbVWuhudeT5Djg8CRb9M5xEM2w7X0X+WtIkpjisAG+R7OkwcuBrWkWaPsc8Jaq+kXvmNAMSOjvoe1Ps/bOYcBdgPOAPavq3O7LliT1GzZstkuyY+/3zWZs+1X/geMa+lxV7wDeMc8xl3DLy3xU1bU0I+ZeM45aJEmjGTZsPjPLti8MOHYsQ58lScvHMGGzf+dVSJKWtXnDpqo+thiFSJKWr7Zzo0mS1JphI0nqnGEjSeqcYSNJ6tw0P9SpW5n+KflnTu8/ynT9XbQdVNsoSxIsJUv1Oy/VupcDezaSpM4ZNpKkzhk2kqTOGTaSpM4ZNpKkzhk2kqTOGTaSpM4ZNpKkzhk2kqTOGTaSpM4ZNpKkzhk2kqTOGTaSpM4567Om1jCzNY8yo/OkzDW7ddv2o8w+Pa7zLMQkP1uTYc9GktS5qQ2bJM9I8tkklya5NsmFSd6R5I5DtK0BPw9dhNIlSX2m+TLa64DLgDcClwMPAw4FHp/k0VV14zztjwaO7Nu29K65SNIyMM1h8xdVtWHG+3VJrgI+BuwGnDZP+59U1dldFSdJGt7UXkbrC5pNvtl73W4xa5EkjWZqw2aAx/VeLxji2IOSbExyTZLTkjy2y8IkSYMtmbBJsh2wFji1qs6Z5/BjgJcBTwLWANsApyXZrcsaJUmzm+Z7NjdJcgfgi8ANwP7zHV9V+814e2aSLwLfAw4DHjPgM9bQBBMrV64ctWRJ0gxT37NJcjvgBGBHYI+qurztOarqt8CJwCPmOOaoqlpdVatXrFix4HolSbc01T2bJFsAnwFWA7tX1XdHPGWNXpUkqa2pDZsktwGOBZ4A7DXKMOYkdwL2Ar4xpvIkSS1MbdgAHwCeCbwd+H2SXWbsu7yqLk+yPfAjYG1VrQVI8jrgfsCXgZ8C29M8IHpPYN9FrF+S1DPNYfOU3ushvZ+Z3kozm0CAzbj5vacLgaf3fu4M/Ab4KvCiqrJnI0kTMLVhU1WrhjjmEprAmbntBJoBBZKkKTG1YSMtRQtZ8mBQm1Gm2B/X0gvjWrZgseuY5PmXUq2LaeqHPkuSlj7DRpLUOcNGktQ5w0aS1DnDRpLUOcNGktQ5w0aS1DnDRpLUOcNGktQ5w0aS1DnDRpLUOcNGktQ5w0aS1DlnfZYWYFyzKg/7GYNm9+1idudpMaimtrNkt/1uo8x0vRBdzOI8LbN1z2TPRpLUOcNGktQ5w0aS1DnDRpLUOcNGktQ5w0aS1DnDRpLUOcNGktQ5w0aS1LmpDZskf5rkM0l+neQ3ST6XZOWQbbdK8q4kVyS5NslZSXbtumZJ0uymMmySbA2cBtwfeAGwH/BnwJeT3H6IU3wEOBB4M7AXcAXwpSQP7aRgSdKcpnVutAOBHYH7VdXFAEm+A/wQeAnwnkENkzwEeC5wQFV9tLdtHXA+sBbYu9vSJUn9prJnQxMIZ28KGoCqWg98FfjLIdpeDxw3o+0NwKeAPZLcdvzlSpLmMq1h80Dge7NsPx/YeYi266vqmlnabgnsNHp5kqQ2UlWTruEWklwHvKeqDu7bfhhwcFUNvPyX5GTgTlW1S9/2JwGnALtW1ZmztFsDrOm9vR9w4ZDl3h34xZDHam7+LcfLv+f4+Lcc3vZVtaJ/47Tes1l0VXUUcFTbdknOqarVHZR0q+Pfcrz8e46Pf8vRTetltKuBu86y/W69fQttC3DVCHVJkhZgWsPmfJp7L/12Br4/RNsdesOn+9teB1x8yyaSpC5Na9j8K7BLkh03bUiyCvgfvX1zOQHYAnjmjLabA88CTq6qjWOutfWlNw3k33K8/HuOj3/LEU3rAIHbA+cB1wJvAgp4G3BH4MFV9bvecdsDPwLWVtXaGe0/BewBvB5YDxxE83Dno6vq3EX8KpIkprRnU1W/B54AXAR8HDiWJjSesCloegJsxi2/x/7AR4HDgBOBPwX2NGgkaTKmsmcjSVpeprJnM+1GmSRUN5fk3kmO6E2Wek2S6t2fU0tJnpHks0ku7U1Ae2GSdyS546RrW2qS7JHktCQ/S7IxyeVJjk8y30PlGsCeTUu9UW7nARv54/2kw4Ctae4n/X6C5S05SXajmVroWzSXRJ8M7FBVl0yuqqUpydnAZcAXgcuBhwGHAj+guV954+SqW1qSPAf4b8DXgQ3ASuBgmkvyD6qqSydY3pJk2LSU5JU0E4HOnCR0B5pJQv+mqgZOEqpbSnKbTf8RTPJi4B8xbBYkyYqq2tC37fnAx4AnVtVpk6lseUhyP5rgfl1VvXvS9Sw1XkZrb5RJQtXH/9sen/6g6flm73W7xaxlmfpl7/WGiVaxRBk27Y0ySai02B7Xe71golUsUUk2S7Jlkj8DjgR+BnxywmUtSc6N1t6gKXOuYvZpcqSJSLIdzRpOp1bVOZOuZ4n6OvDw3u8X0zx+ceUE61my7NlIy1CSO9AMFLiB5rkzLcx+wC40CzL+BjjF0ZILY9i0N8okoVLnktyOZtqmHYE9quryCZe0ZFXVBVX19ar6JPBE4A40o9LUkpfR2htlklCpU0m2AD4DrAZ2r6rvTrikZaOqfpXkYlyAcUHs2bQ3yiShUmeS3IZmaqcnAE+rqrMnXNKykmRb4P408zGqJZ+zaWnYSUI1vCTP6P36ROClwMtoHqTbUFXrJlbYEpPkQzR/v7cD/9a3+3Ivpw0vyeeBc4Hv0NyruS/wauCewCOr6qIJlrckGTYL0Jua5r3A7jSTgf4/4FU+iLgwSQb9Q7iuqnZbzFqWsiSXANsP2P3Wqjp08apZ2pK8AdgHuA+wJfCfwOnAO/z3fGEMG0lS57xnI0nqnGEjSeqcYSNJ6pxhI0nqnGEjSeqcYSNJEzQNq9UmecGMVV4rydFzHPuYJF/rrQb7syTv6U2RNCfDRpImayeaZ3quBs6cUA3Po3mm6BSah1hnleTBvWOuBPaiebB9f+Do+T7AudEkabLOqKpt4abVap88gRr2mLFi7p5zHPdWmiXHn1lV1/eOvw74WJJ3VtW5gxras5GkCRp2tdokK5J8OMlPkmxM8oMkaxarht4kr3sCx28Kmp7jgeuYZ6ViezaSNOWS3An4CnA74FBgPbAH8KEkt62qIxahjPsAW9G3UnFV/SHJj5hnpWLDRpKm3ytp5r17UFX9sLft1CR3Ad6S5ENVdUPHNdyt9zpopeK7zbL9Jl5Gk6TptyfNEtXrk2y+6Qf4ErANvV5Fkp16o8nm+7lksb+APRtJmn73oBm1dv2A/dv0Xi8FHjDE+QadZy6bejSDVio+f67Gho0kTb9f0gw3fuWA/RcC9G7c/6CjGn4EbKRvpeIkW9EsQf7puRobNpI0/U4CXg5cVlVXTqKAqrouyUnAPkkOnXGP6BnAbZlnpWLDRpImbMZqtQ/vvT4lyczVat8LPAs4M8l7aXoyt6dZpvqxVTXnsOMhPn9n/jia7HbA9jNqWldVG3q/HwqcDRyf5APAKuBdwGeq6ltzfoaLp0nSZA2zWm2SuwJvBp4GbAf8iiZ0PltVh4/4+YcCbxmw+/FVdfqMY3cF3gk8DPg18EngjVV1zZyfYdhIkrrm0GdJUucMG0lS5wwbSVLnDBtJUucMG0lS5wwbSVLnDBtJUucMG0lS5/4/DLo6p/d9cV4AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.rcParams.update({'font.size': 16})\n",
    "filtered_T = con_T[~is_outlier(con_T)]\n",
    "filtered_F = con_F[~is_outlier(con_F)]\n",
    "ax = filtered_T.plot.hist(bins=100, alpha=0.5)\n",
    "# ax1 = filtered_F.plot.hist(bins=100, alpha=0.5)\n",
    "fig = ax.get_figure()\n",
    "# fig1 = ax1.get_figure()\n",
    "plt.tight_layout()\n",
    "fig.savefig('../figures/resnet101/ratio_bin_TF.pdf', dpi=200)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Best3 and Worst3 performing classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([840])\n",
      "torch.Size([491])\n",
      "torch.Size([1832])\n",
      "torch.Size([728])\n",
      "torch.Size([799])\n",
      "torch.Size([845])\n",
      "torch.Size([341])\n",
      "torch.Size([1292])\n",
      "torch.Size([1961])\n",
      "torch.Size([1114])\n",
      "torch.Size([1121])\n",
      "torch.Size([1048])\n",
      "torch.Size([1828])\n",
      "torch.Size([2912])\n",
      "torch.Size([2397])\n",
      "torch.Size([670])\n",
      "torch.Size([1350])\n",
      "torch.Size([969])\n",
      "torch.Size([4180])\n",
      "torch.Size([578])\n",
      "torch.Size([1480])\n",
      "torch.Size([1695])\n",
      "torch.Size([4404])\n",
      "torch.Size([1704])\n",
      "torch.Size([1448])\n",
      "torch.Size([666])\n",
      "torch.Size([3061])\n",
      "torch.Size([3960])\n",
      "torch.Size([1521])\n",
      "torch.Size([523])\n",
      "torch.Size([714])\n",
      "torch.Size([592])\n",
      "torch.Size([1173])\n",
      "torch.Size([757])\n",
      "torch.Size([849])\n",
      "torch.Size([70])\n",
      "torch.Size([2272])\n",
      "torch.Size([1001])\n",
      "torch.Size([452])\n",
      "torch.Size([750])\n",
      "torch.Size([727])\n",
      "torch.Size([1410])\n",
      "torch.Size([1232])\n",
      "torch.Size([512])\n",
      "torch.Size([1219])\n",
      "torch.Size([674])\n",
      "torch.Size([568])\n",
      "torch.Size([989])\n",
      "torch.Size([261])\n",
      "torch.Size([21634])\n",
      "torch.Size([1117])\n",
      "torch.Size([1540])\n",
      "torch.Size([790])\n",
      "torch.Size([1041])\n",
      "torch.Size([818])\n",
      "torch.Size([302])\n",
      "torch.Size([489])\n",
      "torch.Size([1574])\n",
      "torch.Size([1092])\n",
      "torch.Size([993])\n",
      "torch.Size([533])\n",
      "torch.Size([1189])\n",
      "torch.Size([1445])\n",
      "torch.Size([589])\n",
      "torch.Size([876])\n",
      "torch.Size([1292])\n",
      "torch.Size([724])\n",
      "torch.Size([1193])\n",
      "torch.Size([1288])\n",
      "torch.Size([74])\n",
      "torch.Size([1185])\n",
      "torch.Size([341])\n",
      "torch.Size([1437])\n",
      "torch.Size([1281])\n",
      "torch.Size([2056])\n",
      "torch.Size([1577])\n",
      "torch.Size([1393])\n",
      "torch.Size([1200])\n",
      "torch.Size([872])\n",
      "torch.Size([677])\n"
     ]
    }
   ],
   "source": [
    "classwise_img = dict({\"class_index\":[], \"1\":[], \"2\":[], \"3\":[], \"-1\":[], \"-2\":[], \"-3\":[]})\n",
    "for i in range(AP.scores.shape[1]):\n",
    "  score = AP.scores[:,i]\n",
    "  target = AP.targets[:, i]\n",
    "  score_true = score[torch.where(target == 1)]\n",
    "  # print(torch.where(target==1))\n",
    "  \n",
    "  dataset_true = [test_dataset[each] for each in torch.where(target == 1)[0].tolist()]\n",
    "  sorted, indices = torch.sort(score_true, dim=0, descending=True)\n",
    "  \n",
    "  print(score_true.shape)\n",
    "  classwise_img[\"class_index\"].append(i)\n",
    "  classwise_img[\"1\"].append(dataset_true[int(indices[0])][0][1])\n",
    "  classwise_img[\"2\"].append(dataset_true[int(indices[1])][0][1])\n",
    "  classwise_img[\"3\"].append(dataset_true[int(indices[2])][0][1])\n",
    "  classwise_img[\"-1\"].append(dataset_true[int(indices[-1])][0][1])\n",
    "  classwise_img[\"-2\"].append(dataset_true[int(indices[-2])][0][1])\n",
    "  classwise_img[\"-3\"].append(dataset_true[int(indices[-3])][0][1])\n",
    "  \n",
    "classwise_img\n",
    "img_df = pd.DataFrame(data=classwise_img)\n",
    "img_df.to_csv(\"../figures/{}/topandworst.csv\".format(m_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_101283/2526294419.py:3: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).\n",
      "  fig = plt.figure(figsize=(12, 12))\n"
     ]
    }
   ],
   "source": [
    "plt.rcParams.update({'font.size': 8})\n",
    "for row_idx, row in img_df.iterrows():\n",
    "  fig = plt.figure(figsize=(12, 12))\n",
    "  for idx, (j, item) in enumerate(row.iteritems()):\n",
    "    if j == \"class_index\":\n",
    "      continue\n",
    "    a = fig.add_subplot(2, 3, idx)\n",
    "    img_id = '/home/seongha/LT-ML/data/coco/data/val2014/{}'.format(item)\n",
    "    image = Image.open(img_id)  \n",
    "    imgplot = plt.imshow(image)\n",
    "    # a.axis(\"off\")\n",
    "    # a.set_title(names[i].split('(')[0], fontsize=30)\n",
    "  \n",
    "  plt.savefig('../figures/{}/topandworst_class_{}.png'.format(m_name, row_idx), bbox_inches='tight', dpi=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classwise AP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_li = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',\n",
    "       'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',\n",
    "       'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']#voc\n",
    "label_li = list(class_di.keys()) #coco\n",
    "ap_li = 100 * AP.value()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### sorted class wise ap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "m_name='resnet101'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>index</th>\n",
       "      <th>class_index</th>\n",
       "      <th>class_size</th>\n",
       "      <th>AP</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>49</td>\n",
       "      <td>45174</td>\n",
       "      <td>96.732132</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>22</td>\n",
       "      <td>8950</td>\n",
       "      <td>67.002098</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>18</td>\n",
       "      <td>8606</td>\n",
       "      <td>40.081284</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>27</td>\n",
       "      <td>8378</td>\n",
       "      <td>80.071243</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>26</td>\n",
       "      <td>6518</td>\n",
       "      <td>95.650810</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75</th>\n",
       "      <td>75</td>\n",
       "      <td>55</td>\n",
       "      <td>673</td>\n",
       "      <td>80.927071</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>76</th>\n",
       "      <td>76</td>\n",
       "      <td>6</td>\n",
       "      <td>668</td>\n",
       "      <td>76.382812</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>77</th>\n",
       "      <td>77</td>\n",
       "      <td>48</td>\n",
       "      <td>481</td>\n",
       "      <td>70.343414</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>78</th>\n",
       "      <td>78</td>\n",
       "      <td>69</td>\n",
       "      <td>151</td>\n",
       "      <td>68.259087</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>79</th>\n",
       "      <td>79</td>\n",
       "      <td>35</td>\n",
       "      <td>128</td>\n",
       "      <td>98.568550</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>80 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    index  class_index  class_size         AP\n",
       "0       0           49       45174  96.732132\n",
       "1       1           22        8950  67.002098\n",
       "2       2           18        8606  40.081284\n",
       "3       3           27        8378  80.071243\n",
       "4       4           26        6518  95.650810\n",
       "..    ...          ...         ...        ...\n",
       "75     75           55         673  80.927071\n",
       "76     76            6         668  76.382812\n",
       "77     77           48         481  70.343414\n",
       "78     78           69         151  68.259087\n",
       "79     79           35         128  98.568550\n",
       "\n",
       "[80 rows x 4 columns]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_ = {\"class_index\": list(class_di.keys()),\n",
    "\"class_size\": list(class_di.values()),\n",
    "\"AP\": ap_li}\n",
    "\n",
    "df = pd.DataFrame(data=data_)\n",
    "df.to_csv(\"../figures/{}/sorted_ap_baseline{}.csv\".format(m_name, m_name))\n",
    "df.reset_index()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### label distribution csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "m_name = 'convnext'\n",
    "df = pd.read_csv(\"../figures/{}/sorted_ap_{}.csv\".format(m_name, m_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    Unnamed: 0  class_index  class_size         AP\n",
      "67          67           19        1186  97.547930\n",
      "49          49           34        1798  97.169050\n",
      "34          34           67        2368  96.780850\n",
      "30          30           61        2493  95.840400\n",
      "79          79           35         128  94.742190\n",
      "73          73           38         821  93.922264\n",
      "0            0           49       45174  93.445090\n",
      "6            6           14        5028  92.356060\n",
      "19\n",
      "Counter({3: 290, 4: 239, 2: 210, 5: 160, 6: 88, 7: 67, 1: 36, 8: 31, 9: 28, 11: 14, 10: 12, 12: 9, 13: 1, 14: 1})\n",
      "34\n",
      "Counter({1: 1234, 2: 407, 3: 116, 4: 27, 5: 8, 8: 3, 7: 2, 6: 1})\n",
      "67\n",
      "Counter({3: 1083, 2: 688, 4: 388, 5: 146, 6: 44, 7: 9, 1: 4, 9: 3, 8: 2, 11: 1})\n",
      "61\n",
      "Counter({6: 395, 7: 375, 5: 352, 4: 346, 8: 267, 3: 219, 9: 179, 10: 110, 2: 85, 11: 81, 12: 45, 13: 20, 14: 8, 15: 5, 1: 4, 16: 2})\n",
      "35\n",
      "Counter({3: 42, 4: 27, 5: 22, 2: 17, 6: 10, 7: 3, 8: 3, 1: 3, 10: 1})\n",
      "38\n",
      "Counter({2: 209, 3: 181, 4: 126, 1: 81, 5: 75, 6: 66, 7: 38, 8: 19, 9: 13, 10: 6, 11: 5, 12: 1, 13: 1})\n",
      "49\n",
      "Counter({2: 15997, 3: 12123, 4: 7145, 5: 4298, 6: 2350, 7: 1388, 8: 747, 9: 436, 1: 238, 10: 228, 11: 124, 12: 57, 13: 27, 14: 9, 15: 6, 16: 1})\n",
      "14\n",
      "Counter({4: 853, 5: 731, 6: 690, 3: 680, 7: 596, 8: 404, 2: 350, 9: 311, 10: 180, 11: 115, 12: 57, 13: 28, 1: 15, 14: 9, 15: 6, 16: 2, 18: 1})\n",
      "65\n",
      "Counter({2: 1782, 3: 362, 4: 103, 5: 42, 1: 29, 6: 15, 7: 5, 8: 3, 10: 2})\n",
      "60\n",
      "Counter({2: 711, 3: 345, 4: 83, 5: 17, 1: 9, 6: 3, 7: 2})\n",
      "18\n",
      "Counter({3: 2552, 4: 2111, 5: 1385, 2: 1334, 6: 626, 7: 278, 8: 129, 1: 97, 9: 51, 10: 25, 11: 12, 12: 3, 13: 2, 14: 1})\n",
      "52\n",
      "Counter({3: 250, 5: 215, 4: 207, 6: 183, 2: 176, 7: 164, 8: 144, 9: 107, 1: 82, 10: 59, 11: 38, 12: 22, 13: 11, 14: 6, 15: 4, 16: 2, 18: 1})\n",
      "70\n",
      "Counter({2: 859, 1: 728, 3: 406, 4: 196, 5: 90, 6: 25, 7: 9, 11: 2, 10: 1, 8: 1})\n",
      "4\n",
      "Counter({3: 621, 4: 578, 2: 280, 5: 222, 6: 72, 7: 21, 1: 5, 8: 4, 9: 1})\n",
      "3\n",
      "Counter({2: 402, 3: 337, 4: 257, 1: 204, 5: 140, 6: 94, 7: 63, 8: 46, 9: 26, 10: 22, 11: 11, 12: 9, 13: 3, 14: 2, 16: 1, 15: 1})\n",
      "15\n",
      "Counter({3: 294, 2: 294, 1: 216, 4: 195, 5: 143, 6: 74, 7: 50, 8: 24, 9: 18, 10: 17, 11: 7, 12: 6, 13: 2})\n"
     ]
    }
   ],
   "source": [
    "\n",
    "top = df.nlargest(8, 'AP')\n",
    "print(top)\n",
    "data_ = []\n",
    "for i, j in zip(top.class_index, top.index):\n",
    "  print(i)\n",
    "  col = gt_labels[:, i]\n",
    "  selected = gt_labels[np.isin(col, [1.0]), :]\n",
    "  nonzero_cnt = (selected != 0).sum(1)\n",
    "  cnter = Counter(nonzero_cnt)\n",
    "  print(cnter)\n",
    "  cnter = dict(sorted(cnter.items(),key = lambda i: i[0]))\n",
    "  cnter['class_index'] = i\n",
    "  cnter['rank'] = j\n",
    "  data_.append(cnter)\n",
    "df_top5 = pd.DataFrame(data=data_)\n",
    "df_top5 = df_top5.fillna(0)\n",
    "df_top5 = df_top5.set_index(['class_index', 'rank'])\n",
    "df_top5.to_csv(\"../figures/{}/label_distribution_top_{}.csv\".format(m_name, m_name))\n",
    "\n",
    "worst = df.nsmallest(8, 'AP')\n",
    "data_ = []\n",
    "for i, j in zip(worst.class_index, worst.index):\n",
    "  print(i)\n",
    "  col = gt_labels[:, i]\n",
    "  selected = gt_labels[np.isin(col, [1.0]), :]\n",
    "  nonzero_cnt = (selected != 0).sum(1)\n",
    "  cnter = Counter(nonzero_cnt)\n",
    "  print(cnter)\n",
    "  cnter = dict(sorted(cnter.items(),key = lambda i: i[0]))\n",
    "  cnter['class_index'] = i\n",
    "  cnter['rank'] = j\n",
    "  data_.append(cnter)\n",
    "df_worst5 = pd.DataFrame(data=data_)\n",
    "df_worst5 = df_worst5.fillna(0)\n",
    "df_worst5 = df_worst5.set_index(['class_index', 'rank'])\n",
    "df_worst5.to_csv(\"../figures/{}/label_distribution_worst_{}.csv\".format(m_name, m_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "81\n",
      "81\n"
     ]
    }
   ],
   "source": [
    "with open('./baseline.txt', 'r') as f:\n",
    "  baseline = f.readlines()\n",
    "  print(len(baseline))\n",
    "  baseline = list(map(float, baseline[1:]))\n",
    "  baseline = np.array(baseline)\n",
    "\n",
    "with open('./mixup.txt', 'r') as f:\n",
    "  mixup = f.readlines()\n",
    "  print(len(mixup))\n",
    "  mixup = list(map(float, mixup[1:]))\n",
    "  mixup = np.array(mixup)\n",
    "\n",
    "\n",
    "  ind = np.argsort(baseline)\n",
    "  baseline = baseline[ind]\n",
    "  mixup = mixup[ind]\n",
    "  pd.DataFrame(data={\"baseline\": baseline, \"mixup\": mixup}).to_csv('./sorted_ap_mixup_baseline.csv')\n"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "765b26ec2fc4c2066cf7ce8a0dde5a8255de29dd3973b3be957926608459ba30"
  },
  "kernelspec": {
   "display_name": "Python 3.10.4 ('MGSSL')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
