{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'voc'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "\u001b[1;32m/home/seongha/LT-ML/notebooks/EDA.ipynb Cell 1'\u001b[0m in \u001b[0;36m<cell line: 7>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/EDA.ipynb#ch0000000vscode-remote?line=4'>5</a>\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mpandas\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mpd\u001b[39;00m\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/EDA.ipynb#ch0000000vscode-remote?line=5'>6</a>\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mos\u001b[39;00m\n\u001b[0;32m----> <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/EDA.ipynb#ch0000000vscode-remote?line=6'>7</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mvoc\u001b[39;00m \u001b[39mimport\u001b[39;00m read_object_labels_csv\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'voc'"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import pickle, gzip\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"4\"\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "from voc import read_object_labels_csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "coco_lt_path = \"data/coco\"\n",
    "voc_lt_path = \"data/voc/\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# VOC2007"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_csv = '../data/voc/files/VOC2007'\n",
    "trainval_csv = os.path.join(path_csv, 'classification_trainval.csv')\n",
    "test_csv = os.path.join(path_csv, 'classification_test.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Index(['name', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',\n",
      "       'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',\n",
      "       'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'],\n",
      "      dtype='object')\n",
      "(5011, 21)\n"
     ]
    }
   ],
   "source": [
    "hi = pd.read_csv(trainval_csv)\n",
    "print(hi.columns)\n",
    "print(hi.shape)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Index(['name', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',\n",
      "       'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',\n",
      "       'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'],\n",
      "      dtype='object')\n",
      "(5011, 21)\n",
      "(5011, 20)\n",
      "[[0 0 0 ... 0 0 0]\n",
      " [0 0 0 ... 0 0 0]\n",
      " [0 0 0 ... 0 0 0]\n",
      " ...\n",
      " [0 1 0 ... 0 0 0]\n",
      " [0 0 0 ... 0 0 0]\n",
      " [0 0 0 ... 0 0 0]]\n"
     ]
    }
   ],
   "source": [
    "hi = pd.read_csv(trainval_csv)\n",
    "print(hi.columns)\n",
    "print(hi.shape)\n",
    "## generate gt pkl file\n",
    "hi=hi.replace(0,1)\n",
    "hi=hi.replace(-1,0)\n",
    "gt_labels = hi.iloc[:,1:].to_numpy()\n",
    "print(gt_labels.shape)\n",
    "img_id2idx = dict()\n",
    "idx2img_id = dict()\n",
    "print(gt_labels)\n",
    "\n",
    "for i, row in hi.iterrows():\n",
    "  img_id2idx[row['name']] = i\n",
    "  idx2img_id[i] = row['name']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = {\n",
    "  \"gt_labels\": gt_labels,\n",
    "  \"img_id2idx\": img_id2idx,\n",
    "  \"idx2img_id\": idx2img_id\n",
    "}\n",
    "with open('test_terse_gt_2007.pkl', 'wb') as f:\n",
    "  pickle.dump(res, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate VOC ADJ PKL file\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "corelation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Index(['name', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',\n",
      "       'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',\n",
      "       'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'],\n",
      "      dtype='object')\n",
      "(5011, 21)\n",
      "colname:  aeroplane\n",
      "(240, 20)\n",
      "bicycle\n",
      "bird\n",
      "boat\n",
      "bottle\n",
      "bus\n",
      "car\n",
      "cat\n",
      "chair\n",
      "cow\n",
      "diningtable\n",
      "dog\n",
      "horse\n",
      "motorbike\n",
      "person\n",
      "pottedplant\n",
      "sheep\n",
      "sofa\n",
      "train\n",
      "tvmonitor\n",
      "colname:  bicycle\n",
      "(255, 20)\n",
      "aeroplane\n",
      "bird\n",
      "boat\n",
      "bottle\n",
      "bus\n",
      "car\n",
      "cat\n",
      "chair\n",
      "cow\n",
      "diningtable\n",
      "dog\n",
      "horse\n",
      "motorbike\n",
      "person\n",
      "pottedplant\n",
      "sheep\n",
      "sofa\n",
      "train\n",
      "tvmonitor\n",
      "colname:  bird\n",
      "(333, 20)\n",
      "aeroplane\n",
      "bicycle\n",
      "boat\n",
      "bottle\n",
      "bus\n",
      "car\n",
      "cat\n",
      "chair\n",
      "cow\n",
      "diningtable\n",
      "dog\n",
      "horse\n",
      "motorbike\n",
      "person\n",
      "pottedplant\n",
      "sheep\n",
      "sofa\n",
      "train\n",
      "tvmonitor\n",
      "colname:  boat\n",
      "(188, 20)\n",
      "aeroplane\n",
      "bicycle\n",
      "bird\n",
      "bottle\n",
      "bus\n",
      "car\n",
      "cat\n",
      "chair\n",
      "cow\n",
      "diningtable\n",
      "dog\n",
      "horse\n",
      "motorbike\n",
      "person\n",
      "pottedplant\n",
      "sheep\n",
      "sofa\n",
      "train\n",
      "tvmonitor\n",
      "colname:  bottle\n",
      "(262, 20)\n",
      "aeroplane\n",
      "bicycle\n",
      "bird\n",
      "boat\n",
      "bus\n",
      "car\n",
      "cat\n",
      "chair\n",
      "cow\n",
      "diningtable\n",
      "dog\n",
      "horse\n",
      "motorbike\n",
      "person\n",
      "pottedplant\n",
      "sheep\n",
      "sofa\n",
      "train\n",
      "tvmonitor\n",
      "colname:  bus\n",
      "(197, 20)\n",
      "aeroplane\n",
      "bicycle\n",
      "bird\n",
      "boat\n",
      "bottle\n",
      "car\n",
      "cat\n",
      "chair\n",
      "cow\n",
      "diningtable\n",
      "dog\n",
      "horse\n",
      "motorbike\n",
      "person\n",
      "pottedplant\n",
      "sheep\n",
      "sofa\n",
      "train\n",
      "tvmonitor\n",
      "colname:  car\n",
      "(761, 20)\n",
      "aeroplane\n",
      "bicycle\n",
      "bird\n",
      "boat\n",
      "bottle\n",
      "bus\n",
      "cat\n",
      "chair\n",
      "cow\n",
      "diningtable\n",
      "dog\n",
      "horse\n",
      "motorbike\n",
      "person\n",
      "pottedplant\n",
      "sheep\n",
      "sofa\n",
      "train\n",
      "tvmonitor\n",
      "colname:  cat\n",
      "(344, 20)\n",
      "aeroplane\n",
      "bicycle\n",
      "bird\n",
      "boat\n",
      "bottle\n",
      "bus\n",
      "car\n",
      "chair\n",
      "cow\n",
      "diningtable\n",
      "dog\n",
      "horse\n",
      "motorbike\n",
      "person\n",
      "pottedplant\n",
      "sheep\n",
      "sofa\n",
      "train\n",
      "tvmonitor\n",
      "colname:  chair\n",
      "(572, 20)\n",
      "aeroplane\n",
      "bicycle\n",
      "bird\n",
      "boat\n",
      "bottle\n",
      "bus\n",
      "car\n",
      "cat\n",
      "cow\n",
      "diningtable\n",
      "dog\n",
      "horse\n",
      "motorbike\n",
      "person\n",
      "pottedplant\n",
      "sheep\n",
      "sofa\n",
      "train\n",
      "tvmonitor\n",
      "colname:  cow\n",
      "(146, 20)\n",
      "aeroplane\n",
      "bicycle\n",
      "bird\n",
      "boat\n",
      "bottle\n",
      "bus\n",
      "car\n",
      "cat\n",
      "chair\n",
      "diningtable\n",
      "dog\n",
      "horse\n",
      "motorbike\n",
      "person\n",
      "pottedplant\n",
      "sheep\n",
      "sofa\n",
      "train\n",
      "tvmonitor\n",
      "colname:  diningtable\n",
      "(263, 20)\n",
      "aeroplane\n",
      "bicycle\n",
      "bird\n",
      "boat\n",
      "bottle\n",
      "bus\n",
      "car\n",
      "cat\n",
      "chair\n",
      "cow\n",
      "dog\n",
      "horse\n",
      "motorbike\n",
      "person\n",
      "pottedplant\n",
      "sheep\n",
      "sofa\n",
      "train\n",
      "tvmonitor\n",
      "colname:  dog\n",
      "(430, 20)\n",
      "aeroplane\n",
      "bicycle\n",
      "bird\n",
      "boat\n",
      "bottle\n",
      "bus\n",
      "car\n",
      "cat\n",
      "chair\n",
      "cow\n",
      "diningtable\n",
      "horse\n",
      "motorbike\n",
      "person\n",
      "pottedplant\n",
      "sheep\n",
      "sofa\n",
      "train\n",
      "tvmonitor\n",
      "colname:  horse\n",
      "(294, 20)\n",
      "aeroplane\n",
      "bicycle\n",
      "bird\n",
      "boat\n",
      "bottle\n",
      "bus\n",
      "car\n",
      "cat\n",
      "chair\n",
      "cow\n",
      "diningtable\n",
      "dog\n",
      "motorbike\n",
      "person\n",
      "pottedplant\n",
      "sheep\n",
      "sofa\n",
      "train\n",
      "tvmonitor\n",
      "colname:  motorbike\n",
      "(249, 20)\n",
      "aeroplane\n",
      "bicycle\n",
      "bird\n",
      "boat\n",
      "bottle\n",
      "bus\n",
      "car\n",
      "cat\n",
      "chair\n",
      "cow\n",
      "diningtable\n",
      "dog\n",
      "horse\n",
      "person\n",
      "pottedplant\n",
      "sheep\n",
      "sofa\n",
      "train\n",
      "tvmonitor\n",
      "colname:  person\n",
      "(2095, 20)\n",
      "aeroplane\n",
      "bicycle\n",
      "bird\n",
      "boat\n",
      "bottle\n",
      "bus\n",
      "car\n",
      "cat\n",
      "chair\n",
      "cow\n",
      "diningtable\n",
      "dog\n",
      "horse\n",
      "motorbike\n",
      "pottedplant\n",
      "sheep\n",
      "sofa\n",
      "train\n",
      "tvmonitor\n",
      "colname:  pottedplant\n",
      "(273, 20)\n",
      "aeroplane\n",
      "bicycle\n",
      "bird\n",
      "boat\n",
      "bottle\n",
      "bus\n",
      "car\n",
      "cat\n",
      "chair\n",
      "cow\n",
      "diningtable\n",
      "dog\n",
      "horse\n",
      "motorbike\n",
      "person\n",
      "sheep\n",
      "sofa\n",
      "train\n",
      "tvmonitor\n",
      "colname:  sheep\n",
      "(97, 20)\n",
      "aeroplane\n",
      "bicycle\n",
      "bird\n",
      "boat\n",
      "bottle\n",
      "bus\n",
      "car\n",
      "cat\n",
      "chair\n",
      "cow\n",
      "diningtable\n",
      "dog\n",
      "horse\n",
      "motorbike\n",
      "person\n",
      "pottedplant\n",
      "sofa\n",
      "train\n",
      "tvmonitor\n",
      "colname:  sofa\n",
      "(372, 20)\n",
      "aeroplane\n",
      "bicycle\n",
      "bird\n",
      "boat\n",
      "bottle\n",
      "bus\n",
      "car\n",
      "cat\n",
      "chair\n",
      "cow\n",
      "diningtable\n",
      "dog\n",
      "horse\n",
      "motorbike\n",
      "person\n",
      "pottedplant\n",
      "sheep\n",
      "train\n",
      "tvmonitor\n",
      "colname:  train\n",
      "(263, 20)\n",
      "aeroplane\n",
      "bicycle\n",
      "bird\n",
      "boat\n",
      "bottle\n",
      "bus\n",
      "car\n",
      "cat\n",
      "chair\n",
      "cow\n",
      "diningtable\n",
      "dog\n",
      "horse\n",
      "motorbike\n",
      "person\n",
      "pottedplant\n",
      "sheep\n",
      "sofa\n",
      "tvmonitor\n",
      "colname:  tvmonitor\n",
      "(279, 20)\n",
      "aeroplane\n",
      "bicycle\n",
      "bird\n",
      "boat\n",
      "bottle\n",
      "bus\n",
      "car\n",
      "cat\n",
      "chair\n",
      "cow\n",
      "diningtable\n",
      "dog\n",
      "horse\n",
      "motorbike\n",
      "person\n",
      "pottedplant\n",
      "sheep\n",
      "sofa\n",
      "train\n",
      "117 7544\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(552, 2)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## count number of correlations\n",
    "hi = pd.read_csv(trainval_csv)\n",
    "print(hi.columns)\n",
    "print(hi.shape)\n",
    "## generate gt pkl file\n",
    "hi=hi.replace(0,1)\n",
    "hi=hi.replace(-1,0)\n",
    "\n",
    "hi=hi.iloc[:,1:]\n",
    "di = dict()\n",
    "adj = []\n",
    "nums = []\n",
    "from collections import defaultdict\n",
    "cnt  = defaultdict(int)\n",
    "for col_name, val in hi.iteritems():\n",
    "  if col_name =='name':\n",
    "    continue\n",
    "  print(\"colname: \",col_name)\n",
    "  df = hi[hi[col_name].isin([1.0])]\n",
    "  print(df.shape)\n",
    "  for col_name_df, val_df in df.iteritems():\n",
    "    if col_name_df == col_name:\n",
    "      continue\n",
    "    print(col_name_df)\n",
    "    if val_df.value_counts().get(1):\n",
    "      k = [col_name, col_name_df]\n",
    "      k.sort()\n",
    "      k = tuple(k)\n",
    "      cnt[k] += val_df.value_counts().get(1)\n",
    "\n",
    "\n",
    "\n",
    "cnt = {k:v for k, v in sorted(cnt.items(), key=lambda item: item[1], reverse=True)}\n",
    "print(len(cnt), sum(cnt.values()) )\n",
    "max(cnt.values()), min(cnt.values())\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "for k,v in cnt.items():\n",
    "  \n",
    "  print(\"{},{}\".format(k[0],k[1]), v)\n",
    "\n",
    "voc_joint = pd.DataFrame({\"k0\": [k[0] for k in cnt.keys()],\n",
    "    \"k1\" : [k[1] for k in cnt.keys()],\n",
    "    \"v\" : cnt.values()\n",
    "})\n",
    "\n",
    "voc_joint.to_csv(\"voc_joint.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Index(['name', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',\n",
      "       'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',\n",
      "       'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'],\n",
      "      dtype='object')\n",
      "(5011, 21)\n",
      "284\n",
      "512\n",
      "375\n",
      "275\n",
      "643\n",
      "401\n",
      "1285\n",
      "492\n",
      "1477\n",
      "210\n",
      "766\n",
      "709\n",
      "556\n",
      "488\n",
      "4321\n",
      "648\n",
      "123\n",
      "923\n",
      "340\n",
      "629\n",
      "[240, 255, 333, 188, 262, 197, 761, 344, 572, 146, 263, 430, 294, 249, 2095, 273, 97, 372, 263, 279]\n",
      "2095 97\n",
      "[[ 240    0    0    3    0    1    7    0    1    0    0    0    0    0\n",
      "    30    0    1    0    0    1]\n",
      " [   0  255    0    1    4   13   27    1   11    1    1    1    0    8\n",
      "   174    5    0    5    0    5]\n",
      " [   0    0  333    4    0    0    3    0    2    4    0    1    0    0\n",
      "    26    2    0    0    0    0]\n",
      " [   3    1    4  188    0    2    9    0    2    2    2    3    0    1\n",
      "    54    2    0    1    1    0]\n",
      " [   0    4    0    0  262    1    5    4   66    1   69    5    1    0\n",
      "   161   20    2   18    0   24]\n",
      " [   1   13    0    2    1  197   89    0    0    0    0    1    0    7\n",
      "    89    0    0    0    1    0]\n",
      " [   7   27    3    9    5   89  761    2    6    3    0    7   12   38\n",
      "   276    8    2   10   18    2]\n",
      " [   0    1    0    0    4    0    2  344   34    0    7   10    0    0\n",
      "    33   18    0   27    0   12]\n",
      " [   1   11    2    2   66    0    6   34  572    0  183   32    1    2\n",
      "   250   84    0  132    0   99]\n",
      " [   0    1    4    2    1    0    3    0    0  146    0    1    7    2\n",
      "    41    0    2    0    0    0]\n",
      " [   0    1    0    2   69    0    0    7  183    0  263    4    0    0\n",
      "   146   43    0   30    0   18]\n",
      " [   0    1    1    3    5    1    7   10   32    1    4  430    8    2\n",
      "   131   11    5   47    0   10]\n",
      " [   0    0    0    0    1    0   12    0    1    7    0    8  294    0\n",
      "   227    4    1    0    1    0]\n",
      " [   0    8    0    1    0    7   38    0    2    2    0    2    0  249\n",
      "   175    3    0    1    0    0]\n",
      " [  30  174   26   54  161   89  276   33  250   41  146  131  227  175\n",
      "  2095   92   13  162   55   91]\n",
      " [   0    5    2    2   20    0    8   18   84    0   43   11    4    3\n",
      "    92  273    0   56    1   26]\n",
      " [   1    0    0    0    2    0    2    0    0    2    0    5    1    0\n",
      "    13    0   97    0    0    0]\n",
      " [   0    5    0    1   18    0   10   27  132    0   30   47    0    1\n",
      "   162   56    0  372    0   62]\n",
      " [   0    0    0    1    0    1   18    0    0    0    0    0    1    0\n",
      "    55    1    0    0  263    0]\n",
      " [   1    5    0    0   24    0    2   12   99    0   18   10    0    0\n",
      "    91   26    0   62    0  279]]\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "Object of type ndarray is not JSON serializable",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m/home/seongha/LT-ML/notebooks/EDA.ipynb Cell 13'\u001b[0m in \u001b[0;36m<cell line: 42>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/EDA.ipynb#ch0000012vscode-remote?line=40'>41</a>\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mjson\u001b[39;00m\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/EDA.ipynb#ch0000012vscode-remote?line=41'>42</a>\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mopen\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mvoc_class_num.json\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m'\u001b[39m\u001b[39mw\u001b[39m\u001b[39m'\u001b[39m) \u001b[39mas\u001b[39;00m f:\n\u001b[0;32m---> <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/EDA.ipynb#ch0000012vscode-remote?line=42'>43</a>\u001b[0m   json\u001b[39m.\u001b[39;49mdump(di, f)\n",
      "File \u001b[0;32m~/anaconda3/envs/MGSSL/lib/python3.9/json/__init__.py:179\u001b[0m, in \u001b[0;36mdump\u001b[0;34m(obj, fp, skipkeys, ensure_ascii, check_circular, allow_nan, cls, indent, separators, default, sort_keys, **kw)\u001b[0m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/__init__.py?line=172'>173</a>\u001b[0m     iterable \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39m(skipkeys\u001b[39m=\u001b[39mskipkeys, ensure_ascii\u001b[39m=\u001b[39mensure_ascii,\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/__init__.py?line=173'>174</a>\u001b[0m         check_circular\u001b[39m=\u001b[39mcheck_circular, allow_nan\u001b[39m=\u001b[39mallow_nan, indent\u001b[39m=\u001b[39mindent,\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/__init__.py?line=174'>175</a>\u001b[0m         separators\u001b[39m=\u001b[39mseparators,\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/__init__.py?line=175'>176</a>\u001b[0m         default\u001b[39m=\u001b[39mdefault, sort_keys\u001b[39m=\u001b[39msort_keys, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkw)\u001b[39m.\u001b[39miterencode(obj)\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/__init__.py?line=176'>177</a>\u001b[0m \u001b[39m# could accelerate with writelines in some versions of Python, at\u001b[39;00m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/__init__.py?line=177'>178</a>\u001b[0m \u001b[39m# a debuggability cost\u001b[39;00m\n\u001b[0;32m--> <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/__init__.py?line=178'>179</a>\u001b[0m \u001b[39mfor\u001b[39;00m chunk \u001b[39min\u001b[39;00m iterable:\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/__init__.py?line=179'>180</a>\u001b[0m     fp\u001b[39m.\u001b[39mwrite(chunk)\n",
      "File \u001b[0;32m~/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py:431\u001b[0m, in \u001b[0;36m_make_iterencode.<locals>._iterencode\u001b[0;34m(o, _current_indent_level)\u001b[0m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=428'>429</a>\u001b[0m     \u001b[39myield from\u001b[39;00m _iterencode_list(o, _current_indent_level)\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=429'>430</a>\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39misinstance\u001b[39m(o, \u001b[39mdict\u001b[39m):\n\u001b[0;32m--> <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=430'>431</a>\u001b[0m     \u001b[39myield from\u001b[39;00m _iterencode_dict(o, _current_indent_level)\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=431'>432</a>\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=432'>433</a>\u001b[0m     \u001b[39mif\u001b[39;00m markers \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n",
      "File \u001b[0;32m~/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py:405\u001b[0m, in \u001b[0;36m_make_iterencode.<locals>._iterencode_dict\u001b[0;34m(dct, _current_indent_level)\u001b[0m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=402'>403</a>\u001b[0m         \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=403'>404</a>\u001b[0m             chunks \u001b[39m=\u001b[39m _iterencode(value, _current_indent_level)\n\u001b[0;32m--> <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=404'>405</a>\u001b[0m         \u001b[39myield from\u001b[39;00m chunks\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=405'>406</a>\u001b[0m \u001b[39mif\u001b[39;00m newline_indent \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=406'>407</a>\u001b[0m     _current_indent_level \u001b[39m-\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n",
      "File \u001b[0;32m~/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py:438\u001b[0m, in \u001b[0;36m_make_iterencode.<locals>._iterencode\u001b[0;34m(o, _current_indent_level)\u001b[0m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=435'>436</a>\u001b[0m         \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mCircular reference detected\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=436'>437</a>\u001b[0m     markers[markerid] \u001b[39m=\u001b[39m o\n\u001b[0;32m--> <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=437'>438</a>\u001b[0m o \u001b[39m=\u001b[39m _default(o)\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=438'>439</a>\u001b[0m \u001b[39myield from\u001b[39;00m _iterencode(o, _current_indent_level)\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=439'>440</a>\u001b[0m \u001b[39mif\u001b[39;00m markers \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n",
      "File \u001b[0;32m~/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py:179\u001b[0m, in \u001b[0;36mJSONEncoder.default\u001b[0;34m(self, o)\u001b[0m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=159'>160</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdefault\u001b[39m(\u001b[39mself\u001b[39m, o):\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=160'>161</a>\u001b[0m     \u001b[39m\"\"\"Implement this method in a subclass such that it returns\u001b[39;00m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=161'>162</a>\u001b[0m \u001b[39m    a serializable object for ``o``, or calls the base implementation\u001b[39;00m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=162'>163</a>\u001b[0m \u001b[39m    (to raise a ``TypeError``).\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=176'>177</a>\u001b[0m \n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=177'>178</a>\u001b[0m \u001b[39m    \"\"\"\u001b[39;00m\n\u001b[0;32m--> <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=178'>179</a>\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39mObject of type \u001b[39m\u001b[39m{\u001b[39;00mo\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/json/encoder.py?line=179'>180</a>\u001b[0m                     \u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39mis not JSON serializable\u001b[39m\u001b[39m'\u001b[39m)\n",
      "\u001b[0;31mTypeError\u001b[0m: Object of type ndarray is not JSON serializable"
     ]
    }
   ],
   "source": [
    "hi = pd.read_csv(trainval_csv)\n",
    "print(hi.columns)\n",
    "print(hi.shape)\n",
    "## generate gt pkl file\n",
    "hi=hi.replace(0,1)\n",
    "hi=hi.replace(-1,0)\n",
    "\n",
    "hi=hi.iloc[:,1:]\n",
    "di = dict()\n",
    "adj = []\n",
    "nums = []\n",
    "from collections import Counter\n",
    "cnt  = Counter()\n",
    "for col_name, val in hi.iteritems():\n",
    "  # print(col_name, hi[hi[col_name].isin([1.0])].count().to_)\n",
    "  if col_name =='name':\n",
    "    continue\n",
    "  # di[col_name] = val.value_counts().to_dict()\n",
    "  M = hi[hi[col_name].isin([1.0])].sum(axis=0)\n",
    "  # print(M)\n",
    "  nums.append(M[col_name])\n",
    "  # M[col_name] = 0\n",
    "  # print(cond_prob[col_name])\n",
    "  # print(M)\n",
    "  print(sum(M.to_list()))\n",
    "  adj.append(M.to_list())\n",
    "  \n",
    "  # print(hi[hi[col_name].isin([1.0])].sum(axis=0))\n",
    "# nums.sort()\n",
    "print(nums)\n",
    "print(max(nums), min(nums))\n",
    "\n",
    "adj = np.asarray(adj)\n",
    "nums = np.asarray(nums)\n",
    "\n",
    "\n",
    "di['adj'] = adj\n",
    "di['nums'] = nums\n",
    "print(di['adj'])\n",
    "\n",
    "import json\n",
    "with open(\"voc_class_num.json\", 'w') as f:\n",
    "  json.dump(di, f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open('voc_lt_adj.pkl', 'wb') as f:\n",
    "  pickle.dump(di, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Index(['name', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',\n",
      "       'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',\n",
      "       'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'],\n",
      "      dtype='object')\n",
      "(5011, 21)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "defaultdict(int, {1: 2808, 2: 1644, 3: 443, 4: 98, 5: 13, 6: 4, 7: 1})"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hi = pd.read_csv(trainval_csv)\n",
    "print(hi.columns)\n",
    "print(hi.shape)\n",
    "## generate gt pkl file\n",
    "hi=hi.replace(0,1)\n",
    "hi=hi.replace(-1,0)\n",
    "\n",
    "hi=hi.iloc[:,1:]\n",
    "di = dict()\n",
    "adj = []\n",
    "nums = []\n",
    "from collections import defaultdict\n",
    "cnt  = defaultdict(int)\n",
    "\n",
    "\n",
    "Counter(hi.astype(bool).sum(axis=1).to_dict())\n",
    "item_lab_cnt = hi.astype(bool).sum(axis=1).to_dict()\n",
    "for k, v in item_lab_cnt.items():\n",
    "  cnt[v] += 1\n",
    "\n",
    "cnt\n",
    "\n",
    "pd.DataFrame({\n",
    "  \n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 1. 1. 0. 1. 0. 1.]\n",
      " [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 1. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 1. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 1. 0. 0.]\n",
      " [0. 0. 1. 0. 1. 0. 0. 1. 0. 0. 1. 1. 0. 0. 0. 1. 0. 1. 0. 1.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 1. 0. 1. 1. 0. 1. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 1. 0. 1. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1.]\n",
      " [0. 0. 0. 1. 1. 0. 0. 1. 1. 0. 1. 1. 0. 0. 0. 0. 0. 1. 1. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 1. 1. 0. 1. 1. 0. 0. 0. 1. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 1. 1. 0. 1. 1. 0. 0. 0. 1. 0. 1. 0. 0.]]\n",
      "[[0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.\n",
      "  0.    0.    0.017 0.    0.    0.    0.    0.   ]\n",
      " [0.    0.    0.    0.    0.    0.05  0.    0.    0.    0.    0.    0.\n",
      "  0.    0.04  0.017 0.    0.    0.    0.    0.   ]\n",
      " [0.    0.    0.    0.    0.    0.    0.029 0.    0.    0.    0.    0.\n",
      "  0.    0.    0.017 0.    0.    0.    0.    0.   ]\n",
      " [0.    0.    0.    0.    0.    0.    0.029 0.    0.    0.    0.    0.\n",
      "  0.    0.04  0.017 0.02  0.    0.    0.    0.   ]\n",
      " [0.    0.    0.    0.    0.    0.    0.    0.    0.025 0.067 0.029 0.\n",
      "  0.    0.    0.017 0.02  0.    0.025 0.    0.029]\n",
      " [0.    0.05  0.    0.    0.    0.    0.029 0.    0.    0.    0.    0.\n",
      "  0.    0.04  0.017 0.    0.    0.    0.    0.   ]\n",
      " [0.    0.05  0.067 0.05  0.    0.05  0.    0.    0.    0.    0.    0.\n",
      "  0.    0.04  0.    0.    0.1   0.    0.067 0.   ]\n",
      " [0.    0.    0.    0.    0.    0.    0.    0.    0.025 0.    0.029 0.\n",
      "  0.    0.    0.    0.02  0.    0.025 0.    0.   ]\n",
      " [0.    0.    0.067 0.    0.029 0.    0.    0.033 0.    0.    0.029 0.033\n",
      "  0.    0.    0.    0.02  0.    0.025 0.    0.029]\n",
      " [0.    0.    0.    0.    0.029 0.    0.    0.    0.    0.    0.    0.\n",
      "  0.067 0.    0.017 0.    0.    0.    0.    0.   ]\n",
      " [0.    0.    0.    0.    0.029 0.    0.    0.033 0.025 0.    0.    0.\n",
      "  0.    0.    0.    0.02  0.    0.025 0.    0.029]\n",
      " [0.    0.    0.    0.    0.    0.    0.    0.    0.025 0.    0.    0.\n",
      "  0.067 0.    0.017 0.02  0.    0.025 0.    0.029]\n",
      " [0.    0.    0.    0.    0.    0.    0.    0.    0.    0.067 0.    0.033\n",
      "  0.    0.    0.017 0.    0.    0.    0.    0.   ]\n",
      " [0.    0.05  0.    0.05  0.    0.05  0.029 0.    0.    0.    0.    0.\n",
      "  0.    0.    0.017 0.    0.    0.    0.    0.   ]\n",
      " [0.2   0.05  0.067 0.05  0.029 0.05  0.029 0.033 0.025 0.067 0.029 0.033\n",
      "  0.067 0.04  0.    0.02  0.1   0.025 0.067 0.029]\n",
      " [0.    0.    0.    0.05  0.029 0.    0.    0.033 0.025 0.    0.029 0.033\n",
      "  0.    0.    0.    0.    0.    0.025 0.067 0.029]\n",
      " [0.    0.    0.    0.    0.    0.    0.029 0.    0.    0.    0.    0.\n",
      "  0.    0.    0.017 0.    0.    0.    0.    0.   ]\n",
      " [0.    0.    0.    0.    0.029 0.    0.    0.033 0.025 0.    0.029 0.033\n",
      "  0.    0.    0.    0.02  0.    0.    0.    0.029]\n",
      " [0.    0.    0.    0.    0.    0.    0.029 0.    0.    0.    0.    0.\n",
      "  0.    0.    0.017 0.02  0.    0.    0.    0.   ]\n",
      " [0.    0.    0.    0.    0.029 0.    0.    0.033 0.025 0.    0.029 0.033\n",
      "  0.    0.    0.    0.02  0.    0.025 0.    0.   ]]\n",
      "[[0.8   0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.\n",
      "  0.    0.    0.017 0.    0.    0.    0.    0.   ]\n",
      " [0.    0.8   0.    0.    0.    0.05  0.    0.    0.    0.    0.    0.\n",
      "  0.    0.04  0.017 0.    0.    0.    0.    0.   ]\n",
      " [0.    0.    0.8   0.    0.    0.    0.029 0.    0.    0.    0.    0.\n",
      "  0.    0.    0.017 0.    0.    0.    0.    0.   ]\n",
      " [0.    0.    0.    0.8   0.    0.    0.029 0.    0.    0.    0.    0.\n",
      "  0.    0.04  0.017 0.02  0.    0.    0.    0.   ]\n",
      " [0.    0.    0.    0.    0.8   0.    0.    0.    0.025 0.067 0.029 0.\n",
      "  0.    0.    0.017 0.02  0.    0.025 0.    0.029]\n",
      " [0.    0.05  0.    0.    0.    0.8   0.029 0.    0.    0.    0.    0.\n",
      "  0.    0.04  0.017 0.    0.    0.    0.    0.   ]\n",
      " [0.    0.05  0.067 0.05  0.    0.05  0.8   0.    0.    0.    0.    0.\n",
      "  0.    0.04  0.    0.    0.1   0.    0.067 0.   ]\n",
      " [0.    0.    0.    0.    0.    0.    0.    0.8   0.025 0.    0.029 0.\n",
      "  0.    0.    0.    0.02  0.    0.025 0.    0.   ]\n",
      " [0.    0.    0.067 0.    0.029 0.    0.    0.033 0.8   0.    0.029 0.033\n",
      "  0.    0.    0.    0.02  0.    0.025 0.    0.029]\n",
      " [0.    0.    0.    0.    0.029 0.    0.    0.    0.    0.8   0.    0.\n",
      "  0.067 0.    0.017 0.    0.    0.    0.    0.   ]\n",
      " [0.    0.    0.    0.    0.029 0.    0.    0.033 0.025 0.    0.8   0.\n",
      "  0.    0.    0.    0.02  0.    0.025 0.    0.029]\n",
      " [0.    0.    0.    0.    0.    0.    0.    0.    0.025 0.    0.    0.8\n",
      "  0.067 0.    0.017 0.02  0.    0.025 0.    0.029]\n",
      " [0.    0.    0.    0.    0.    0.    0.    0.    0.    0.067 0.    0.033\n",
      "  0.8   0.    0.017 0.    0.    0.    0.    0.   ]\n",
      " [0.    0.05  0.    0.05  0.    0.05  0.029 0.    0.    0.    0.    0.\n",
      "  0.    0.8   0.017 0.    0.    0.    0.    0.   ]\n",
      " [0.2   0.05  0.067 0.05  0.029 0.05  0.029 0.033 0.025 0.067 0.029 0.033\n",
      "  0.067 0.04  0.8   0.02  0.1   0.025 0.067 0.029]\n",
      " [0.    0.    0.    0.05  0.029 0.    0.    0.033 0.025 0.    0.029 0.033\n",
      "  0.    0.    0.    0.8   0.    0.025 0.067 0.029]\n",
      " [0.    0.    0.    0.    0.    0.    0.029 0.    0.    0.    0.    0.\n",
      "  0.    0.    0.017 0.    0.8   0.    0.    0.   ]\n",
      " [0.    0.    0.    0.    0.029 0.    0.    0.033 0.025 0.    0.029 0.033\n",
      "  0.    0.    0.    0.02  0.    0.8   0.    0.029]\n",
      " [0.    0.    0.    0.    0.    0.    0.029 0.    0.    0.    0.    0.\n",
      "  0.    0.    0.017 0.02  0.    0.    0.8   0.   ]\n",
      " [0.    0.    0.    0.    0.029 0.    0.    0.033 0.025 0.    0.029 0.033\n",
      "  0.    0.    0.    0.02  0.    0.025 0.    0.8  ]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_40963/1120479978.py:18: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  P = P + np.identity(len(nums), np.int) * 0.8\n"
     ]
    }
   ],
   "source": [
    "##check\n",
    "# print(nums)\n",
    "nums2 = nums[:, np.newaxis]\n",
    "# print(nums2)\n",
    "# P = adj / nums2\n",
    "P = np.divide(adj, nums2)\n",
    "# print(\"p\", P)\n",
    "\n",
    "t = np.mean(P, axis=0)\n",
    "np.set_printoptions(3)\n",
    "# print(\"t:\", t)\n",
    "P[P < t] = 0 #balance\n",
    "# print(P)\n",
    "P[P >= t] = 1\n",
    "print(P)\n",
    "P = P * 0.2 / (P.sum(0, keepdims=True) + 1e-6) #reweight\n",
    "print(P)\n",
    "P = P + np.identity(len(nums), np.int) * 0.8\n",
    "print(P)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1366\n",
      "9963\n",
      "0       000001\n",
      "2       000003\n",
      "3       000004\n",
      "4       000006\n",
      "5       000008\n",
      "         ...  \n",
      "4941    009941\n",
      "4944    009951\n",
      "4948    009957\n",
      "4950    009962\n",
      "4951    009963\n",
      "Name: name, Length: 1366, dtype: object\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_40963/2512746191.py:14: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  LT['name'] = LT['name'].apply(lambda x:format(x, '06d'))\n"
     ]
    }
   ],
   "source": [
    "with open(voc_lt_path + \"/testlongtail2007/img_id.txt\", \"r\") as f:\n",
    "  # data= pickle.load(f)\n",
    "  data = f.readlines()\n",
    "  # print(data)\n",
    "  data = [int(each) for each in data]\n",
    "  # data = list(map(str,data))\n",
    "  print(len(data))\n",
    "print(max(data))\n",
    "# print(hi['name'].isin({'name': data}))\n",
    "# print(hi['name'].dtype)\n",
    "\n",
    "LT = hi.loc[hi['name'].isin(data)]\n",
    "\n",
    "LT['name'] = LT['name'].apply(lambda x:format(x, '06d'))\n",
    "print(LT['name'])\n",
    "import os\n",
    "LT_path = os.path.join(path_csv, 'classification_test_' + \"LT.csv\")\n",
    "LT.to_csv(LT_path, index=False) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# COCO2014"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_csv = '../data/coco/'\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['airplane', 'apple', 'backpack', 'banana', 'baseball bat', 'baseball glove', 'bear', 'bed', 'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle', 'bowl', 'broccoli', 'bus', 'cake', 'car', 'carrot', 'cat', 'cell phone', 'chair', 'clock', 'couch', 'cow', 'cup', 'dining table', 'dog', 'donut', 'elephant', 'fire hydrant', 'fork', 'frisbee', 'giraffe', 'hair drier', 'handbag', 'horse', 'hot dog', 'keyboard', 'kite', 'knife', 'laptop', 'microwave', 'motorcycle', 'mouse', 'orange', 'oven', 'parking meter', 'person', 'pizza', 'potted plant', 'refrigerator', 'remote', 'sandwich', 'scissors', 'sheep', 'sink', 'skateboard', 'skis', 'snowboard', 'spoon', 'sports ball', 'stop sign', 'suitcase', 'surfboard', 'teddy bear', 'tennis racket', 'tie', 'toaster', 'toilet', 'toothbrush', 'traffic light', 'train', 'truck', 'tv', 'umbrella', 'vase', 'wine glass', 'zebra'])\n",
      "80\n",
      "dict_values([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79])\n",
      "82081 {'file_name': 'COCO_train2014_000000057870.jpg', 'labels': [12, 77, 51, 22, 27]}\n",
      "79\n",
      "RangeIndex(start=0, stop=80, step=1)\n",
      "0        0.0\n",
      "1        0.0\n",
      "2        0.0\n",
      "3        0.0\n",
      "4        0.0\n",
      "        ... \n",
      "82076    0.0\n",
      "82077    0.0\n",
      "82078    0.0\n",
      "82079    0.0\n",
      "82080    0.0\n",
      "Name: 0, Length: 82081, dtype: float64\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "Length of values (2) does not match length of index (82081)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[1;32m/home/seongha/LT-ML/EDA.ipynb Cell 18'\u001b[0m in \u001b[0;36m<cell line: 10>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/EDA.ipynb#ch0000017vscode-remote?line=30'>31</a>\u001b[0m   \u001b[39mif\u001b[39;00m col_name \u001b[39m==\u001b[39m \u001b[39m'\u001b[39m\u001b[39mname\u001b[39m\u001b[39m'\u001b[39m: \u001b[39mcontinue\u001b[39;00m\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/EDA.ipynb#ch0000017vscode-remote?line=31'>32</a>\u001b[0m   \u001b[39mprint\u001b[39m(val)\n\u001b[0;32m---> <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/EDA.ipynb#ch0000017vscode-remote?line=32'>33</a>\u001b[0m   df[col_name] \u001b[39m=\u001b[39m val\u001b[39m.\u001b[39mvalue_counts()\u001b[39m.\u001b[39mto_dict()\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/EDA.ipynb#ch0000017vscode-remote?line=34'>35</a>\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mopen\u001b[39m(\u001b[39m'\u001b[39m\u001b[39mcoco_class_num\u001b[39m\u001b[39m'\u001b[39m, \u001b[39m'\u001b[39m\u001b[39mw\u001b[39m\u001b[39m'\u001b[39m) \u001b[39mas\u001b[39;00m f:\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/EDA.ipynb#ch0000017vscode-remote?line=35'>36</a>\u001b[0m   json\u001b[39m.\u001b[39mdump(di, f)\n",
      "File \u001b[0;32m~/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py:3655\u001b[0m, in \u001b[0;36mDataFrame.__setitem__\u001b[0;34m(self, key, value)\u001b[0m\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=3651'>3652</a>\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_setitem_array([key], value)\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=3652'>3653</a>\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=3653'>3654</a>\u001b[0m     \u001b[39m# set column\u001b[39;00m\n\u001b[0;32m-> <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=3654'>3655</a>\u001b[0m     \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_set_item(key, value)\n",
      "File \u001b[0;32m~/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py:3832\u001b[0m, in \u001b[0;36mDataFrame._set_item\u001b[0;34m(self, key, value)\u001b[0m\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=3821'>3822</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_set_item\u001b[39m(\u001b[39mself\u001b[39m, key, value) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=3822'>3823</a>\u001b[0m     \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=3823'>3824</a>\u001b[0m \u001b[39m    Add series to DataFrame in specified column.\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=3824'>3825</a>\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=3829'>3830</a>\u001b[0m \u001b[39m    ensure homogeneity.\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=3830'>3831</a>\u001b[0m \u001b[39m    \"\"\"\u001b[39;00m\n\u001b[0;32m-> <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=3831'>3832</a>\u001b[0m     value \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_sanitize_column(value)\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=3833'>3834</a>\u001b[0m     \u001b[39mif\u001b[39;00m (\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=3834'>3835</a>\u001b[0m         key \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcolumns\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=3835'>3836</a>\u001b[0m         \u001b[39mand\u001b[39;00m value\u001b[39m.\u001b[39mndim \u001b[39m==\u001b[39m \u001b[39m1\u001b[39m\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=3836'>3837</a>\u001b[0m         \u001b[39mand\u001b[39;00m \u001b[39mnot\u001b[39;00m is_extension_array_dtype(value)\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=3837'>3838</a>\u001b[0m     ):\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=3838'>3839</a>\u001b[0m         \u001b[39m# broadcast across multiple columns if necessary\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=3839'>3840</a>\u001b[0m         \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcolumns\u001b[39m.\u001b[39mis_unique \u001b[39mor\u001b[39;00m \u001b[39misinstance\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcolumns, MultiIndex):\n",
      "File \u001b[0;32m~/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py:4535\u001b[0m, in \u001b[0;36mDataFrame._sanitize_column\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=4531'>4532</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m _reindex_for_setitem(value, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mindex)\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=4533'>4534</a>\u001b[0m \u001b[39mif\u001b[39;00m is_list_like(value):\n\u001b[0;32m-> <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=4534'>4535</a>\u001b[0m     com\u001b[39m.\u001b[39;49mrequire_length_match(value, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mindex)\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/frame.py?line=4535'>4536</a>\u001b[0m \u001b[39mreturn\u001b[39;00m sanitize_array(value, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mindex, copy\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m, allow_2d\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n",
      "File \u001b[0;32m~/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/common.py:557\u001b[0m, in \u001b[0;36mrequire_length_match\u001b[0;34m(data, index)\u001b[0m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/common.py?line=552'>553</a>\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/common.py?line=553'>554</a>\u001b[0m \u001b[39mCheck the length of data matches the length of the index.\u001b[39;00m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/common.py?line=554'>555</a>\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/common.py?line=555'>556</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(data) \u001b[39m!=\u001b[39m \u001b[39mlen\u001b[39m(index):\n\u001b[0;32m--> <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/common.py?line=556'>557</a>\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/common.py?line=557'>558</a>\u001b[0m         \u001b[39m\"\u001b[39m\u001b[39mLength of values \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/common.py?line=558'>559</a>\u001b[0m         \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m(\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mlen\u001b[39m(data)\u001b[39m}\u001b[39;00m\u001b[39m) \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/common.py?line=559'>560</a>\u001b[0m         \u001b[39m\"\u001b[39m\u001b[39mdoes not match length of index \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/common.py?line=560'>561</a>\u001b[0m         \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m(\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mlen\u001b[39m(index)\u001b[39m}\u001b[39;00m\u001b[39m)\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/pandas/core/common.py?line=561'>562</a>\u001b[0m     )\n",
      "\u001b[0;31mValueError\u001b[0m: Length of values (2) does not match length of index (82081)"
     ]
    }
   ],
   "source": [
    "\n",
    "with open(path_csv+ '/data/category.json') as f:\n",
    "\n",
    "  category = json.load(f)\n",
    "  print(category.keys())\n",
    "  print(len(category.items()))\n",
    "  print(category.values())\n",
    "\n",
    "with open(path_csv + '/data/train_anno.json') as f:\n",
    "\n",
    "  train = json.load(f)\n",
    "  print(len(train), train[0])\n",
    "\n",
    "  li = []\n",
    "  gt_labels = np.zeros((len(train),80))\n",
    "  img_id2idx = dict()\n",
    "  idx2img_id = []\n",
    "  for i,each in enumerate(train):\n",
    "    li += each['labels']\n",
    "    gt_labels[i, each['labels']] = 1\n",
    "    img_id = int(each['file_name'].split('.')[0].split('_')[-1])\n",
    "    idx2img_id.append(img_id)\n",
    "    img_id2idx[img_id] = i\n",
    "  print(max(li))\n",
    "\n",
    "  df = pd.DataFrame(gt_labels)\n",
    "  print(df.columns)\n",
    "  di = dict()\n",
    "  for col_name, val in df.iteritems():\n",
    "    if col_name == 'name': continue\n",
    "    print(val)\n",
    "    df[col_name] = val.value_counts().to_dict()\n",
    "  \n",
    "  with open('coco_class_num', 'w') as f:\n",
    "    json.dump(di, f)\n",
    "\n",
    "with open(path_csv + '/data/val_anno.json') as f:\n",
    "\n",
    "  val = json.load(f)\n",
    "  print(len(val), val[0])\n",
    "\n",
    "  test_gt_labels = np.zeros((len(val),80))\n",
    "  \n",
    "  for i,each in enumerate(val):\n",
    "    test_gt_labels[i, each['labels']] = 1\n",
    "  print(test_gt_labels)\n",
    "\n",
    "res = {\n",
    "  \"gt_labels\": gt_labels,\n",
    "  \"img_id2idx\": img_id2idx,\n",
    "  \"idx2img_id\": idx2img_id,\n",
    "  \"test_gt_labels\": test_gt_labels\n",
    "}\n",
    "\n",
    "with open('terse_gt_2014.pkl', 'wb') as f:\n",
    "  pickle.dump(res, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Explore COCO-MLT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "82081 {'file_name': 'COCO_train2014_000000057870.jpg', 'labels': [12, 77, 51, 22, 27]}\n",
      "[2243.0, 1171.0, 3924.0, 1618.0, 1804.0, 1884.0, 668.0, 2539.0, 3844.0, 2287.0, 2241.0, 2098.0, 3734.0, 5968.0, 5028.0, 1340.0, 2791.0, 2080.0, 8606.0, 1186.0, 2818.0, 3322.0, 8950.0, 3159.0, 3170.0, 1389.0, 6518.0, 8378.0, 3041.0, 1062.0, 1518.0, 1205.0, 2537.0, 1511.0, 1798.0, 128.0, 4861.0, 2068.0, 821.0, 1471.0, 1625.0, 3097.0, 2475.0, 1089.0, 2442.0, 1290.0, 1216.0, 2003.0, 481.0, 45174.0, 2202.0, 3084.0, 1671.0, 2180.0, 1645.0, 673.0, 1105.0, 3291.0, 2511.0, 2209.0, 1170.0, 2493.0, 2986.0, 1214.0, 1631.0, 2343.0, 1510.0, 2368.0, 2667.0, 151.0, 2317.0, 700.0, 2893.0, 2464.0, 4321.0, 3191.0, 2749.0, 2530.0, 1771.0, 1324.0]\n",
      "45174.0 128.0\n"
     ]
    }
   ],
   "source": [
    "with open(path_csv + '/data/train_anno.json') as f:\n",
    "  adj = np.zeros((80,80))\n",
    "  import json\n",
    "  train = json.load(f)\n",
    "  print(len(train), train[0])\n",
    "\n",
    "  li = []\n",
    "  gt_labels = np.zeros((len(train),80))\n",
    "  img_id2idx = dict()\n",
    "  idx2img_id = []\n",
    "  for i,each in enumerate(train):\n",
    "    li += each['labels']\n",
    "    gt_labels[i, each['labels']] = 1\n",
    "    img_id = int(each['file_name'].split('.')[0].split('_')[-1])\n",
    "    idx2img_id.append(img_id)\n",
    "    img_id2idx[img_id] = i\n",
    "  nums = gt_labels.sum(axis=0)\n",
    "  adj = []\n",
    "  for i,col in enumerate(gt_labels.T):\n",
    "    cond_prob = gt_labels[np.isin(col,[1.0]),:].sum(axis=0)\n",
    "    cond_prob[i] = 0\n",
    "    adj.append(cond_prob)\n",
    "    # print(adj[-1])\n",
    "  print(list(nums))\n",
    "  nums.sort()\n",
    "  print(max(nums), min(nums))\n",
    "  di={'adj': np.asarray(adj), \"nums\": np.asarray(nums)}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "82081 {'file_name': 'COCO_train2014_000000057870.jpg', 'labels': [12, 77, 51, 22, 27]}\n"
     ]
    }
   ],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "cnt = defaultdict(int)\n",
    "label_1 = list()\n",
    "label_2 = list()\n",
    "label_3andmore = list()\n",
    "with open(path_csv + '/data/train_anno.json') as f:\n",
    "  adj = np.zeros((80,80))\n",
    "  import json\n",
    "  train = json.load(f)\n",
    "  print(len(train), train[0])\n",
    "  for i,each in enumerate(train):\n",
    "    # print(each)\n",
    "    li += each['labels']\n",
    "    gt_labels[i, each['labels']] = 1\n",
    "    cnt[len(each['labels'])] += 1\n",
    "    if len(each['labels']) == 1:\n",
    "      label_1.append(each)\n",
    "    elif len(each['labels']) == 2:\n",
    "      label_2.append(each)\n",
    "    else:\n",
    "      label_3andmore.append(each)\n",
    "\n",
    "import json\n",
    "for name, j in zip(['train_anno_1', 'train_anno_2', 'train_anno_3andmore'], [label_1, label_2, label_3andmore]):\n",
    "  with open(path_csv + '/data/'+'{}.json'.format(name), 'w') as f:\n",
    "    json.dump(j, f)\n",
    "# label_1 = json.dumps(json)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "  import pickle\n",
    "  with open('coco_lt_adj.pkl', 'wb') as f:\n",
    "    pickle.dump(di, f)\n",
    "  np.save('coco_class_num.npy', gt_labels.sum(axis=0))\n",
    "\n",
    "  with open('coco_class_num', 'w') as f:\n",
    "    json.dump(di, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "82081 {'file_name': 'COCO_train2014_000000057870.jpg', 'labels': [12, 77, 51, 22, 27]}\n",
      "1861\n"
     ]
    }
   ],
   "source": [
    "with open(\"data/coco/img_id.pkl\", \"rb\") as f:\n",
    "  # data= pickle.load(f)\n",
    "  data = f.readlines()\n",
    "  data = [int(each) for each in data]\n",
    "  # print(data)\n",
    "  train_anno = []\n",
    "  with open('data/coco/data/train_anno.json') as f:\n",
    "\n",
    "    train = json.load(f)\n",
    "    print(len(train), train[0])\n",
    "    for each in train:\n",
    "      img_id = int(each['file_name'].split('.')[0].split('_')[-1])\n",
    "      if img_id in data:\n",
    "        train_anno.append(each)\n",
    "  print(len(train_anno))\n",
    "\n",
    "  with open(\"data/coco/data/train_anno_LT.json\", \"w\") as f:\n",
    "    json.dump(train_anno, f)\n",
    "  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "765b26ec2fc4c2066cf7ce8a0dde5a8255de29dd3973b3be957926608459ba30"
  },
  "kernelspec": {
   "display_name": "Python 3.9.13 ('MGSSL')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
