{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "35240287-c287-4d38-a714-009f65251791",
   "metadata": {},
   "source": [
    "# Base Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de041be8-f0c0-4cc0-9f57-bedad3722871",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from utils import *\n",
    "from agents import *\n",
    "import time\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import copy\n",
    "import torch.nn.functional as F\n",
    "from copy import deepcopy\n",
    "import argparse\n",
    "from torchvision import datasets, transforms\n",
    "from torch.utils.data import DataLoader, Subset\n",
    "import numpy as np\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from models import resnet18\n",
    "from agents.adv import FGSM\n",
    "import random\n",
    "import math\n",
    "from ov_utils import *\n",
    "\n",
    "seed_everything(42)\n",
    "\n",
    "os.makedirs('checkpoints_cifar10', exist_ok = True)\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "\n",
    "transform_train = transforms.Compose([\n",
    "    transforms.RandomCrop(32, padding=4),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
    "])\n",
    "\n",
    "transform_test = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
    "])\n",
    "\n",
    "trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)\n",
    "testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4)\n",
    "\n",
    "model = resnet18(num_classes=10)\n",
    "model = model.to(device)\n",
    "\n",
    "num_epochs = 200\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.1, weight_decay=5e-4)\n",
    "scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)\n",
    "\n",
    "best_acc = 0.0 \n",
    "\n",
    "def train(epoch):\n",
    "    model.train()\n",
    "    running_loss = 0.0\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for batch_idx, (inputs, targets) in enumerate(trainloader):\n",
    "        inputs, targets = inputs.to(device), targets.to(device)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, targets)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        running_loss += loss.item()\n",
    "        _, predicted = outputs.max(1)\n",
    "        total += targets.size(0)\n",
    "        correct += predicted.eq(targets).sum().item()\n",
    "        \n",
    "        if batch_idx % 100 == 0:\n",
    "            print(f'Epoch: {epoch} | Batch: {batch_idx}/{len(trainloader)} | Loss: {running_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}% ({correct}/{total})')\n",
    "\n",
    "def test(epoch):\n",
    "    global best_acc\n",
    "    model.eval()\n",
    "    test_loss = 0.0\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, (inputs, targets) in enumerate(testloader):\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            outputs = model(inputs)\n",
    "            loss = criterion(outputs, targets)\n",
    "            \n",
    "            test_loss += loss.item()\n",
    "            _, predicted = outputs.max(1)\n",
    "            total += targets.size(0)\n",
    "            correct += predicted.eq(targets).sum().item()\n",
    "    \n",
    "    acc = 100. * correct / total\n",
    "    print(f'Test Epoch: {epoch} | Loss: {test_loss/len(testloader):.3f} | Acc: {acc:.3f}% ({correct}/{total})')\n",
    "    \n",
    "    if acc > best_acc:\n",
    "        print(f'New best accuracy: {acc:.3f}% (previous best: {best_acc:.3f}%), saving the model...')\n",
    "        best_acc = acc\n",
    "        torch.save(model.state_dict(), 'checkpoints_cifar10/resnet18_cifar10_best.pth')\n",
    "    \n",
    "    return acc\n",
    "    \n",
    "for epoch in range(num_epochs):\n",
    "    train(epoch)\n",
    "    test(epoch)\n",
    "    scheduler.step()\n",
    "\n",
    "torch.save(model.state_dict(), 'checkpoints_cifar10/resnet18_cifar10_final.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39c5eacf-26ec-4c4f-b991-7ea4353fa1cc",
   "metadata": {},
   "source": [
    "# Exclude Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae39d800-cabc-4764-b902-60bf8b152a39",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_exclude = 1\n",
    "all_classes = list(range(10))\n",
    "excluded_classes = random.sample(all_classes, num_exclude)\n",
    "print(f\"Excluded Labels: {excluded_classes}\")\n",
    "\n",
    "train_indices = [\n",
    "    idx for idx, label in enumerate(trainset.targets)\n",
    "    if label not in excluded_classes\n",
    "]\n",
    "\n",
    "trainset = torch.utils.data.Subset(trainset, train_indices)\n",
    "\n",
    "trainloader = torch.utils.data.DataLoader(\n",
    "    trainset,\n",
    "    batch_size=128,\n",
    "    shuffle=True,\n",
    "    num_workers=4\n",
    ")\n",
    "testloader = torch.utils.data.DataLoader(\n",
    "    testset,\n",
    "    batch_size=100,\n",
    "    shuffle=False,\n",
    "    num_workers=4\n",
    ")\n",
    "\n",
    "model = resnet18(num_classes=10)\n",
    "model = model.to(device)\n",
    "\n",
    "num_epochs = 200\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.1, weight_decay=5e-4)\n",
    "scheduler = optim.lr_scheduler.MultiStepLR(optimizer,\n",
    "                                           milestones=[60, 120, 160],\n",
    "                                           gamma=0.2)\n",
    "\n",
    "best_acc = 0.0  # 최고 테스트 정확도 저장\n",
    "\n",
    "def train(epoch):\n",
    "    model.train()\n",
    "    running_loss = 0.0\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for batch_idx, (inputs, targets) in enumerate(trainloader):\n",
    "        inputs, targets = inputs.to(device), targets.to(device)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, targets)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        running_loss += loss.item()\n",
    "        _, predicted = outputs.max(1)\n",
    "        total += targets.size(0)\n",
    "        correct += predicted.eq(targets).sum().item()\n",
    "        \n",
    "        if batch_idx % 100 == 0:\n",
    "            print(f'Epoch: {epoch} | Batch: {batch_idx}/{len(trainloader)} '\n",
    "                  f'| Loss: {running_loss/(batch_idx+1):.3f} '\n",
    "                  f'| Acc: {100.*correct/total:.3f}% ({correct}/{total})')\n",
    "\n",
    "def test(epoch):\n",
    "    global best_acc\n",
    "    model.eval()\n",
    "    test_loss = 0.0\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, (inputs, targets) in enumerate(testloader):\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            outputs = model(inputs)\n",
    "            loss = criterion(outputs, targets)\n",
    "            \n",
    "            test_loss += loss.item()\n",
    "            _, predicted = outputs.max(1)\n",
    "            total += targets.size(0)\n",
    "            correct += predicted.eq(targets).sum().item()\n",
    "    \n",
    "    acc = 100. * correct / total\n",
    "    print(f'Test Epoch: {epoch} | Loss: {test_loss/len(testloader):.3f} '\n",
    "          f'| Acc: {acc:.3f}% ({correct}/{total})')\n",
    "    \n",
    "    # 최고 정확도 모델 갱신 시 저장\n",
    "    if acc > best_acc:\n",
    "        print(f'New best accuracy: {acc:.3f}% '\n",
    "              f'(previous best: {best_acc:.3f}%), saving the model...')\n",
    "        best_acc = acc\n",
    "        torch.save(model.state_dict(), 'checkpoints_cifar10/resnet18_cifar10_retrain_best.pth')\n",
    "    \n",
    "    return acc\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    train(epoch)\n",
    "    test(epoch)\n",
    "    scheduler.step()\n",
    "\n",
    "torch.save(model.state_dict(), 'checkpoints_cifar10/resnet18_cifar10_retrain_final.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "768bf032-27ac-4418-bb0c-25d8e5ccff26",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
