{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4f270637-c76d-499b-bc84-0347531dd176",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f336e4e6-e202-40a7-bf65-07d7d6db9da3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from utils import *\n",
    "from agents import *\n",
    "import time\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import copy\n",
    "import torch.nn.functional as F\n",
    "from copy import deepcopy\n",
    "import argparse\n",
    "from torchvision import datasets, transforms\n",
    "from torch.utils.data import DataLoader, Subset\n",
    "import numpy as np\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from models import resnet18\n",
    "import random\n",
    "import timm\n",
    "import math\n",
    "from ov_utils import *\n",
    "from proto_utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5fb6b426-27e2-4070-beab-5ea3b04c1bb0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "unlearn classes: [0]\n",
      "Attack subset size: 50000\n",
      "  → Unlearn class samples: 5000\n",
      "  → Remain class samples: 45000\n"
     ]
    }
   ],
   "source": [
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument('--batch_size',     type=int,       default=128)\n",
    "parser.add_argument('--dataset', type=str, default='cifar10')\n",
    "parser.add_argument('--model', type=str, default='resnet18')\n",
    "args = parser.parse_args(\"\")\n",
    "\n",
    "\n",
    "exp_path = f'checkpoints_{args.dataset}'\n",
    "device = 'cuda'\n",
    "seed_everything(42) # Choose the same seed with baseline trainer\n",
    "\n",
    "if args.dataset == 'cifar100':\n",
    "    transform_train = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))\n",
    "    ])\n",
    "    transform_test = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))\n",
    "    ])\n",
    "\n",
    "elif args.dataset == 'cifar10':\n",
    "    transform_train = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
    "    ])\n",
    "\n",
    "    transform_test = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
    "    ])\n",
    "\n",
    "if args.dataset == 'cifar100':\n",
    "    trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)\n",
    "    testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)\n",
    "\n",
    "elif args.dataset == 'cifar10':\n",
    "    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)\n",
    "    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)\n",
    "    \n",
    "train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)\n",
    "test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)\n",
    "\n",
    "if args.dataset != 'cifar10':\n",
    "    num_unlearn_classes = 10\n",
    "else:\n",
    "    num_unlearn_classes = 1\n",
    "\n",
    "if args.dataset == 'cifar100':\n",
    "    all_classes = list(range(100))\n",
    "    num_classes = 100\n",
    "elif args.dataset == 'cifar10':\n",
    "    all_classes = list(range(10))\n",
    "    num_classes = 10\n",
    "    \n",
    "unlearn_classes = [0] # Use same unlearn classes with excluded labels\n",
    "remain_classes = [cls for cls in all_classes if cls not in unlearn_classes]\n",
    "print(\"unlearn classes:\", unlearn_classes)\n",
    "args.unlearn_class = unlearn_classes\n",
    "\n",
    "attack_subset, unlearn_attack_loader, remain_attack_loader, counts = create_attack_loaders(trainset, 100, unlearn_classes, remain_classes, args.batch_size, shuffle=True, seed=42)\n",
    "\n",
    "unlearn_indices = [i for i, target in enumerate(trainset.targets) if target in unlearn_classes]\n",
    "remain_indices  = [i for i, target in enumerate(trainset.targets) if target not in unlearn_classes]\n",
    "test_unlearn_indices = [i for i, target in enumerate(testset.targets) if target in unlearn_classes]\n",
    "test_remain_indices  = [i for i, target in enumerate(testset.targets) if target not in unlearn_classes]\n",
    "\n",
    "unlearn_trainset = Subset(trainset, unlearn_indices)\n",
    "remain_trainset = Subset(trainset, remain_indices)\n",
    "unlearn_testset = Subset(testset, test_unlearn_indices)\n",
    "remain_testset = Subset(testset, test_remain_indices)\n",
    "\n",
    "unlearn_train_loader = DataLoader(unlearn_trainset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)\n",
    "remain_train_loader  = DataLoader(remain_trainset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)\n",
    "unlearn_test_loader = DataLoader(unlearn_testset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)\n",
    "remain_test_loader  = DataLoader(remain_testset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bcd51eeb-5bf5-430e-bc41-a275d2d2a159",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/seungb/miniconda3/envs/mu/lib/python3.10/site-packages/torch/cuda/__init__.py:611: UserWarning: Can't initialize NVML\n",
      "  warnings.warn(\"Can't initialize NVML\")\n"
     ]
    }
   ],
   "source": [
    "# Original Model Load\n",
    "model = resnet18(num_classes=num_classes)\n",
    "model.load_state_dict(torch.load(f'{exp_path}/resnet18_{args.dataset}_best.pth'))\n",
    "model = model.to(device)\n",
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f16e6c43-8c2a-4f9d-bf2d-13c1f2ae88dd",
   "metadata": {},
   "source": [
    "## Perturbed Set Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "269a997c-8e3f-43c8-9afa-0ab8ab30f3fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "adv_test_soft_loader = create_adversarial_loader_pgd_multi(\n",
    "            model,\n",
    "            unlearn_train_loader,\n",
    "            epsilon=0.03,\n",
    "            step_size=0.01,\n",
    "            num_steps=3,\n",
    "            num_adv_per_sample=1,\n",
    "            device=device\n",
    ")\n",
    "\n",
    "gaussian_test_soft_loader = create_adversarial_loader_gaussian(\n",
    "            model,\n",
    "            unlearn_train_loader,\n",
    "            noise_std=0.01,\n",
    "            num_adv_per_sample=1,\n",
    "            device='cuda'\n",
    ")\n",
    "\n",
    "adv_train_loader = create_adversarial_loader_pgd_multi(\n",
    "            model,\n",
    "            unlearn_train_loader,\n",
    "            epsilon=0.03,\n",
    "            step_size=0.01,\n",
    "            num_steps=3,\n",
    "            num_adv_per_sample=1,\n",
    "            device=device\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e24ef98e-6055-43db-9246-5a8960ff55c4",
   "metadata": {},
   "source": [
    "## Original-Eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bf8d6907-7a6f-45ab-b7a2-99845fba8813",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Forget Acc.] Loss: 0.001, Accuracy: 100.00% (5000/5000)\n",
      "[Retain Acc.] Loss: 0.001, Accuracy: 100.00% (45000/45000)\n",
      "[Test-Forget Acc.] Loss: 0.179, Accuracy: 94.70% (947/1000)\n",
      "[Test-Retain Acc.] Loss: 0.238, Accuracy: 93.98% (8458/9000)\n",
      "{'Original': 0.6213402444839478}\n",
      "{'Original': 0.5742272691726684}\n",
      "[Forget Acc.] Loss: 0.001, Accuracy: 100.00% (5000/5000)\n",
      "[Retain Acc.] Loss: 0.001, Accuracy: 100.00% (45000/45000)\n",
      "tensor(0.9177, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "ul = 'Original'\n",
    "model.eval()\n",
    "evaluate(model, unlearn_train_loader, \"Forget Acc.\", device, criterion)\n",
    "evaluate(model, remain_train_loader, \"Retain Acc.\", device, criterion)\n",
    "evaluate(model, unlearn_test_loader, \"Test-Forget Acc.\", device, criterion)\n",
    "evaluate(model, remain_test_loader, \"Test-Retain Acc.\", device, criterion)\n",
    "models_Fu = {}\n",
    "models_Fu[ul] = model\n",
    "\n",
    "# Prototypical Relearning Attack\n",
    "eval_model = update_fc_with_prototypes(\n",
    "    model,\n",
    "    unlearn_attack_loader,\n",
    "    unlearn_classes,\n",
    "    num_samples_per_class=5,\n",
    "    metric='cosine',\n",
    "    device='cuda'\n",
    ")\n",
    "\n",
    "evaluate(eval_model, unlearn_train_loader, \"Forget Acc.\", device, criterion)\n",
    "evaluate(eval_model, remain_train_loader, \"Retain Acc.\", device, criterion)\n",
    "\n",
    "print(compute_dispersion_loss(eval_model, unlearn_train_loader))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df5d208b-63d8-4598-8453-b953285f3d9b",
   "metadata": {},
   "source": [
    "## Retrain-Eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "eb70e8bb-deeb-4dbe-9e52-e00fe2d554ff",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Forget Acc.] Loss: 9.394, Accuracy: 0.00% (0/5000)\n",
      "[Retain Acc.] Loss: 0.001, Accuracy: 100.00% (45000/45000)\n",
      "[Test-Forget Acc.] Loss: 9.547, Accuracy: 0.00% (0/1000)\n",
      "[Test-Retain Acc.] Loss: 0.218, Accuracy: 94.71% (8524/9000)\n",
      "PGD OU@epsilon : 0.23585119132995605\n",
      "Gaussian OU@epsilon : 0.17794286270141602\n",
      "[Proto-Forget Acc.] Loss: 1.081, Accuracy: 56.02% (2801/5000)\n",
      "[Retain^* Acc.] Loss: 0.016, Accuracy: 99.91% (44961/45000)\n",
      "tensor(0.6909, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "ul = 'retrain'\n",
    "eval_model = resnet18(num_classes=num_classes).to(device)\n",
    "eval_model.load_state_dict(torch.load(f'{exp_path}/resnet18_{args.dataset}_{ul}_best.pth'))\n",
    "\n",
    "evaluate(eval_model, unlearn_train_loader, \"Forget Acc.\", device, criterion)\n",
    "evaluate(eval_model, remain_train_loader, \"Retain Acc.\", device, criterion)\n",
    "evaluate(eval_model, unlearn_test_loader, \"Test-Forget Acc.\", device, criterion)\n",
    "evaluate(eval_model, remain_test_loader, \"Test-Retain Acc.\", device, criterion)\n",
    "models_Fu = {}\n",
    "models_Fu[ul] = eval_model\n",
    "overunlearning_scores = compare_unlearning_methods(model, models_Fu, adv_test_soft_loader, unlearn_classes, 'JS', device=device)\n",
    "print(\"PGD OU@epsilon :\", overunlearning_scores[ul])\n",
    "overunlearning_scores = compare_unlearning_methods(model, models_Fu, gaussian_test_soft_loader, unlearn_classes, 'JS', device=device)\n",
    "print(\"Gaussian OU@epsilon :\", overunlearning_scores[ul])\n",
    "\n",
    "# Prototypical Relearning Attack\n",
    "eval_model = update_fc_with_prototypes(\n",
    "        eval_model,\n",
    "        unlearn_attack_loader,\n",
    "        unlearn_classes,\n",
    "        num_samples_per_class=5,\n",
    "        metric='cosine',\n",
    "        device='cuda',\n",
    "        alpha=1\n",
    ")\n",
    "\n",
    "evaluate(eval_model, unlearn_train_loader, \"Proto-Forget Acc.\", device, criterion)\n",
    "evaluate(eval_model, remain_train_loader, \"Retain^* Acc.\", device, criterion)\n",
    "\n",
    "print(compute_dispersion_loss(eval_model, unlearn_train_loader))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c95d7310-0cd3-499b-bfd0-4b00c1324eae",
   "metadata": {},
   "source": [
    "## Spotter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "9ecca1d8-459c-4fc9-99b4-5d43f42c33ff",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10 - Total Loss: 64.7059 - UL Loss: 30.9605 - Adv Loss: 33.1906 - Disp Loss: 33.0763\n",
      "Epoch 2/10 - Total Loss: 34.4623 - UL Loss: 10.3102 - Adv Loss: 9.8214 - Disp Loss: 24.2988\n",
      "Epoch 3/10 - Total Loss: 25.3921 - UL Loss: 10.0319 - Adv Loss: 7.4803 - Disp Loss: 16.1257\n",
      "Epoch 4/10 - Total Loss: 21.1686 - UL Loss: 9.4888 - Adv Loss: 6.1974 - Disp Loss: 12.6672\n",
      "Epoch 5/10 - Total Loss: 18.9043 - UL Loss: 9.0272 - Adv Loss: 5.5101 - Disp Loss: 10.9322\n",
      "Epoch 6/10 - Total Loss: 17.3590 - UL Loss: 8.5898 - Adv Loss: 5.0265 - Disp Loss: 9.8383\n",
      "Epoch 7/10 - Total Loss: 16.3054 - UL Loss: 8.2322 - Adv Loss: 4.7764 - Disp Loss: 9.1100\n",
      "Epoch 8/10 - Total Loss: 15.4909 - UL Loss: 7.9031 - Adv Loss: 4.5481 - Disp Loss: 8.5943\n",
      "Epoch 9/10 - Total Loss: 14.9105 - UL Loss: 7.7533 - Adv Loss: 4.4172 - Disp Loss: 8.1580\n",
      "Epoch 10/10 - Total Loss: 14.3139 - UL Loss: 7.4897 - Adv Loss: 4.2707 - Disp Loss: 7.7899\n"
     ]
    }
   ],
   "source": [
    "ul = 'Spotter'\n",
    "model.eval()\n",
    "mu_model = copy.deepcopy(model)\n",
    "\n",
    "optimizer = torch.optim.SGD(mu_model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)\n",
    "num_epochs = 10\n",
    "lambda_1 = 0.7\n",
    "lambda_2 = 1\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    mu_model.train()\n",
    "\n",
    "    for m in mu_model.modules():\n",
    "        if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):\n",
    "            m.eval()\n",
    "        \n",
    "    total_loss = 0.0\n",
    "    ul_loss = 0.0\n",
    "    adv_loss = 0.0\n",
    "    disp_loss = 0.0\n",
    "\n",
    "    for (images, labels), (adv_images, adv_labels) in zip(unlearn_train_loader, adv_train_loader):\n",
    "        images = images.to(device)\n",
    "        adv_images = adv_images.to(device)\n",
    "        labels = labels.to(device)  # 잊어야 할 클래스 레이블\n",
    "        adv_labels = adv_labels.to(device)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        with torch.no_grad():\n",
    "            teacher_logits = model(images)\n",
    "            adv_teacher_logits = model(adv_images)\n",
    "            teacher_probs = F.softmax(teacher_logits, dim=1)\n",
    "            adv_teacher_probs = F.softmax(adv_teacher_logits, dim=1)\n",
    "            teacher_target = mask_full_target_soft(teacher_probs, unlearn_classes)  # \\tilde{p}(x)\n",
    "            adv_teacher_target = mask_full_target_soft(adv_teacher_probs, unlearn_classes)\n",
    "        \n",
    "        # mu_model (student)의 출력\n",
    "        mu_logits = mu_model(images)\n",
    "        adv_mu_logits = mu_model(adv_images)\n",
    "        mu_log_probs = F.log_softmax(mu_logits, dim=1)\n",
    "        adv_mu_log_probs = F.log_softmax(adv_mu_logits, dim=1)\n",
    "\n",
    "        loss_unlearn = F.kl_div(mu_log_probs, teacher_target, reduction='batchmean')\n",
    "        adv_loss_unlearn = F.kl_div(adv_mu_log_probs, adv_teacher_target, reduction='batchmean')\n",
    "        \n",
    "        features = get_features(mu_model, images)\n",
    "        loss_disp = dispersion_loss(features, labels, metric='cosine')\n",
    "\n",
    "        loss = lambda_1 * loss_unlearn + (1 - lambda_1) * adv_loss_unlearn + lambda_2 * loss_disp\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        ul_loss += loss_unlearn.item()\n",
    "        adv_loss += adv_loss_unlearn.item()\n",
    "        disp_loss += loss_disp.item()\n",
    "        total_loss += loss.item()\n",
    "\n",
    "    print(f\"Epoch {epoch+1}/{num_epochs} - Total Loss: {total_loss:.4f} - UL Loss: {ul_loss:.4f} - Adv Loss: {adv_loss:.4f} - Disp Loss: {disp_loss:.4f}\")\n",
    "    \n",
    "torch.save(mu_model.state_dict(), f\"{exp_path}/{args.model}_{args.dataset}_{ul}.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "99d608b0-7f83-40d8-84aa-721c79274b4b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Forget Acc.] Loss: 3.103, Accuracy: 0.00% (0/5000)\n",
      "[Retain Acc.] Loss: 0.009, Accuracy: 99.98% (44989/45000)\n",
      "[Test-Forget Acc.] Loss: 3.551, Accuracy: 0.00% (0/1000)\n",
      "[Test-Retain Acc.] Loss: 0.210, Accuracy: 93.82% (8444/9000)\n",
      "PGD OU@epsilon : 0.02864329786300659\n",
      "Gaussian OU@epsilon : 0.026672714617848398\n",
      "[Proto-Forget Acc.] Loss: 2.253, Accuracy: 0.10% (5/5000)\n",
      "[Retain^* Acc.] Loss: 0.013, Accuracy: 99.98% (44989/45000)\n",
      "tensor(0.1901, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "ul = 'Spotter'\n",
    "eval_model = resnet18(num_classes=num_classes).to(device)\n",
    "eval_model.load_state_dict(torch.load(f\"{exp_path}/{args.model}_{args.dataset}_{ul}.pth\"))\n",
    "evaluate(eval_model, unlearn_train_loader, \"Forget Acc.\", device, criterion)\n",
    "evaluate(eval_model, remain_train_loader, \"Retain Acc.\", device, criterion)\n",
    "evaluate(eval_model, unlearn_test_loader, \"Test-Forget Acc.\", device, criterion)\n",
    "evaluate(eval_model, remain_test_loader, \"Test-Retain Acc.\", device, criterion)\n",
    "models_Fu = {}\n",
    "models_Fu[ul] = eval_model\n",
    "\n",
    "overunlearning_scores = compare_unlearning_methods(model, models_Fu, adv_test_soft_loader, unlearn_classes, 'JS', device=device)\n",
    "print(\"PGD OU@epsilon :\", overunlearning_scores[ul])\n",
    "overunlearning_scores = compare_unlearning_methods(model, models_Fu, gaussian_test_soft_loader, unlearn_classes, 'JS', device=device)\n",
    "print(\"Gaussian OU@epsilon :\", overunlearning_scores[ul])\n",
    "\n",
    "# Prototypical Relearning Attack\n",
    "eval_model = update_fc_with_prototypes(\n",
    "        eval_model,\n",
    "        unlearn_attack_loader,\n",
    "        unlearn_classes,\n",
    "        num_samples_per_class=5,\n",
    "        metric='cosine',\n",
    "        device='cuda',\n",
    "        alpha=0.8\n",
    ")\n",
    "\n",
    "evaluate(eval_model, unlearn_train_loader, \"Proto-Forget Acc.\", device, criterion)\n",
    "evaluate(eval_model, remain_train_loader, \"Retain^* Acc.\", device, criterion)\n",
    "\n",
    "print(compute_dispersion_loss(eval_model, unlearn_train_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81ee7239-4054-4392-a5c2-680f776cd22c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
