{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "db347b70",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-13T18:17:58.829900Z",
     "iopub.status.busy": "2025-05-13T18:17:58.829412Z",
     "iopub.status.idle": "2025-05-13T18:18:05.165912Z",
     "shell.execute_reply": "2025-05-13T18:18:05.164143Z"
    },
    "papermill": {
     "duration": 6.345673,
     "end_time": "2025-05-13T18:18:05.168650",
     "exception": false,
     "start_time": "2025-05-13T18:17:58.822977",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import math\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bc48ef53",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-13T18:18:05.177100Z",
     "iopub.status.busy": "2025-05-13T18:18:05.176538Z",
     "iopub.status.idle": "2025-05-13T18:18:05.182359Z",
     "shell.execute_reply": "2025-05-13T18:18:05.180906Z"
    },
    "papermill": {
     "duration": 0.01269,
     "end_time": "2025-05-13T18:18:05.184690",
     "exception": false,
     "start_time": "2025-05-13T18:18:05.172000",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "output_path = '/kaggle/input/imagenet-model-outputs'\n",
    "noised_output_path = '/kaggle/input/std0-3-noised-model-outputs'\n",
    "pd.set_option(\"display.precision\", 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fe91ed41",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-13T18:18:05.192673Z",
     "iopub.status.busy": "2025-05-13T18:18:05.192310Z",
     "iopub.status.idle": "2025-05-13T18:18:05.213404Z",
     "shell.execute_reply": "2025-05-13T18:18:05.211443Z"
    },
    "papermill": {
     "duration": 0.028578,
     "end_time": "2025-05-13T18:18:05.216477",
     "exception": false,
     "start_time": "2025-05-13T18:18:05.187899",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def compute_L1_loss(train_probs, train_img_names, train_labels, num_classes):\n",
    "    train_losses = {}\n",
    "    \n",
    "    for i in tqdm(range(len(train_labels))):\n",
    "        prob = train_probs[i].float()\n",
    "        label = train_labels[i]\n",
    "        one_hot = F.one_hot(torch.tensor(label), num_classes)\n",
    "        loss = torch.sum(abs(prob - one_hot))\n",
    "        train_losses[train_img_names[i]] = loss.item()\n",
    "    return train_losses\n",
    "\n",
    "def compute_01_loss(train_probs, train_img_names, train_labels, num_classes):\n",
    "    train_losses = {}\n",
    "    \n",
    "    for i in tqdm(range(len(train_labels))):\n",
    "        prob = train_probs[i].float()\n",
    "        label = train_labels[i]\n",
    "        train_losses[train_img_names[i]] = 0 if label == prob.argmax() else 1\n",
    "    return train_losses\n",
    "\n",
    "def get_losses(path, model_name):\n",
    "    num_classes = 1000\n",
    "    with open('/kaggle/input/imagenet-clusters/class_list.json', 'r') as f:\n",
    "        class_list: list = json.load(f)\n",
    "    class_to_id = {class_name: i for i, class_name in enumerate(class_list)}\n",
    "\n",
    "    try:\n",
    "        train_outputs_0 = torch.load(f'{path}/{model_name}/{model_name}_train_0.pth')\n",
    "        train_outputs_1 = torch.load(f'{path}/{model_name}/{model_name}_train_1.pth')\n",
    "        train_probs = torch.cat((train_outputs_0['probs'], train_outputs_1['probs']), dim=0)\n",
    "        train_img_names = train_outputs_0['img_names'] + train_outputs_1['img_names']\n",
    "    except:\n",
    "        train_outputs = torch.load(f'{path}/{model_name}/{model_name}_train.pth')\n",
    "        train_probs = train_outputs['probs']\n",
    "        train_img_names = train_outputs['img_names']\n",
    "    \n",
    "    train_labels = list(map(lambda x: class_to_id[x.split('_')[0]], train_img_names))\n",
    "    return compute_01_loss(train_probs, train_img_names, train_labels, num_classes)\n",
    "\n",
    "def get_noised_losses(path, model_name):\n",
    "    num_classes = 1000\n",
    "    with open('/kaggle/input/imagenet-clusters/class_list.json', 'r') as f:\n",
    "        class_list: list = json.load(f)\n",
    "    class_to_id = {class_name: i for i, class_name in enumerate(class_list)}\n",
    "\n",
    "    train_outputs = torch.load(f'{path}/{model_name}_train_0.pth')\n",
    "    train_probs = train_outputs['probs']\n",
    "    train_img_names = train_outputs['img_names']\n",
    "    \n",
    "    train_labels = list(map(lambda x: class_to_id[x.split('_')[0]], train_img_names))\n",
    "    return compute_01_loss(train_probs, train_img_names, train_labels, num_classes)\n",
    "    \n",
    "def compute_g(train_clusters, delta, K, C):\n",
    "    n_T = len(train_clusters.keys())\n",
    "    N = 0\n",
    "    for cluster_id in train_clusters.keys():        \n",
    "        N += len(train_clusters[cluster_id])\n",
    "    \n",
    "    ln = math.log(2 * K / delta)\n",
    "    term = n_T * ln / N\n",
    "    \n",
    "    return C * ((math.sqrt(2)+1) * math.sqrt(term) + 2 * term)\n",
    "\n",
    "def compute_train_loss(train_losses: list, train_clusters: dict, K=10000):\n",
    "    train_loss = 0\n",
    "    N = 0\n",
    "    \n",
    "    for cluster_id in tqdm(train_clusters.keys()):\n",
    "        Ni = len(train_clusters[cluster_id])\n",
    "        N += Ni\n",
    "        for train_point in train_clusters[cluster_id]:\n",
    "            train_loss += train_losses[train_point]\n",
    "    train_loss /= N\n",
    "    return train_loss\n",
    "\n",
    "def compute_ab(train_clusters, delta, K):\n",
    "    N = 0\n",
    "    a = 0\n",
    "    for cluster_id in train_clusters.keys():\n",
    "        Ni = len(train_clusters[cluster_id])\n",
    "        a += (Ni*Ni)\n",
    "        N += Ni\n",
    "    a /= (2 * N*N)\n",
    "    a += math.sqrt(2/N * math.log(2*K/delta))\n",
    "    return a, 1/(2*N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "30ea1748",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-13T18:18:05.225205Z",
     "iopub.status.busy": "2025-05-13T18:18:05.224702Z",
     "iopub.status.idle": "2025-05-13T18:18:05.234859Z",
     "shell.execute_reply": "2025-05-13T18:18:05.233370Z"
    },
    "papermill": {
     "duration": 0.017473,
     "end_time": "2025-05-13T18:18:05.237378",
     "exception": false,
     "start_time": "2025-05-13T18:18:05.219905",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_reports = {\n",
    "        # 'resnet18_v1':{\n",
    "        #     'test_accuracy': 69.76\n",
    "        # },\n",
    "        # 'resnet34_v1':{\n",
    "        #     'test_accuracy': 73.31\n",
    "        # },\n",
    "        # 'resnet50_v1':{\n",
    "        #     'test_accuracy': 76.13\n",
    "        # },\n",
    "        # 'resnet101_v1':{\n",
    "        #     'test_accuracy': 77.37\n",
    "        # },\n",
    "        # 'resnet152_v1':{\n",
    "        #     'test_accuracy': 78.31\n",
    "        # },\n",
    "        # 'resnet50_v2':{\n",
    "        #     'test_accuracy': 80.86\n",
    "        # },\n",
    "        # 'resnet101_v2':{\n",
    "        #     'test_accuracy': 81.89\n",
    "        # },\n",
    "        # 'resnet152_v2':{\n",
    "        #     'test_accuracy': 82.28\n",
    "        # },\n",
    "#         'swin_b_v1':{\n",
    "#             'test_accuracy': 83.58\n",
    "#         },\n",
    "#         'swin_b_v2':{\n",
    "#             'test_accuracy': 84.11\n",
    "#         },\n",
    "#         'swin_t_v1':{\n",
    "#             'test_accuracy': 81.47\n",
    "#         },\n",
    "        # 'swin_t_v2':{\n",
    "        #     'test_accuracy': 82.07\n",
    "        # },\n",
    "        # 'vgg13_v1':{\n",
    "        #     'test_accuracy': 69.93\n",
    "        # },\n",
    "#         'vgg13_bn_v1':{\n",
    "#             'test_accuracy': 71.59\n",
    "#         },\n",
    "#         'vgg19_v1':{\n",
    "#             'test_accuracy': 72.38\n",
    "#         },\n",
    "#         'vgg19_bn_v1':{\n",
    "#             'test_accuracy': 74.22\n",
    "#         },\n",
    "#         'densenet121_v1':{\n",
    "#             'test_accuracy': 74.43\n",
    "#         },\n",
    "        # 'densenet161_v1':{\n",
    "        #     'test_accuracy': 77.14\n",
    "        # },\n",
    "        # 'densenet169_v1':{\n",
    "        #     'test_accuracy': 75.60\n",
    "        # },\n",
    "        # 'densenet201_v1':{\n",
    "        #     'test_accuracy': 76.90\n",
    "        # },\n",
    "        # 'convnext_base_v1':{\n",
    "        #     'test_accuracy': 84.062\n",
    "        # },\n",
    "        # 'convnext_large_v1':{\n",
    "        #     'test_accuracy': 84.414\n",
    "        # },\n",
    "#         'regnet_y_128gf_e2e':{\n",
    "#             'test_accuracy': 88.228\n",
    "#         },\n",
    "#         'regnet_y_128gf_linear':{\n",
    "#             'test_accuracy': 86.068\n",
    "#         },\n",
    "#         'regnet_y_32gf_e2e':{\n",
    "#             'test_accuracy': 86.838\n",
    "#         },\n",
    "#         'regnet_y_32gf_linear':{\n",
    "#             'test_accuracy': 84.622\n",
    "#         },\n",
    "        'regnet_y_32gf_v2':{\n",
    "            'test_accuracy': 81.982\n",
    "        },\n",
    "#         'vit_b_16_linear':{\n",
    "#             'test_accuracy': 81.886\n",
    "#         },\n",
    "#         'vit_b_16_v1':{\n",
    "#             'test_accuracy': 81.072\n",
    "#         },\n",
    "#         'vit_h_14_linear':{\n",
    "#             'test_accuracy': 85.708\n",
    "#         },\n",
    "#         'vit_l_16_linear':{\n",
    "#             'test_accuracy': 85.146\n",
    "#         },\n",
    "        # 'vit_l_16_v1':{\n",
    "        #     'test_accuracy': 79.662\n",
    "        # },\n",
    "    }\n",
    "\n",
    "def get_results(model_train_losses, model_valid_losses, K, run_id):\n",
    "    print(f'run_id: {run_id}')\n",
    "    train_group_path = f'/kaggle/input/imagenet-clusters/run_{run_id}/train_group.json'\n",
    "#     valid_group_path = f'/kaggle/input/imagenet-clusters/run_{run_id}/valid_group.json'\n",
    "    valid_group_path = f'/kaggle/input/imagenet-clusters/run_{run_id}/train_group.json'\n",
    "    with open(train_group_path, 'r') as f:\n",
    "        train_clusters: dict = json.load(f)\n",
    "    with open(valid_group_path, 'r') as f:\n",
    "        valid_clusters: dict = json.load(f)\n",
    "            \n",
    "    results = test_reports\n",
    "    \n",
    "    for model_name in results.keys():\n",
    "        results[model_name].update(compute_bound(model_train_losses[model_name], \n",
    "                                                 model_valid_losses[model_name], \n",
    "                                                 train_clusters, valid_clusters, \n",
    "                                                 K=10000))\n",
    "        \n",
    "    with open(f'results_{run_id}.json', 'w') as f:\n",
    "        json.dump(results, f, indent=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9df36bb4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-13T18:18:05.245127Z",
     "iopub.status.busy": "2025-05-13T18:18:05.244716Z",
     "iopub.status.idle": "2025-05-13T18:19:36.075787Z",
     "shell.execute_reply": "2025-05-13T18:19:36.073240Z"
    },
    "papermill": {
     "duration": 90.83873,
     "end_time": "2025-05-13T18:19:36.079185",
     "exception": false,
     "start_time": "2025-05-13T18:18:05.240455",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1281167/1281167 [00:26<00:00, 48883.90it/s]\n",
      "100%|██████████| 1281167/1281167 [00:28<00:00, 45109.15it/s]\n"
     ]
    }
   ],
   "source": [
    "model_train_losses = {}\n",
    "noised_model_train_losses = {}\n",
    "for model_name in test_reports.keys():\n",
    "    train_losses = get_losses(output_path, model_name)\n",
    "    model_train_losses[model_name] = train_losses\n",
    "    noised_train_losses = get_noised_losses(noised_output_path, model_name)\n",
    "    noised_model_train_losses[model_name] = noised_train_losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "57005790",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-13T18:19:36.159922Z",
     "iopub.status.busy": "2025-05-13T18:19:36.158894Z",
     "iopub.status.idle": "2025-05-13T18:19:36.174166Z",
     "shell.execute_reply": "2025-05-13T18:19:36.172089Z"
    },
    "papermill": {
     "duration": 0.060572,
     "end_time": "2025-05-13T18:19:36.178560",
     "exception": false,
     "start_time": "2025-05-13T18:19:36.117988",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from multiprocessing import Pool, cpu_count\n",
    "from functools import partial\n",
    "\n",
    "def compute_minus_loss(cluster_id, args):\n",
    "    \"\"\" Compute minus loss for a given cluster. \"\"\"\n",
    "    clusters, noised_clusters, losses = args\n",
    "    if cluster_id not in noised_clusters:\n",
    "        return 0  # Skip if not in noised clusters\n",
    "    \n",
    "    ni = len(clusters[cluster_id])\n",
    "    mi = len(noised_clusters[cluster_id])\n",
    "    tmp = sum(losses[z] for z in clusters[cluster_id])\n",
    "    \n",
    "    return (mi * tmp / ni)\n",
    "\n",
    "def compute_gap_loss(cluster_id, args):\n",
    "    \"\"\" Compute gap loss for a given cluster. \"\"\"\n",
    "    clusters, noised_clusters, losses, noised_losses = args\n",
    "    if cluster_id not in noised_clusters:\n",
    "        return 0  # Skip if not in noised clusters\n",
    "    \n",
    "    ni = len(clusters[cluster_id])\n",
    "    tmp = sum(\n",
    "        abs(losses[z] - noised_losses[s])\n",
    "        for z in clusters[cluster_id]\n",
    "        for s in noised_clusters[cluster_id]\n",
    "    )\n",
    "    \n",
    "    return (tmp / ni)\n",
    "\n",
    "def compute_additional_term(losses, noised_losses, clusters, noised_clusters, m=1281167):\n",
    "    gap_loss = 0\n",
    "    minus_loss = 0\n",
    "    plus_loss = sum(noised_losses.values()) / m\n",
    "\n",
    "    cluster_ids = list(clusters.keys())\n",
    "    \n",
    "    # Use multiprocessing for minus_loss computation\n",
    "    partial_compute_minus_loss = partial(compute_minus_loss, \n",
    "                                         args=(clusters, noised_clusters, losses))\n",
    "    with Pool(processes=cpu_count()) as pool:\n",
    "        minus_loss_results = pool.map(partial_compute_minus_loss, [cid for cid in cluster_ids])\n",
    "    \n",
    "    minus_loss = sum(minus_loss_results)\n",
    "\n",
    "    # Use multiprocessing for gap_loss computation\n",
    "    partial_compute_gap_loss = partial(compute_gap_loss, \n",
    "                                       args=(clusters, noised_clusters, losses, noised_losses))\n",
    "    with Pool(processes=cpu_count()) as pool:\n",
    "        gap_loss_results = pool.map(partial_compute_gap_loss, [cid for cid in cluster_ids])\n",
    "\n",
    "    gap_loss = sum(gap_loss_results)\n",
    "\n",
    "    return {\n",
    "        'additional_term': (gap_loss - minus_loss) / m + plus_loss,\n",
    "        'gap_loss': gap_loss / m,\n",
    "        'minus_loss': minus_loss / m,\n",
    "        'plus_loss': plus_loss\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3137ea98",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-13T18:19:36.268151Z",
     "iopub.status.busy": "2025-05-13T18:19:36.267554Z",
     "iopub.status.idle": "2025-05-13T20:56:44.249566Z",
     "shell.execute_reply": "2025-05-13T20:56:44.247718Z"
    },
    "papermill": {
     "duration": 9428.062105,
     "end_time": "2025-05-13T20:56:44.285516",
     "exception": false,
     "start_time": "2025-05-13T18:19:36.223411",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "regnet_y_32gf_v2\n",
      "{'additional_term': 0.6602707042218152, 'gap_loss': 0.3548522503069701, 'minus_loss': 0.03762893959435399, 'plus_loss': 0.34304739350919905}\n"
     ]
    }
   ],
   "source": [
    "with open('/kaggle/input/imagenet-train-clusters/seed_42/200/train_group.json', 'r') as f:\n",
    "    train_clusters: dict = json.load(f)\n",
    "with open('/kaggle/input/noised-imagenet-train-clusters/seed_42/std0.3/train_group.json', 'r') as f:\n",
    "    noised_train_clusters: dict = json.load(f)\n",
    "        \n",
    "for model_name in test_reports.keys():\n",
    "    print(model_name)\n",
    "    train_losses = model_train_losses[model_name]\n",
    "    noised_train_losses = noised_model_train_losses[model_name]\n",
    "    print(compute_additional_term(train_losses, noised_train_losses, train_clusters, noised_train_clusters))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b50af64a",
   "metadata": {
    "papermill": {
     "duration": 0.031234,
     "end_time": "2025-05-13T20:56:44.347291",
     "exception": false,
     "start_time": "2025-05-13T20:56:44.316057",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kaggle": {
   "accelerator": "none",
   "dataSources": [
    {
     "datasetId": 5011828,
     "sourceId": 8794701,
     "sourceType": "datasetVersion"
    },
    {
     "datasetId": 4780783,
     "sourceId": 9100309,
     "sourceType": "datasetVersion"
    },
    {
     "datasetId": 6268723,
     "sourceId": 10284152,
     "sourceType": "datasetVersion"
    },
    {
     "datasetId": 6403427,
     "sourceId": 11307171,
     "sourceType": "datasetVersion"
    },
    {
     "datasetId": 7066932,
     "sourceId": 11799875,
     "sourceType": "datasetVersion"
    }
   ],
   "dockerImageVersionId": 30664,
   "isGpuEnabled": false,
   "isInternetEnabled": true,
   "language": "python",
   "sourceType": "notebook"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 9533.126276,
   "end_time": "2025-05-13T20:56:47.831248",
   "environment_variables": {},
   "exception": null,
   "input_path": "__notebook__.ipynb",
   "output_path": "__notebook__.ipynb",
   "parameters": {},
   "start_time": "2025-05-13T18:17:54.704972",
   "version": "2.5.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
