{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "0501ab29-614b-4db8-ac53-217f0b339c14",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import time\n",
    "import random\n",
    "import pickle\n",
    "import argparse\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from numpy import linalg as LA\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.autonotebook import tqdm\n",
    "from tinyimagenet import TinyImageNet\n",
    "\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.func import grad, vmap, functional_call\n",
    "from torchvision import datasets, transforms, models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "ce17145b-d95d-43cc-bf15-2ae8d933cb4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model, test_loader, device, use_noisy_labels=False):\n",
    "\n",
    "    model.eval()\n",
    "    running_loss = 0.0\n",
    "    correct = 0\n",
    "    total = 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for _, inputs, true_labels, noisy_labels in tqdm(test_loader):\n",
    "            \n",
    "            if use_noisy_labels == False:\n",
    "                labels = true_labels\n",
    "            else:\n",
    "                labels = noisy_labels\n",
    "            \n",
    "            inputs, labels = inputs.to(device), labels.to(device)\n",
    "            outputs = model(inputs)\n",
    "            loss = torch.nn.CrossEntropyLoss()(outputs, labels)\n",
    "    \n",
    "            running_loss += loss.item() * inputs.size(0)\n",
    "            _, predicted = outputs.max(1)\n",
    "            total += labels.size(0)\n",
    "            correct += predicted.eq(labels).sum().item()\n",
    "\n",
    "    epoch_loss = running_loss / total\n",
    "    epoch_acc = correct / total\n",
    "    \n",
    "    return epoch_loss, epoch_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "ebad84ca-1cb7-441e-9ccd-94674461596d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cost_model(input_, target, params, buffers, model):\n",
    "    \n",
    "    input_ = input_.unsqueeze(0)\n",
    "    target = target.unsqueeze(0)\n",
    "\n",
    "    out_ = functional_call(model, (params, buffers), input_)\n",
    "    cost = torch.nn.CrossEntropyLoss()(out_, target)\n",
    "\n",
    "    return cost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "a317f752-74e5-484a-b39a-0e4398e29836",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_similarity(task_vector, sample_grads, method=\"normed_proj\", temperature=1.0):\n",
    "    \n",
    "    tv_flat = task_vector.flatten()\n",
    "    normed_tv_flat = torch.nn.functional.normalize(tv_flat, p=2, dim=0)\n",
    "    \n",
    "    score_tensor = torch.empty(len(sample_grads), device=tv_flat.device)\n",
    "\n",
    "    for i, neg_grad in enumerate(sample_grads):\n",
    "        neg_grad_flat = neg_grad.flatten()\n",
    "        \n",
    "        ## compute mimic score ##\n",
    "        if method.endswith(\"cos\"):\n",
    "            normed_neg_grad_flat = torch.nn.functional.normalize(neg_grad_flat, p=2, dim=0)\n",
    "            score = torch.dot(normed_neg_grad_flat, normed_tv_flat)\n",
    "        elif method.endswith(\"proj\"):\n",
    "            score = torch.dot(neg_grad_flat, normed_tv_flat)\n",
    "            \n",
    "        score_tensor[i] = score.item()\n",
    "\n",
    "    if method.startswith(\"normed\"):\n",
    "        score_tensor = score_tensor / temperature\n",
    "        normed_score_tensor = torch.nn.Softmax(dim=0)(score_tensor)\n",
    "        return normed_score_tensor\n",
    "    else:\n",
    "        return score_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "4399a7e4-095c-4430-aa18-80318750199e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_task_vector(current_model, reference_model, mimic_layer_name):\n",
    "\n",
    "    current_model_state_dict = current_model.state_dict()\n",
    "    reference_model_state_dict = reference_model.state_dict()\n",
    "    task_vector = reference_model_state_dict[mimic_layer_name] - current_model_state_dict[mimic_layer_name]\n",
    "    \n",
    "    return task_vector    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "764c4ebd-d22a-41d0-8ce3-85819a970390",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"\\ndef get_calibration_results(w, neg_grad, task_vector):\\n\\n    calibrated_grads = np.sum((w.reshape(len(w), 1, 1) * neg_grad), axis=0)\\n    uncalibrated_grads = np.mean(neg_grad, axis=0)\\n\\n    d1 = LA.norm(calibrated_grads - uncalibrated_grads, 'fro')\\n    d2 = LA.norm(uncalibrated_grads - task_vector, 'fro')\\n    d3 = LA.norm(calibrated_grads - task_vector, 'fro')\\n\\n    return d1, d2, d3\\n\""
      ]
     },
     "execution_count": 66,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\"\"\"\n",
    "def get_calibration_results(w, neg_grad, task_vector):\n",
    "\n",
    "    calibrated_grads = np.sum((w.reshape(len(w), 1, 1) * neg_grad), axis=0)\n",
    "    uncalibrated_grads = np.mean(neg_grad, axis=0)\n",
    "\n",
    "    d1 = LA.norm(calibrated_grads - uncalibrated_grads, 'fro')\n",
    "    d2 = LA.norm(uncalibrated_grads - task_vector, 'fro')\n",
    "    d3 = LA.norm(calibrated_grads - task_vector, 'fro')\n",
    "\n",
    "    return d1, d2, d3\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "be7755be-b30c-4df0-b26d-6d2238d3bf8f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"\\ndef task_vector_norm(model, ref_model, layer):\\n    \\n    tau = (ref_model.state_dict()[layer] - model.state_dict()[layer]).cpu().numpy()\\n    \\n    return LA.norm(tau, 'fro')\\n\""
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\"\"\"\n",
    "def task_vector_norm(model, ref_model, layer):\n",
    "    \n",
    "    tau = (ref_model.state_dict()[layer] - model.state_dict()[layer]).cpu().numpy()\n",
    "    \n",
    "    return LA.norm(tau, 'fro')\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "662340df-b5f0-4ca6-a3bc-117a9e8fc25c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def objective_fn_matrix(neg_grad, task_vector, w, lambd, norm_way=\"l2\"):\n",
    "    \n",
    "    slice_norms = []\n",
    "    \n",
    "    for i in range(neg_grad.shape[0]):\n",
    "        neg_grad_slice = neg_grad[i, :, :] * w[i]\n",
    "        slice_norm = cp.norm(neg_grad_slice - task_vector, 'fro')\n",
    "        slice_norms.append(slice_norm)\n",
    "\n",
    "    if norm_way == \"l2\":\n",
    "        total_norm = cp.sum(cp.hstack(slice_norms)) + lambd * cp.norm2(w)\n",
    "    elif norm_way == \"l1\":\n",
    "        total_norm = cp.sum(cp.hstack(slice_norms)) + lambd * cp.norm1(w)\n",
    "\n",
    "    return total_norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "16423888-3596-4302-965a-4250bd3037c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def objective_fn_vector(neg_grad, task_vector, w, lambd, norm_way=\"l2\"):\n",
    "\n",
    "    if norm_way == \"l2\":\n",
    "        return cp.norm2((w @ neg_grad) - task_vector) + lambd * cp.norm2(w)\n",
    "    elif norm_way == \"l1\":\n",
    "        return cp.norm2((w @ neg_grad) - task_vector) + lambd * cp.norm1(w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "7687d3b5-911a-4ae8-adca-0e8a4fa2f170",
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_noise_to_labels(labels, num_class, noise_ratio=0.1):\n",
    "\n",
    "    labels = np.array(labels)\n",
    "    num_labels = len(labels)\n",
    "    num_noisy_labels = int(noise_ratio * num_labels)\n",
    "    \n",
    "    # Indices for which labels will be noisy\n",
    "    noisy_indices = random.sample(range(num_labels), num_noisy_labels)\n",
    "    \n",
    "    # Make a copy of the original labels\n",
    "    noisy_labels = labels.copy()\n",
    "    \n",
    "    for idx in noisy_indices:\n",
    "        \n",
    "        # Assign a new random label different from the original one\n",
    "        original_label = labels[idx]\n",
    "        new_label = (original_label + np.random.randint(1, num_class)) % num_class\n",
    "        noisy_labels[idx] = new_label\n",
    "    \n",
    "    return noisy_labels, noisy_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "af00ff2d-bc31-4565-ae4a-e2b22f89d6a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_per_sample_gradients(model, inputs, targets):\n",
    "\n",
    "    params = {k: v.detach() for k, v in model.named_parameters() if v.requires_grad == True}\n",
    "    buffers = {k: v.detach() for k, v in model.named_buffers() if v.requires_grad == True}\n",
    "\n",
    "    ft_grad = grad(cost_model, argnums=2)\n",
    "    ft_all_grads = vmap(ft_grad, in_dims = (0, 0, None, None, None))\n",
    "    ft_per_sample_grads = ft_all_grads(inputs, targets, params, buffers, model)\n",
    "\n",
    "    return ft_per_sample_grads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "52df46c5-42fe-4740-8adf-0155a7d604bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def solve_subset_selection(ref, inputs, num_datapoint, lambd_value, norm_way, device):\n",
    "\n",
    "    selection_weights = cp.Variable(num_datapoint)\n",
    "    constraints = []\n",
    "    # constraints = [0 <= selection_weights, selection_weights <= 1, cp.sum(selection_weights) == 1]\n",
    "    lambd = cp.Parameter(nonneg=True)\n",
    "    lambd.value = lambd_value\n",
    "    if len(ref.shape) == 2:\n",
    "        problem = cp.Problem(cp.Minimize(objective_fn_matrix(inputs, ref, selection_weights, lambd, norm_way)), constraints=constraints)\n",
    "    elif len(ref.shape) == 1:\n",
    "        problem = cp.Problem(cp.Minimize(objective_fn_vector(inputs, ref, selection_weights, lambd, norm_way)), constraints=constraints)\n",
    "    problem.solve(solver=cp.MOSEK, verbose=False)\n",
    "    selection_weights = torch.Tensor(selection_weights.value).to(device)\n",
    "    \n",
    "    return selection_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "239789fc-6963-4405-a8c0-cac1e07c6f46",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gradient_calibration(model, per_sample_weights, per_sample_grads, mimic_layer_name, calibrate_mimic_layer_only=True):\n",
    "\n",
    "    if calibrate_mimic_layer_only == True:\n",
    "        for name, param in model.named_parameters():\n",
    "            if param.requires_grad == True:\n",
    "                if name == mimic_layer_name:\n",
    "                    reshape_array = [1 for s in range(len(param.shape) + 1)]\n",
    "                    reshape_array[0] = len(per_sample_weights)\n",
    "                    reshaped_new_weights = per_sample_weights.reshape(reshape_array)\n",
    "                    param.grad = torch.sum(reshaped_new_weights * per_sample_grads[name], axis=0)\n",
    "                else:\n",
    "                    param.grad = torch.mean(per_sample_grads[name], axis=0)\n",
    "                    \n",
    "    else:\n",
    "        for name, param in model.named_parameters():\n",
    "            if param.requires_grad == True:\n",
    "                reshape_array = [1 for s in range(len(param.shape) + 1)]\n",
    "                reshape_array[0] = len(per_sample_weights)\n",
    "                reshaped_new_weights = per_sample_weights.reshape(reshape_array)\n",
    "                param.grad = torch.sum(reshaped_new_weights * per_sample_grads[name], axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "d8c94e03-cc51-45d7-86d7-3a227dde91c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Custom Dataset class with index tracking and label noise addition\n",
    "class IndexedDataset(Dataset):\n",
    "    \n",
    "    def __init__(self, root, name, train=True, transform=None, download=False, noise_ratio=0.0):\n",
    "        \n",
    "        if name == \"dtd\":\n",
    "            if train == True: train = \"train\"\n",
    "            else: train = \"test\"\n",
    "            self.dataset = datasets.DTD(root=root, split=train, transform=transform, download=download)\n",
    "            self.dataset.targets = np.array([label for _, label in self.dataset])\n",
    "            \n",
    "        if name == \"stl10\":\n",
    "            if train == True: train = \"train\"\n",
    "            else: train = \"test\"\n",
    "            self.dataset = datasets.STL10(root=root, split=train, transform=transform, download=download)\n",
    "            self.dataset.targets = self.dataset.labels\n",
    "            \n",
    "        if name == \"cifar10\":\n",
    "            self.dataset = datasets.CIFAR10(root=root, train=train, transform=transform, download=download)\n",
    "\n",
    "        if name == \"flower102\":\n",
    "            if train == True: train = \"train\"\n",
    "            else: train = \"test\"\n",
    "            self.dataset = datasets.Flowers102(root=root, split=train, transform=transform, download=download)\n",
    "            self.dataset.targets = np.array([label for _, label in self.dataset])\n",
    "\n",
    "        if name == \"country211\":\n",
    "            if train == True: train = \"train\"\n",
    "            else: train = \"test\"\n",
    "            self.dataset = datasets.Country211(root=root, split=train, transform=transform, download=download)\n",
    "            \n",
    "        if name == \"cifar100\":\n",
    "            self.dataset = datasets.CIFAR100(root=root, train=train, transform=transform, download=download)\n",
    "\n",
    "        if name == \"tinyimagenet\":\n",
    "            if train == True: train = \"train\"\n",
    "            else: train = \"test\"\n",
    "            self.dataset = TinyImageNet(f\"{root}/tiny-imagenet-200\", split=train, transform=transform, imagenet_idx=True)\n",
    "\n",
    "        if name == \"pet\":\n",
    "            if train == True: train = \"trainval\"\n",
    "            else: train = \"test\"\n",
    "            self.dataset = datasets.OxfordIIITPet(root=root, split=train, target_types=\"category\", transform=transform, download=download)\n",
    "            self.dataset.targets = [self.dataset[i][1] for i in range(len(self.dataset))]\n",
    "\n",
    "        if name == \"mnist\":\n",
    "            self.dataset = datasets.MNIST(root=root, train=train, transform=transform, download=download)\n",
    "\n",
    "        if name == \"svhn\":\n",
    "            if train == True: train = \"train\"\n",
    "            else: train = \"test\"\n",
    "            self.dataset = datasets.SVHN(root=root, split=train, transform=transform, download=download)\n",
    "            self.dataset.targets = self.dataset.labels\n",
    "\n",
    "        self.true_labels = self.dataset.targets\n",
    "        \n",
    "        if name == \"flower102\":\n",
    "            self.num_class = 102\n",
    "        elif name == \"svhn\":\n",
    "            self.num_class = 10\n",
    "        else:\n",
    "            self.num_class = len(self.dataset.classes)\n",
    "        \n",
    "        # Apply noise to labels if noise_ratio is specified\n",
    "        if noise_ratio > 0:\n",
    "            self.noise_ratio = noise_ratio\n",
    "            self._apply_noise_to_labels()\n",
    "        else:\n",
    "            self.noise_ratio = 0\n",
    "            self.noisy_labels = self.dataset.targets\n",
    "            self.noisy_indices = []\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        data, true_target = self.dataset[index]\n",
    "        noisy_target = self.noisy_labels[index]\n",
    "        return index, data, true_target, noisy_target\n",
    "\n",
    "    def _apply_noise_to_labels(self):\n",
    "        self.noisy_labels, self.noisy_indices = add_noise_to_labels(self.true_labels, self.num_class, self.noise_ratio)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "feecfca0-f6b2-418b-a82e-a8bf42c47d0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_init_model(arch_name, num_class, device, seed, pretrained=True, linear_probing=True):\n",
    "\n",
    "    if pretrained == True:\n",
    "        ## using pretrained weights as backbone ##\n",
    "        if os.path.exists(f\"./saved_models/pretrained_linear_probing/init_{arch_name}_{num_class}classes_seed{seed}.pt\"):\n",
    "            model = torch.load(f\"./saved_models/pretrained_linear_probing/init_{arch_name}_{num_class}classes_seed{seed}.pt\")\n",
    "        else:\n",
    "            if arch_name == 'vit-b':\n",
    "                model = models.vit_b_16(weights='IMAGENET1K_V1')\n",
    "            elif arch_name == 'vit-l':\n",
    "                model = models.vit_l_16(weights='IMAGENET1K_V1')\n",
    "            model.heads.head = torch.nn.Linear(model.heads.head.in_features, num_class)\n",
    "            torch.save(model, f'./saved_models/pretrained_linear_probing/init_{arch_name}_{num_class}classes_seed{seed}.pt')\n",
    "        if linear_probing == True:\n",
    "            for name, param in model.named_parameters():\n",
    "                if \"heads.head\" not in name: param.requires_grad = False\n",
    "                else: param.requires_grad = True\n",
    "    else:\n",
    "        ## train from scratch ##\n",
    "        if os.path.exists(f\"./saved_models/train_from_scratch/init_{arch_name}_{num_class}classes_seed{seed}.pt\"):\n",
    "            model = torch.load(f\"./saved_models/train_from_scratch/init_{arch_name}_{num_class}classes_seed{seed}.pt\")\n",
    "        else:\n",
    "            if arch_name == 'vit-b':\n",
    "                model = models.vit_b_16(weights=None)\n",
    "            elif arch_name == 'vit-l':\n",
    "                model = models.vit_l_16(weights=None)\n",
    "            model.heads.head = torch.nn.Linear(model.heads.head.in_features, num_class)\n",
    "            torch.save(model, f'./saved_models/train_from_scratch/init_{arch_name}_{num_class}classes_seed{seed}.pt')\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "474cee92-4577-45c8-bd0d-97e375cd4ac5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_arguments():\n",
    "    \n",
    "    parser = argparse.ArgumentParser(description='Argument parser for training script')\n",
    "\n",
    "    parser.add_argument('--mode', type=str, choices=['grad-mimic', 'grad-match', 'grad-descent'], required=True,\n",
    "                        help='Select the mode for training: grad-mimic, grad-match, or grad-descent (baseline)')\n",
    "    parser.add_argument('--method', type=str, default='cos', choices=['opt', 'cos', 'proj', 'normed_cos', 'normed_proj'], required=False,\n",
    "                        help='Select the method for reweighting gradients: optimization, (normed) cosine similarity, or (normed) projection length')    \n",
    "    parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['svhn', 'mnist', 'dtd', 'pet', 'stl10', 'cifar10', 'flower102', 'country211', 'cifar100', 'tinyimagenet'], required=False,\n",
    "                        help='Name of the dataset to use')\n",
    "    parser.add_argument('--noisy_level', type=float, default=0.0, required=False,\n",
    "                        help='Percentage of involved noisy labels')\n",
    "    parser.add_argument('--num_epoch', type=int, default=5, required=False,\n",
    "                        help='Number of epochs to train')\n",
    "    parser.add_argument('--model_arch', type=str, default='vit-b', choices=['vit-b', 'vit-l'], required=False,\n",
    "                        help='Type of model architecture to use')\n",
    "    parser.add_argument('--pretrained', default=True, required=False, action=argparse.BooleanOptionalAction,\n",
    "                        help='To load pretrained weights (IMAGENET_V1) or not')\n",
    "    parser.add_argument('--linear_probing', default=True, required=False, action=argparse.BooleanOptionalAction,\n",
    "                        help='Fine-tune on the top of backbone only')\n",
    "    parser.add_argument('--mimic_layer', type=str, default='heads.head.weight', required=False, \n",
    "                        help='Specific layer name to use')\n",
    "    parser.add_argument('--calibrate_mimic_layer_only', default=True, required=False, action=argparse.BooleanOptionalAction,\n",
    "                        help='Only calibrate gradients for mimic layer or not')\n",
    "    parser.add_argument('--temperature', type=float, default=1.0, required=False,\n",
    "                        help='To control the smoonthness of selection weights')\n",
    "    parser.add_argument('--starting_epoch', type=int, default=0, required=False,\n",
    "                        help='Starting epoch for using grad-mimic')\n",
    "    parser.add_argument('--training_batch_size', type=int, default=32, choices=[16, 32, 64, 128, 256], required=False,\n",
    "                        help='Training batch size')\n",
    "    parser.add_argument('--dataset_dir', type=str, default='./data', required=False,\n",
    "                        help='Directory path to load data')\n",
    "    parser.add_argument('--optimizer', type=str, default='adamw', choices=['adamw', 'sgd'], required=False,\n",
    "                        help='Optimizer to use')\n",
    "    parser.add_argument('--learning_rate', type=float, default=1e-4, required=False,\n",
    "                        help='Learning rate to train')\n",
    "    parser.add_argument('--norm_way', type=str, default=\"l2\", choices=[\"l2\", \"l1\"], required=False,\n",
    "                        help='Regularization approach')\n",
    "    parser.add_argument('--lambda_value', type=float, default=0.0, required=False,\n",
    "                        help='Regularization parameter')\n",
    "    parser.add_argument('--cuda_device', type=int, default=0, required=False,\n",
    "                        help='Select the CUDA device ID to use for training')\n",
    "    parser.add_argument('--seed', type=int, default=123, required=False,\n",
    "                        help='Seed value for random number generator')\n",
    "    \n",
    "    args = parser.parse_args()\n",
    "    \n",
    "    return args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "270aa998-1775-4b71-987b-499566695281",
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.argv = ['ipykernel_launcher.py', \\\n",
    "            '--mode', 'grad-mimic', '--method', 'normed_proj', '--model_arch', 'vit-b', \\\n",
    "            '--pretrained', '--no-linear_probing', '--mimic_layer', 'heads.head.weight', \\\n",
    "            '--no-calibrate_mimic_layer_only', '--temperature', '0.1', '--cuda_device', '2', \\\n",
    "            '--dataset_name', 'dtd', '--noisy_level', '0.3', '--num_epoch', '10']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "58dbfb70-8862-4ff7-a102-3ff23212e8a7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-----------------------Arguments-----------------------\n",
      "Mode: grad-mimic\n",
      "Method: normed_proj\n",
      "Dataset Name: dtd\n",
      "Noisy Level: 0.3\n",
      "Number of Epochs: 10\n",
      "Model Arch: vit-b\n",
      "Pretrained: True\n",
      "Linear Probing: False\n",
      "Mimic Layer: heads.head.weight\n",
      "Calibrate Mimic Layer only: False\n",
      "Temperature: 0.1\n",
      "Starting Epoch: 0\n",
      "Training Batch Size: 32\n",
      "Dataset Directory: ./data\n",
      "Optimizer: adamw\n",
      "Learning Rate: 0.0001\n",
      "Norm Way: l2\n",
      "Lambda: 0.0\n",
      "CUDA Device: 2\n",
      "Seed: 123\n",
      "-------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "args = parse_arguments()\n",
    "print(\"-----------------------Arguments-----------------------\")\n",
    "print(f\"Mode: {args.mode}\")\n",
    "print(f\"Method: {args.method}\")\n",
    "print(f\"Dataset Name: {args.dataset_name}\")\n",
    "print(f\"Noisy Level: {args.noisy_level}\")\n",
    "print(f\"Number of Epochs: {args.num_epoch}\")\n",
    "print(f\"Model Arch: {args.model_arch}\")\n",
    "print(f\"Pretrained: {args.pretrained}\")\n",
    "print(f\"Linear Probing: {args.linear_probing}\")\n",
    "print(f\"Mimic Layer: {args.mimic_layer}\")\n",
    "print(f\"Calibrate Mimic Layer only: {args.calibrate_mimic_layer_only}\")\n",
    "print(f\"Temperature: {args.temperature}\")\n",
    "print(f\"Starting Epoch: {args.starting_epoch}\")\n",
    "print(f\"Training Batch Size: {args.training_batch_size}\")\n",
    "print(f\"Dataset Directory: {args.dataset_dir}\")\n",
    "print(f\"Optimizer: {args.optimizer}\")\n",
    "print(f\"Learning Rate: {args.learning_rate}\")\n",
    "print(f\"Norm Way: {args.norm_way}\")\n",
    "print(f\"Lambda: {args.lambda_value}\")\n",
    "print(f\"CUDA Device: {args.cuda_device}\")\n",
    "print(f\"Seed: {args.seed}\")\n",
    "print(\"-------------------------------------------------------\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "416fc5ef-b1c8-4ebb-aedb-dd6d2baf1c6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Setup Everything ##\n",
    "np.random.seed(args.seed)\n",
    "random.seed(args.seed)\n",
    "torch.manual_seed(args.seed)\n",
    "torch.cuda.manual_seed(args.seed)\n",
    "\n",
    "test_batch_size = 512\n",
    "device = torch.device(f'cuda:{args.cuda_device}' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "## Data preprocessing and augmentation ##\n",
    "transform_train = transforms.Compose([\n",
    "    transforms.Resize((224, 224)),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
    "])\n",
    "\n",
    "transform_test = transforms.Compose([\n",
    "    transforms.Resize((224, 224)),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
    "])\n",
    "\n",
    "## Create instances of IndexedDataset ##\n",
    "train_dataset = IndexedDataset(root=args.dataset_dir, name=args.dataset_name, train=True, download=False, transform=transform_train, noise_ratio=args.noisy_level)\n",
    "test_dataset = IndexedDataset(root=args.dataset_dir, name=args.dataset_name, train=False, download=False, transform=transform_test, noise_ratio=0.0)\n",
    "num_class = train_dataset.num_class\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset, batch_size=args.training_batch_size, shuffle=True, num_workers=4)\n",
    "test_loader = DataLoader(dataset=test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "5167b04c-bccf-4f45-81e4-1eef40b8f1c9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/tmp/ipykernel_1756324/1201812452.py:6: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  model = torch.load(f\"./saved_models/pretrained_linear_probing/init_{arch_name}_{num_class}classes_seed{seed}.pt\")\n"
     ]
    }
   ],
   "source": [
    "## Load initial model ##\n",
    "model = load_init_model(arch_name=args.model_arch, num_class=num_class, device=device, seed=args.seed, pretrained=args.pretrained, linear_probing=args.linear_probing)\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "1eb75cf4-e5f1-4892-8fef-ac7e1284e4e7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 59/59 [00:08<00:00,  7.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Init. Model -- Train True Loss: 3.9483, Train True Acc: 0.0133\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 59/59 [00:07<00:00,  7.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Init. Model -- Train Noisy Loss: 3.9404, Train Noisy Acc: 0.0223\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:09<00:00,  2.50s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Init. Model -- Test Loss: 3.9386, Test Acc: 0.0112\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "## Evaluate the initial model ##\n",
    "init_true_train_loss, init_true_train_acc = evaluate(model, train_loader, device)\n",
    "print(f'Init. Model -- Train True Loss: {init_true_train_loss:.4f}, Train True Acc: {init_true_train_acc:.4f}')\n",
    "\n",
    "init_noisy_train_loss, init_noisy_train_acc = evaluate(model, train_loader, device, use_noisy_labels=True)\n",
    "print(f'Init. Model -- Train Noisy Loss: {init_noisy_train_loss:.4f}, Train Noisy Acc: {init_noisy_train_acc:.4f}')\n",
    "\n",
    "init_test_loss, init_test_acc = evaluate(model, test_loader, device)\n",
    "print(f'Init. Model -- Test Loss: {init_test_loss:.4f}, Test Acc: {init_test_acc:.4f}')\n",
    "\n",
    "## Record learning results ##\n",
    "learning_results = {\n",
    "    \"true_training_loss\": [init_true_train_loss] + [0 for t in range(args.num_epoch)],\n",
    "    \"noisy_training_loss\": [init_noisy_train_loss] + [0 for t in range(args.num_epoch)],\n",
    "    \"testing_loss\": [init_test_loss] + [0 for t in range(args.num_epoch)],\n",
    "    \"true_training_accuracy\": [init_true_train_acc] + [0 for t in range(args.num_epoch)],\n",
    "    \"noisy_training_accuracy\": [init_noisy_train_acc] + [0 for t in range(args.num_epoch)],\n",
    "    \"testing_accuracy\": [init_test_acc] + [0 for t in range(args.num_epoch)]\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "id": "022d1fc0-e6ff-4130-88a8-539522cb0f3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Record per-sample weights ##\n",
    "result_collection = {}\n",
    "for i in range(len(train_dataset)):\n",
    "    \n",
    "    if i not in result_collection:\n",
    "        \n",
    "        if i in train_dataset.noisy_indices: \n",
    "            status = \"incorrect\"\n",
    "        else:\n",
    "            status = \"correct\"\n",
    "            \n",
    "        result_collection[i] = {\n",
    "            \"status\": status, \\\n",
    "            \"per_sample_weights\": [0 for t in range(args.num_epoch)]\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "e10cc087-3b3e-4c9f-80f7-f9633f66e066",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/tmp/ipykernel_1756324/2872778365.py:15: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  reference_model = torch.load(reference_model_path).to(device)\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:10<00:00,  2.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ref. Model -- Test Loss: 1.6038, Test Acc: 0.6069\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "## Load reference model ##\n",
    "if args.mode == \"grad-mimic\":\n",
    "\n",
    "    if args.pretrained == True and args.linear_probing == True:\n",
    "        reference_model_path = f\"./saved_models/pretrained_linear_probing/{args.model_arch}_{args.dataset_name}_run{args.num_epoch}epochs_grad-descent_optadamw_noisy0.0_seed123.pt\"\n",
    "    elif args.pretrained == True and args.linear_probing == False:\n",
    "        reference_model_path = f\"./saved_models/pretrained_fine_tune_all/{args.model_arch}_{args.dataset_name}_run{args.num_epoch}epochs_grad-descent_optadamw_noisy0.0_seed123.pt\"\n",
    "    else:\n",
    "        reference_model_path = f\"./saved_models/train_from_scratch/{args.model_arch}_{args.dataset_name}_run{args.num_epoch}epochs_grad-descent_optadamw_noisy0.0_seed123.pt\"\n",
    "        \n",
    "    if os.path.exists(reference_model_path) == False:\n",
    "        print(\"** Reference model is not existed. Change to grad-descent's mode. Please run grad-mimic later. **\")\n",
    "        args.mode = \"grad-descent\"\n",
    "    else:\n",
    "        reference_model = torch.load(reference_model_path).to(device)\n",
    "        reference_model.eval()\n",
    "        ## Evaluate the reference model ##\n",
    "        ref_test_loss, ref_test_acc = evaluate(reference_model, test_loader, device)\n",
    "        print(f'Ref. Model -- Test Loss: {ref_test_loss:.4f}, Test Acc: {ref_test_acc:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "227ef7a6-f812-4211-98b8-69e5a8ca1757",
   "metadata": {},
   "outputs": [],
   "source": [
    "if args.optimizer == \"adamw\":\n",
    "    optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-5)\n",
    "elif args.optimizer == \"sgd\":\n",
    "    optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, weight_decay=1e-5, momentum=0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "id": "c6aa9b78-9528-451d-bed9-0ef080e9fa38",
   "metadata": {},
   "outputs": [],
   "source": [
    "args.mode = \"agra\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "id": "aa2e7188-31f7-4ffe-aceb-18cc0740d6e6",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                                                                                                        | 0/10 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train the model\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|                                                                                                                                                                        | 0/59 [00:00<?, ?it/s]\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.0000, 0.0000, 0.0000, 0.1114, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0520, 0.0762, 0.0763, 0.0000, 0.0000, 0.0967, 0.1898, 0.0000, 0.0540,\n",
      "        0.0000, 0.0289, 0.2399, 0.0739, 0.0277, 0.0000, 0.0554, 0.0171, 0.0968,\n",
      "        0.0000, 0.0665, 0.1114, 0.0789, 0.0000], device='cuda:2')\n",
      "tensor(17, device='cuda:2')\n",
      "tensor([0.0000, 0.0000, 0.0000, 0.0588, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0588, 0.0588, 0.0588, 0.0000, 0.0000, 0.0588, 0.0588, 0.0000, 0.0588,\n",
      "        0.0000, 0.0588, 0.0588, 0.0588, 0.0588, 0.0000, 0.0588, 0.0588, 0.0588,\n",
      "        0.0000, 0.0588, 0.0588, 0.0588, 0.0000], device='cuda:2')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  2%|██▋                                                                                                                                                             | 1/59 [00:02<02:05,  2.16s/it]\u001b[A\n",
      "  3%|█████▍                                                                                                                                                          | 2/59 [00:03<01:51,  1.95s/it]\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.0000, 0.1725, 0.1472, 0.0807, 0.0000, 0.0000, 0.0000, 0.0000, 0.0873,\n",
      "        0.0000, 0.0169, 0.0689, 0.0000, 0.0841, 0.0000, 0.0937, 0.0000, 0.1676,\n",
      "        0.0000, 0.0000, 0.2203, 0.0000, 0.0878, 0.0298, 0.0000, 0.0413, 0.0000,\n",
      "        0.0000, 0.1041, 0.0000, 0.0000, 0.0000], device='cuda:2')\n",
      "tensor(14, device='cuda:2')\n",
      "tensor([0.0000, 0.0714, 0.0714, 0.0714, 0.0000, 0.0000, 0.0000, 0.0000, 0.0714,\n",
      "        0.0000, 0.0714, 0.0714, 0.0000, 0.0714, 0.0000, 0.0714, 0.0000, 0.0714,\n",
      "        0.0000, 0.0000, 0.0714, 0.0000, 0.0714, 0.0714, 0.0000, 0.0714, 0.0000,\n",
      "        0.0000, 0.0714, 0.0000, 0.0000, 0.0000], device='cuda:2')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|█████▍                                                                                                                                                          | 2/59 [00:05<02:26,  2.56s/it]\n",
      "  0%|                                                                                                                                                                        | 0/10 [00:05<?, ?it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[93], line 53\u001b[0m\n\u001b[1;32m     50\u001b[0m comp_noisy_labels \u001b[38;5;241m=\u001b[39m comp_noisy_labels\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m     52\u001b[0m \u001b[38;5;66;03m## Computer per-sample grad for both ##\u001b[39;00m\n\u001b[0;32m---> 53\u001b[0m comp_ft_per_sample_grads \u001b[38;5;241m=\u001b[39m \u001b[43mget_per_sample_gradients\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcomp_inputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcomp_noisy_labels\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     54\u001b[0m ft_per_sample_grads \u001b[38;5;241m=\u001b[39m get_per_sample_gradients(model, inputs, noisy_labels)\n\u001b[1;32m     56\u001b[0m \u001b[38;5;66;03m## Compute mean gradient on compared batch ##\u001b[39;00m\n",
      "Cell \u001b[0;32mIn[71], line 8\u001b[0m, in \u001b[0;36mget_per_sample_gradients\u001b[0;34m(model, inputs, targets)\u001b[0m\n\u001b[1;32m      6\u001b[0m ft_grad \u001b[38;5;241m=\u001b[39m grad(cost_model, argnums\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m      7\u001b[0m ft_all_grads \u001b[38;5;241m=\u001b[39m vmap(ft_grad, in_dims \u001b[38;5;241m=\u001b[39m (\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m))\n\u001b[0;32m----> 8\u001b[0m ft_per_sample_grads \u001b[38;5;241m=\u001b[39m \u001b[43mft_all_grads\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtargets\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbuffers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ft_per_sample_grads\n",
      "File \u001b[0;32m/miniforge/envs/iris/lib/python3.10/site-packages/torch/_functorch/apis.py:201\u001b[0m, in \u001b[0;36mvmap.<locals>.wrapped\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    200\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapped\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 201\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mvmap_impl\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    202\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_dims\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout_dims\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrandomness\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mchunk_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m    203\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/miniforge/envs/iris/lib/python3.10/site-packages/torch/_functorch/vmap.py:331\u001b[0m, in \u001b[0;36mvmap_impl\u001b[0;34m(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)\u001b[0m\n\u001b[1;32m    320\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m _chunked_vmap(\n\u001b[1;32m    321\u001b[0m         func,\n\u001b[1;32m    322\u001b[0m         flat_in_dims,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    327\u001b[0m         \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m    328\u001b[0m     )\n\u001b[1;32m    330\u001b[0m \u001b[38;5;66;03m# If chunk_size is not specified.\u001b[39;00m\n\u001b[0;32m--> 331\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_flat_vmap\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    332\u001b[0m \u001b[43m    \u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    333\u001b[0m \u001b[43m    \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    334\u001b[0m \u001b[43m    \u001b[49m\u001b[43mflat_in_dims\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    335\u001b[0m \u001b[43m    \u001b[49m\u001b[43mflat_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    336\u001b[0m \u001b[43m    \u001b[49m\u001b[43margs_spec\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    337\u001b[0m \u001b[43m    \u001b[49m\u001b[43mout_dims\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    338\u001b[0m \u001b[43m    \u001b[49m\u001b[43mrandomness\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    339\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    340\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/miniforge/envs/iris/lib/python3.10/site-packages/torch/_functorch/vmap.py:48\u001b[0m, in \u001b[0;36mdoesnt_support_saved_tensors_hooks.<locals>.fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     45\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(f)\n\u001b[1;32m     46\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfn\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m     47\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mgraph\u001b[38;5;241m.\u001b[39mdisable_saved_tensors_hooks(message):\n\u001b[0;32m---> 48\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/miniforge/envs/iris/lib/python3.10/site-packages/torch/_functorch/vmap.py:480\u001b[0m, in \u001b[0;36m_flat_vmap\u001b[0;34m(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)\u001b[0m\n\u001b[1;32m    476\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m vmap_increment_nesting(batch_size, randomness) \u001b[38;5;28;01mas\u001b[39;00m vmap_level:\n\u001b[1;32m    477\u001b[0m     batched_inputs \u001b[38;5;241m=\u001b[39m _create_batched_inputs(\n\u001b[1;32m    478\u001b[0m         flat_in_dims, flat_args, vmap_level, args_spec\n\u001b[1;32m    479\u001b[0m     )\n\u001b[0;32m--> 480\u001b[0m     batched_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mbatched_inputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    481\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)\n",
      "File \u001b[0;32m/miniforge/envs/iris/lib/python3.10/site-packages/torch/_functorch/apis.py:397\u001b[0m, in \u001b[0;36mgrad.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    396\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 397\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43meager_transforms\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgrad_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margnums\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhas_aux\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/miniforge/envs/iris/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py:1451\u001b[0m, in \u001b[0;36mgrad_impl\u001b[0;34m(func, argnums, has_aux, args, kwargs)\u001b[0m\n\u001b[1;32m   1450\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mgrad_impl\u001b[39m(func: Callable, argnums: argnums_t, has_aux: \u001b[38;5;28mbool\u001b[39m, args, kwargs):\n\u001b[0;32m-> 1451\u001b[0m     results \u001b[38;5;241m=\u001b[39m \u001b[43mgrad_and_value_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margnums\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhas_aux\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1452\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m has_aux:\n\u001b[1;32m   1453\u001b[0m         grad, (_, aux) \u001b[38;5;241m=\u001b[39m results\n",
      "File \u001b[0;32m/miniforge/envs/iris/lib/python3.10/site-packages/torch/_functorch/vmap.py:48\u001b[0m, in \u001b[0;36mdoesnt_support_saved_tensors_hooks.<locals>.fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     45\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(f)\n\u001b[1;32m     46\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfn\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m     47\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mgraph\u001b[38;5;241m.\u001b[39mdisable_saved_tensors_hooks(message):\n\u001b[0;32m---> 48\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/miniforge/envs/iris/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py:1435\u001b[0m, in \u001b[0;36mgrad_and_value_impl\u001b[0;34m(func, argnums, has_aux, args, kwargs)\u001b[0m\n\u001b[1;32m   1433\u001b[0m \u001b[38;5;66;03m# NB: need create_graph so that backward pass isn't run in no_grad mode\u001b[39;00m\n\u001b[1;32m   1434\u001b[0m flat_outputs \u001b[38;5;241m=\u001b[39m _as_tuple(output)\n\u001b[0;32m-> 1435\u001b[0m flat_grad_input \u001b[38;5;241m=\u001b[39m \u001b[43m_autograd_grad\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1436\u001b[0m \u001b[43m    \u001b[49m\u001b[43mflat_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mflat_diff_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\n\u001b[1;32m   1437\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1438\u001b[0m grad_input \u001b[38;5;241m=\u001b[39m tree_unflatten(flat_grad_input, spec)\n\u001b[1;32m   1440\u001b[0m grad_input \u001b[38;5;241m=\u001b[39m _undo_create_differentiable(grad_input, level)\n",
      "File \u001b[0;32m/miniforge/envs/iris/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py:181\u001b[0m, in \u001b[0;36m_autograd_grad\u001b[0;34m(outputs, inputs, grad_outputs, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m    179\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(diff_outputs) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m    180\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtuple\u001b[39m(torch\u001b[38;5;241m.\u001b[39mzeros_like(inp) \u001b[38;5;28;01mfor\u001b[39;00m inp \u001b[38;5;129;01min\u001b[39;00m inputs)\n\u001b[0;32m--> 181\u001b[0m grad_inputs \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgrad\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    182\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdiff_outputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    183\u001b[0m \u001b[43m    \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    184\u001b[0m \u001b[43m    \u001b[49m\u001b[43mgrad_outputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    185\u001b[0m \u001b[43m    \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    186\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    187\u001b[0m \u001b[43m    \u001b[49m\u001b[43mallow_unused\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    188\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    189\u001b[0m grad_inputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtuple\u001b[39m(\n\u001b[1;32m    190\u001b[0m     torch\u001b[38;5;241m.\u001b[39mzeros_like(inp) \u001b[38;5;28;01mif\u001b[39;00m gi \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m gi\n\u001b[1;32m    191\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m gi, inp \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(grad_inputs, inputs)\n\u001b[1;32m    192\u001b[0m )\n\u001b[1;32m    193\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m grad_inputs\n",
      "File \u001b[0;32m/miniforge/envs/iris/lib/python3.10/site-packages/torch/autograd/__init__.py:436\u001b[0m, in \u001b[0;36mgrad\u001b[0;34m(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched, materialize_grads)\u001b[0m\n\u001b[1;32m    432\u001b[0m     result \u001b[38;5;241m=\u001b[39m _vmap_internals\u001b[38;5;241m.\u001b[39m_vmap(vjp, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m, allow_none_pass_through\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)(\n\u001b[1;32m    433\u001b[0m         grad_outputs_\n\u001b[1;32m    434\u001b[0m     )\n\u001b[1;32m    435\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 436\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    437\u001b[0m \u001b[43m        \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    438\u001b[0m \u001b[43m        \u001b[49m\u001b[43mgrad_outputs_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    439\u001b[0m \u001b[43m        \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    440\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    441\u001b[0m \u001b[43m        \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    442\u001b[0m \u001b[43m        \u001b[49m\u001b[43mallow_unused\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    443\u001b[0m \u001b[43m        \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    444\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    445\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m materialize_grads:\n\u001b[1;32m    446\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28many\u001b[39m(\n\u001b[1;32m    447\u001b[0m         result[i] \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_tensor_like(inputs[i])\n\u001b[1;32m    448\u001b[0m         \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(inputs))\n\u001b[1;32m    449\u001b[0m     ):\n",
      "File \u001b[0;32m/miniforge/envs/iris/lib/python3.10/site-packages/torch/autograd/graph.py:769\u001b[0m, in \u001b[0;36m_engine_run_backward\u001b[0;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m    767\u001b[0m     unregister_hooks \u001b[38;5;241m=\u001b[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[1;32m    768\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 769\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m    770\u001b[0m \u001b[43m        \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m    771\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[1;32m    772\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m    773\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for epoch in tqdm(range(args.num_epoch), total=args.num_epoch):\n",
    "        \n",
    "    print(\"Train the model\")\n",
    "    model.train()\n",
    "    running_noisy_loss = 0.0\n",
    "    running_true_loss = 0.0\n",
    "    noisy_correct = 0\n",
    "    true_correct = 0\n",
    "    total = 0\n",
    "\n",
    "    for batch_idx, (indices, inputs, true_labels, noisy_labels) in tqdm(enumerate(train_loader), total=len(train_loader)):\n",
    "    \n",
    "        noisy_labels = noisy_labels.type(torch.int64)\n",
    "        true_labels = true_labels.type(torch.int64)\n",
    "            \n",
    "        optimizer.zero_grad()\n",
    "        inputs, true_labels, noisy_labels = inputs.to(device), true_labels.to(device), noisy_labels.to(device)\n",
    "        \n",
    "        ## Grad-Norm algorithm (our baseline) ##\n",
    "        if args.mode == \"grad-norm\" and epoch >= args.starting_epoch:\n",
    "            \n",
    "            ## Compute per-sample grad ##\n",
    "            ft_per_sample_grads = get_per_sample_gradients(model, inputs, noisy_labels)\n",
    "\n",
    "            ## Compute gradient norm ##\n",
    "            specific_layer_grads = ft_per_sample_grads[args.mimic_layer]\n",
    "            per_sample_weights = torch.zeros(batch_size, device=device)\n",
    "            for g, _grad in enumerate(specific_layer_grads):\n",
    "                per_sample_weights[g] = torch.norm(_grad).item()\n",
    "            per_sample_weights = torch.nn.Softmax(dim=0)(per_sample_weights)\n",
    "\n",
    "            ## Record things ##\n",
    "            for loc, datapoint_index in enumerate(indices):\n",
    "                result_collection[datapoint_index.item()][\"per_sample_weights\"][epoch] = per_sample_weights[loc].item()\n",
    "\n",
    "            ## Calibrate Gradients ##\n",
    "            gradient_calibration(model, per_sample_weights, ft_per_sample_grads, args.mimic_layer, args.calibrate_mimic_layer_only)\n",
    "            optimizer.step()\n",
    "\n",
    "        ## AGRA algorithm (our competitor) ##\n",
    "        elif args.mode == \"agra\" and epoch >= args.starting_epoch:\n",
    "\n",
    "            if args.method != \"cos\":\n",
    "                    args.method = \"cos\"\n",
    "                \n",
    "            ## Grab another batch ##\n",
    "            comp_indices, comp_inputs, comp_true_labels, comp_noisy_labels = next(iter(train_loader))\n",
    "            comp_inputs = comp_inputs.to(device)\n",
    "            comp_noisy_labels = comp_noisy_labels.type(torch.int64)\n",
    "            comp_noisy_labels = comp_noisy_labels.to(device)\n",
    "\n",
    "            ## Computer per-sample grad for both ##\n",
    "            comp_ft_per_sample_grads = get_per_sample_gradients(model, comp_inputs, comp_noisy_labels)\n",
    "            ft_per_sample_grads = get_per_sample_gradients(model, inputs, noisy_labels)\n",
    "\n",
    "            ## Compute mean gradient on compared batch ##\n",
    "            comp_specific_layer_grad_mean = comp_ft_per_sample_grads[args.mimic_layer].mean(axis=0)\n",
    "            specific_layer_grads = ft_per_sample_grads[args.mimic_layer]\n",
    "            \n",
    "            per_sample_weights = compute_similarity(comp_specific_layer_grad_mean, specific_layer_grads, method=args.method)\n",
    "            \n",
    "            ## if similarity is negative, we discard it ##\n",
    "            per_sample_weights = torch.clamp(per_sample_weights, min=0)\n",
    "            print(per_sample_weights)\n",
    "            non_zero_count = (per_sample_weights > 0).sum()\n",
    "            print(non_zero_count)\n",
    "            new_per_sample_weights = torch.zeros_like(per_sample_weights)\n",
    "            new_per_sample_weights[per_sample_weights > 0] = 1.0 / non_zero_count.item()\n",
    "            print(new_per_sample_weights)\n",
    "\n",
    "            ## Calibrate Gradients ##\n",
    "            gradient_calibration(model, new_per_sample_weights, ft_per_sample_grads, args.mimic_layer, args.calibrate_mimic_layer_only)\n",
    "            optimizer.step()\n",
    "        \n",
    "        ## Grad-Match algorithm (our competitor) ##\n",
    "        elif args.mode == \"grad-match\" and epoch >= args.starting_epoch:\n",
    "            \n",
    "            ## Compute per-sample grad ##\n",
    "            ft_per_sample_grads = get_per_sample_gradients(model, inputs, noisy_labels)\n",
    "\n",
    "            ## Compute mean gradient ##\n",
    "            specific_layer_grad_mean = ft_per_sample_grads[args.mimic_layer].mean(axis=0)\n",
    "            specific_layer_grads = ft_per_sample_grads[args.mimic_layer]\n",
    "\n",
    "            ## Solve subset selection problem ##\n",
    "            specific_layer_grad_mean = specific_layer_grad_mean.cpu().detach().numpy()\n",
    "            specific_layer_grads = specific_layer_grads.cpu().detach().numpy()\n",
    "            per_sample_weights = solve_subset_selection(specific_layer_grad_mean, specific_layer_grads, inputs.size(0), args.lambda_value, args.norm_way, device)\n",
    "\n",
    "            ## Record things ##\n",
    "            for loc, datapoint_index in enumerate(indices):\n",
    "                result_collection[datapoint_index.item()][\"per_sample_weights\"][epoch] = per_sample_weights[loc].item()\n",
    "\n",
    "            ## Calibrate Gradients ##\n",
    "            gradient_calibration(model, per_sample_weights, ft_per_sample_grads, args.mimic_layer, args.calibrate_mimic_layer_only)\n",
    "            optimizer.step()\n",
    "\n",
    "        ## Grad-Mimic algorithm (our method) ##\n",
    "        elif args.mode == \"grad-mimic\" and epoch >= args.starting_epoch:\n",
    "            \n",
    "            ## Compute per-sample grad ##\n",
    "            ft_per_sample_grads = get_per_sample_gradients(model, inputs, noisy_labels)\n",
    "\n",
    "            ## Compute task vector and negative gradients ##\n",
    "            specific_layer_task_vector = compute_task_vector(model, reference_model, args.mimic_layer)\n",
    "            specific_layer_neg_grads = -ft_per_sample_grads[args.mimic_layer]\n",
    "\n",
    "            if args.method != \"opt\":\n",
    "                per_sample_weights = compute_similarity(specific_layer_task_vector, specific_layer_neg_grads, args.method, args.temperature)\n",
    "            else:\n",
    "                ## Solve subset selection problem ##\n",
    "                specific_layer_task_vector = specific_layer_task_vector.cpu().detach().numpy()\n",
    "                specific_layer_neg_grads = specific_layer_neg_grads.cpu().detach().numpy()\n",
    "                per_sample_weights = solve_subset_selection(specific_layer_task_vector, specific_layer_neg_grads, inputs.size(0), args.lambda_value, args.norm_way, device)\n",
    "            \n",
    "            ## Record things ##\n",
    "            for loc, datapoint_index in enumerate(indices):\n",
    "                result_collection[datapoint_index.item()][\"per_sample_weights\"][epoch] = per_sample_weights[loc].item()\n",
    "\n",
    "            ## Calibrate Gradients ##\n",
    "            gradient_calibration(model, per_sample_weights, ft_per_sample_grads, args.mimic_layer, args.calibrate_mimic_layer_only)\n",
    "            optimizer.step()\n",
    "\n",
    "        ## Grad-Descent algorithm (our baseline) ##\n",
    "        else:\n",
    "            outputs = model(inputs)\n",
    "            loss = torch.nn.CrossEntropyLoss()(outputs, noisy_labels)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "        updated_outputs = model(inputs)\n",
    "        \n",
    "        updated_noisy_loss = torch.nn.CrossEntropyLoss()(updated_outputs, noisy_labels)\n",
    "        running_noisy_loss += updated_noisy_loss.item() * true_labels.size(0)\n",
    "\n",
    "        updated_true_loss = torch.nn.CrossEntropyLoss()(updated_outputs, true_labels)\n",
    "        running_true_loss += updated_true_loss.item() * true_labels.size(0)\n",
    "        \n",
    "        _, predicted = updated_outputs.max(1)\n",
    "        \n",
    "        total += true_labels.size(0)\n",
    "        noisy_correct += predicted.eq(noisy_labels).sum().item()\n",
    "        true_correct += predicted.eq(true_labels).sum().item()\n",
    "        \n",
    "    noisy_train_loss = running_noisy_loss / total\n",
    "    noisy_train_acc = noisy_correct / total\n",
    "    \n",
    "    true_train_loss = running_true_loss / total\n",
    "    true_train_acc = true_correct / total\n",
    "\n",
    "    print(\"Evaluate the model\")\n",
    "    test_loss, test_acc = evaluate(model, test_loader, device)\n",
    "\n",
    "    print(f'Epoch {epoch + 1}/{args.num_epoch}')\n",
    "    print(f'Noisy Train Loss: {noisy_train_loss:.4f}, Noisy Train Acc: {noisy_train_acc:.4f}')\n",
    "    print(f'True Train Loss: {true_train_loss:.4f}, True Train Acc: {true_train_acc:.4f}')\n",
    "    print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')\n",
    "    \n",
    "    learning_results[\"true_training_loss\"][epoch + 1] = true_train_loss\n",
    "    learning_results[\"noisy_training_loss\"][epoch + 1] = noisy_train_loss\n",
    "    learning_results[\"testing_loss\"][epoch + 1] = test_loss\n",
    "\n",
    "    learning_results[\"true_training_accuracy\"][epoch + 1] = true_train_acc\n",
    "    learning_results[\"noisy_training_accuracy\"][epoch + 1] = noisy_train_acc\n",
    "    learning_results[\"testing_accuracy\"][epoch + 1] = test_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b98b5ce-8aad-4401-a9c3-c90770d01362",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "516e33f4-08de-448d-9adb-9d05e5f7c1b8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c724707-5ab1-48d9-a5bc-712e8770304b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "## save model and results ##\n",
    "if args.pretrained == True and args.linear_probing == True: sub_folder = \"pretrained_linear_probing\"\n",
    "elif args.pretrained == True and args.linear_probing == False: sub_folder = \"pretrained_fine_tune_all\"\n",
    "else: sub_folder = \"train_from_scratch\"\n",
    "\n",
    "if args.calibrate_mimic_layer_only == True: layer_only = \"only\"\n",
    "else: layer_only = \"not_only\"\n",
    "        \n",
    "if args.mode == \"grad-mimic\" or args.mode == \"grad-match\":\n",
    "\n",
    "    with open(f'./saved_logs/{sub_folder}/{args.model_arch}_{args.mimic_layer}_{layer_only}_{args.dataset_name}_{args.method}_temp{args.temperature}_run{args.num_epoch}epochs_{args.mode}_opt{args.optimizer}_startat{args.starting_epoch}_noisy{args.noisy_level}_reg{args.lambda_value}_{args.norm_way}norm_seed{args.seed}_results.pkl', 'wb') as f:\n",
    "        pickle.dump(learning_results, f)\n",
    "\n",
    "    with open(f'./saved_logs/{sub_folder}/{args.model_arch}_{args.mimic_layer}_{layer_only}_{args.dataset_name}_{args.method}_temp{args.temperature}_run{args.num_epoch}epochs_{args.mode}_opt{args.optimizer}_startat{args.starting_epoch}_noisy{args.noisy_level}_reg{args.lambda_value}_{args.norm_way}norm_seed{args.seed}_weights.pkl', 'wb') as f:\n",
    "        pickle.dump(result_collection, f)\n",
    "    \n",
    "elif args.mode == \"grad-descent\":\n",
    "\n",
    "    with open(f'./saved_logs/{sub_folder}/{args.model_arch}_{args.dataset_name}_run{args.num_epoch}epochs_{args.mode}_opt{args.optimizer}_noisy{args.noisy_level}_seed{args.seed}_results.pkl', 'wb') as f:\n",
    "        pickle.dump(learning_results, f)\n",
    "\n",
    "    if args.noisy_level == 0.0:\n",
    "        torch.save(model, f'./saved_models/{sub_folder}/{args.model_arch}_{args.dataset_name}_run{args.num_epoch}epochs_{args.mode}_opt{args.optimizer}_noisy{args.noisy_level}_seed{args.seed}.pt')\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12f35dcf-1ae4-4d08-985a-1f34ba716b99",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7198cd73-f472-4098-a10b-e92f6ecc75da",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
