{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4f270637-c76d-499b-bc84-0347531dd176",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f336e4e6-e202-40a7-bf65-07d7d6db9da3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from utils import *\n",
    "from agents import *\n",
    "import time\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import copy\n",
    "import torch.nn.functional as F\n",
    "from copy import deepcopy\n",
    "import argparse\n",
    "from torchvision import datasets, transforms\n",
    "from torch.utils.data import DataLoader, Subset\n",
    "import numpy as np\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from models import resnet18\n",
    "import random\n",
    "import timm\n",
    "import math\n",
    "from ov_utils import *\n",
    "from proto_utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5fb6b426-27e2-4070-beab-5ea3b04c1bb0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "unlearn classes: [0]\n",
      "Attack subset size: 50000\n",
      "  → Unlearn class samples: 5000\n",
      "  → Remain class samples: 45000\n"
     ]
    }
   ],
   "source": [
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument('--batch_size',     type=int,       default=128)\n",
    "parser.add_argument('--dataset', type=str, default='cifar10')\n",
    "parser.add_argument('--model', type=str, default='resnet18')\n",
    "args = parser.parse_args(\"\")\n",
    "\n",
    "\n",
    "exp_path = f'checkpoints_{args.dataset}'\n",
    "device = 'cuda'\n",
    "seed_everything(42) # Choose the same seed with baseline trainer\n",
    "\n",
    "if args.dataset == 'cifar100':\n",
    "    transform_train = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))\n",
    "    ])\n",
    "    transform_test = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))\n",
    "    ])\n",
    "\n",
    "elif args.dataset == 'cifar10':\n",
    "    transform_train = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
    "    ])\n",
    "\n",
    "    transform_test = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
    "    ])\n",
    "\n",
    "if args.dataset == 'cifar100':\n",
    "    trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)\n",
    "    testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)\n",
    "\n",
    "elif args.dataset == 'cifar10':\n",
    "    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)\n",
    "    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)\n",
    "    \n",
    "train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)\n",
    "test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)\n",
    "\n",
    "if args.dataset != 'cifar10':\n",
    "    num_unlearn_classes = 10\n",
    "else:\n",
    "    num_unlearn_classes = 1\n",
    "\n",
    "if args.dataset == 'cifar100':\n",
    "    all_classes = list(range(100))\n",
    "    num_classes = 100\n",
    "elif args.dataset == 'cifar10':\n",
    "    all_classes = list(range(10))\n",
    "    num_classes = 10\n",
    "    \n",
    "unlearn_classes = [0] # Use same unlearn classes with excluded labels\n",
    "remain_classes = [cls for cls in all_classes if cls not in unlearn_classes]\n",
    "print(\"unlearn classes:\", unlearn_classes)\n",
    "args.unlearn_class = unlearn_classes\n",
    "\n",
    "attack_subset, unlearn_attack_loader, remain_attack_loader, counts = create_attack_loaders(trainset, 100, unlearn_classes, remain_classes, args.batch_size, shuffle=True, seed=42)\n",
    "\n",
    "unlearn_indices = [i for i, target in enumerate(trainset.targets) if target in unlearn_classes]\n",
    "remain_indices  = [i for i, target in enumerate(trainset.targets) if target not in unlearn_classes]\n",
    "test_unlearn_indices = [i for i, target in enumerate(testset.targets) if target in unlearn_classes]\n",
    "test_remain_indices  = [i for i, target in enumerate(testset.targets) if target not in unlearn_classes]\n",
    "\n",
    "unlearn_trainset = Subset(trainset, unlearn_indices)\n",
    "remain_trainset = Subset(trainset, remain_indices)\n",
    "unlearn_testset = Subset(testset, test_unlearn_indices)\n",
    "remain_testset = Subset(testset, test_remain_indices)\n",
    "\n",
    "unlearn_train_loader = DataLoader(unlearn_trainset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)\n",
    "remain_train_loader  = DataLoader(remain_trainset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)\n",
    "unlearn_test_loader = DataLoader(unlearn_testset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)\n",
    "remain_test_loader  = DataLoader(remain_testset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcd51eeb-5bf5-430e-bc41-a275d2d2a159",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Original Model Load\n",
    "model = resnet18(num_classes=num_classes)\n",
    "model.load_state_dict(torch.load(f'{exp_path}/resnet18_{args.dataset}_best.pth'))\n",
    "model = model.to(device)\n",
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f16e6c43-8c2a-4f9d-bf2d-13c1f2ae88dd",
   "metadata": {},
   "source": [
    "## Perturbed Set Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "269a997c-8e3f-43c8-9afa-0ab8ab30f3fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "adv_test_soft_loader = create_adversarial_loader_pgd_multi(\n",
    "            model,\n",
    "            unlearn_train_loader,\n",
    "            epsilon=0.03,\n",
    "            step_size=0.01,\n",
    "            num_steps=3,\n",
    "            num_adv_per_sample=1,\n",
    "            device=device\n",
    ")\n",
    "\n",
    "gaussian_test_soft_loader = create_adversarial_loader_gaussian(\n",
    "            model,\n",
    "            unlearn_train_loader,\n",
    "            noise_std=0.01,\n",
    "            num_adv_per_sample=1,\n",
    "            device='cuda'\n",
    ")\n",
    "\n",
    "adv_train_loader = create_adversarial_loader_pgd_multi(\n",
    "            model,\n",
    "            unlearn_train_loader,\n",
    "            epsilon=0.03,\n",
    "            step_size=0.01,\n",
    "            num_steps=3,\n",
    "            num_adv_per_sample=1,\n",
    "            device=device\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e24ef98e-6055-43db-9246-5a8960ff55c4",
   "metadata": {},
   "source": [
    "## Original-Eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf8d6907-7a6f-45ab-b7a2-99845fba8813",
   "metadata": {},
   "outputs": [],
   "source": [
    "ul = 'Original'\n",
    "model.eval()\n",
    "evaluate(model, unlearn_train_loader, \"Forget Acc.\", device, criterion)\n",
    "evaluate(model, remain_train_loader, \"Retain Acc.\", device, criterion)\n",
    "evaluate(model, unlearn_test_loader, \"Test-Forget Acc.\", device, criterion)\n",
    "evaluate(model, remain_test_loader, \"Test-Retain Acc.\", device, criterion)\n",
    "\n",
    "# Prototypical Relearning Attack\n",
    "eval_model = update_fc_with_prototypes(\n",
    "    model,\n",
    "    unlearn_attack_loader,\n",
    "    unlearn_classes,\n",
    "    num_samples_per_class=5,\n",
    "    metric='cosine',\n",
    "    device='cuda'\n",
    ")\n",
    "\n",
    "evaluate(eval_model, unlearn_train_loader, \"Forget Acc.\", device, criterion)\n",
    "evaluate(eval_model, remain_train_loader, \"Retain Acc.\", device, criterion)\n",
    "\n",
    "print(compute_dispersion_loss(eval_model, unlearn_train_loader))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df5d208b-63d8-4598-8453-b953285f3d9b",
   "metadata": {},
   "source": [
    "## Retrain-Eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb70e8bb-deeb-4dbe-9e52-e00fe2d554ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "ul = 'retrain'\n",
    "eval_model = resnet18(num_classes=num_classes).to(device)\n",
    "eval_model.load_state_dict(torch.load(f'{exp_path}/resnet18_{args.dataset}_{ul}_best.pth'))\n",
    "\n",
    "evaluate(eval_model, unlearn_train_loader, \"Forget Acc.\", device, criterion)\n",
    "evaluate(eval_model, remain_train_loader, \"Retain Acc.\", device, criterion)\n",
    "evaluate(eval_model, unlearn_test_loader, \"Test-Forget Acc.\", device, criterion)\n",
    "evaluate(eval_model, remain_test_loader, \"Test-Retain Acc.\", device, criterion)\n",
    "\n",
    "# Prototypical Relearning Attack\n",
    "eval_model = update_fc_with_prototypes(\n",
    "        eval_model,\n",
    "        unlearn_attack_loader,\n",
    "        unlearn_classes,\n",
    "        num_samples_per_class=5,\n",
    "        metric='cosine',\n",
    "        device='cuda',\n",
    "        alpha=1\n",
    ")\n",
    "\n",
    "evaluate(eval_model, unlearn_train_loader, \"Proto-Forget Acc.\", device, criterion)\n",
    "evaluate(eval_model, remain_train_loader, \"Retain^* Acc.\", device, criterion)\n",
    "\n",
    "print(compute_dispersion_loss(eval_model, unlearn_train_loader))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c95d7310-0cd3-499b-bfd0-4b00c1324eae",
   "metadata": {},
   "source": [
    "## Spotter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "9ecca1d8-459c-4fc9-99b4-5d43f42c33ff",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10 - Total Loss: 64.7059 - UL Loss: 30.9605 - Adv Loss: 33.1906 - Disp Loss: 33.0763\n",
      "Epoch 2/10 - Total Loss: 34.4623 - UL Loss: 10.3102 - Adv Loss: 9.8214 - Disp Loss: 24.2988\n",
      "Epoch 3/10 - Total Loss: 25.3921 - UL Loss: 10.0319 - Adv Loss: 7.4803 - Disp Loss: 16.1257\n",
      "Epoch 4/10 - Total Loss: 21.1686 - UL Loss: 9.4888 - Adv Loss: 6.1974 - Disp Loss: 12.6672\n",
      "Epoch 5/10 - Total Loss: 18.9043 - UL Loss: 9.0272 - Adv Loss: 5.5101 - Disp Loss: 10.9322\n",
      "Epoch 6/10 - Total Loss: 17.3590 - UL Loss: 8.5898 - Adv Loss: 5.0265 - Disp Loss: 9.8383\n",
      "Epoch 7/10 - Total Loss: 16.3054 - UL Loss: 8.2322 - Adv Loss: 4.7764 - Disp Loss: 9.1100\n",
      "Epoch 8/10 - Total Loss: 15.4909 - UL Loss: 7.9031 - Adv Loss: 4.5481 - Disp Loss: 8.5943\n",
      "Epoch 9/10 - Total Loss: 14.9105 - UL Loss: 7.7533 - Adv Loss: 4.4172 - Disp Loss: 8.1580\n",
      "Epoch 10/10 - Total Loss: 14.3139 - UL Loss: 7.4897 - Adv Loss: 4.2707 - Disp Loss: 7.7899\n"
     ]
    }
   ],
   "source": [
    "ul = 'Spotter'\n",
    "model.eval()\n",
    "mu_model = copy.deepcopy(model)\n",
    "\n",
    "optimizer = torch.optim.SGD(mu_model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)\n",
    "num_epochs = 10\n",
    "lambda_1 = 0.7\n",
    "lambda_2 = 1\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    mu_model.train()\n",
    "\n",
    "    for m in mu_model.modules():\n",
    "        if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):\n",
    "            m.eval()\n",
    "        \n",
    "    total_loss = 0.0\n",
    "    ul_loss = 0.0\n",
    "    adv_loss = 0.0\n",
    "    disp_loss = 0.0\n",
    "\n",
    "    for (images, labels), (adv_images, adv_labels) in zip(unlearn_train_loader, adv_train_loader):\n",
    "        images = images.to(device)\n",
    "        adv_images = adv_images.to(device)\n",
    "        labels = labels.to(device)  # 잊어야 할 클래스 레이블\n",
    "        adv_labels = adv_labels.to(device)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        with torch.no_grad():\n",
    "            teacher_logits = model(images)\n",
    "            adv_teacher_logits = model(adv_images)\n",
    "            teacher_probs = F.softmax(teacher_logits, dim=1)\n",
    "            adv_teacher_probs = F.softmax(adv_teacher_logits, dim=1)\n",
    "            teacher_target = mask_full_target_soft(teacher_probs, unlearn_classes)  # \\tilde{p}(x)\n",
    "            adv_teacher_target = mask_full_target_soft(adv_teacher_probs, unlearn_classes)\n",
    "        \n",
    "        # mu_model (student)의 출력\n",
    "        mu_logits = mu_model(images)\n",
    "        adv_mu_logits = mu_model(adv_images)\n",
    "        mu_log_probs = F.log_softmax(mu_logits, dim=1)\n",
    "        adv_mu_log_probs = F.log_softmax(adv_mu_logits, dim=1)\n",
    "\n",
    "        loss_unlearn = F.kl_div(mu_log_probs, teacher_target, reduction='batchmean')\n",
    "        adv_loss_unlearn = F.kl_div(adv_mu_log_probs, adv_teacher_target, reduction='batchmean')\n",
    "        \n",
    "        features = get_features(mu_model, images)\n",
    "        loss_disp = dispersion_loss(features, labels, metric='cosine')\n",
    "\n",
    "        loss = lambda_1 * loss_unlearn + (1 - lambda_1) * adv_loss_unlearn + lambda_2 * loss_disp\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        ul_loss += loss_unlearn.item()\n",
    "        adv_loss += adv_loss_unlearn.item()\n",
    "        disp_loss += loss_disp.item()\n",
    "        total_loss += loss.item()\n",
    "\n",
    "    print(f\"Epoch {epoch+1}/{num_epochs} - Total Loss: {total_loss:.4f} - UL Loss: {ul_loss:.4f} - Adv Loss: {adv_loss:.4f} - Disp Loss: {disp_loss:.4f}\")\n",
    "    \n",
    "torch.save(mu_model.state_dict(), f\"{exp_path}/{args.model}_{args.dataset}_{ul}.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "99d608b0-7f83-40d8-84aa-721c79274b4b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Forget Acc.] Loss: 3.103, Accuracy: 0.00% (0/5000)\n",
      "[Retain Acc.] Loss: 0.009, Accuracy: 99.98% (44989/45000)\n",
      "[Test-Forget Acc.] Loss: 3.551, Accuracy: 0.00% (0/1000)\n",
      "[Test-Retain Acc.] Loss: 0.210, Accuracy: 93.82% (8444/9000)\n",
      "PGD OU@epsilon : 0.02864329786300659\n",
      "Gaussian OU@epsilon : 0.026672714617848398\n",
      "[Proto-Forget Acc.] Loss: 2.253, Accuracy: 0.10% (5/5000)\n",
      "[Retain^* Acc.] Loss: 0.013, Accuracy: 99.98% (44989/45000)\n",
      "tensor(0.1901, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "ul = 'Spotter'\n",
    "eval_model = resnet18(num_classes=num_classes).to(device)\n",
    "eval_model.load_state_dict(torch.load(f\"{exp_path}/{args.model}_{args.dataset}_{ul}.pth\"))\n",
    "evaluate(eval_model, unlearn_train_loader, \"Forget Acc.\", device, criterion)\n",
    "evaluate(eval_model, remain_train_loader, \"Retain Acc.\", device, criterion)\n",
    "evaluate(eval_model, unlearn_test_loader, \"Test-Forget Acc.\", device, criterion)\n",
    "evaluate(eval_model, remain_test_loader, \"Test-Retain Acc.\", device, criterion)\n",
    "models_Fu = {}\n",
    "models_Fu[ul] = eval_model\n",
    "\n",
    "overunlearning_scores = compare_unlearning_methods(model, models_Fu, adv_test_soft_loader, unlearn_classes, 'JS', device=device)\n",
    "print(\"PGD OU@epsilon :\", overunlearning_scores[ul])\n",
    "overunlearning_scores = compare_unlearning_methods(model, models_Fu, gaussian_test_soft_loader, unlearn_classes, 'JS', device=device)\n",
    "print(\"Gaussian OU@epsilon :\", overunlearning_scores[ul])\n",
    "\n",
    "# Prototypical Relearning Attack\n",
    "eval_model = update_fc_with_prototypes(\n",
    "        eval_model,\n",
    "        unlearn_attack_loader,\n",
    "        unlearn_classes,\n",
    "        num_samples_per_class=5,\n",
    "        metric='cosine',\n",
    "        device='cuda',\n",
    "        alpha=0.8\n",
    ")\n",
    "\n",
    "evaluate(eval_model, unlearn_train_loader, \"Proto-Forget Acc.\", device, criterion)\n",
    "evaluate(eval_model, remain_train_loader, \"Retain^* Acc.\", device, criterion)\n",
    "\n",
    "print(compute_dispersion_loss(eval_model, unlearn_train_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81ee7239-4054-4392-a5c2-680f776cd22c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
