{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.transforms as transforms\n",
    "import torchvision\n",
    "from pathlib import Path\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import torch.nn as nn\n",
    "from optimizers.AdaFisher import AdaFisher\n",
    "from models.resnet_cifar import ResNet18"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "H_bar_dir = Path(\"H_bar_resnet\").expanduser()\n",
    "S_dir = Path(\"S_resnet\").expanduser()\n",
    "if not H_bar_dir.exists():\n",
    "        print(f\"Info: Data dir {H_bar_dir} does not exist, building\")\n",
    "        H_bar_dir.mkdir(exist_ok=True, parents=True)\n",
    "if not S_dir.exists():\n",
    "        print(f\"Info: Data dir {S_dir} does not exist, building\")\n",
    "        S_dir.mkdir(exist_ok=True, parents=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_network(type):\n",
    "    if type == \"resnet18\":\n",
    "        net = ResNet18()\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "    return net\n",
    "\n",
    "def get_data(batch_size):\n",
    "    transform_train = transforms.Compose([\n",
    "        transforms.RandomCrop(32, padding=4),\n",
    "        transforms.RandomHorizontalFlip(),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.4914, 0.4822, 0.4465),\n",
    "                                (0.2023, 0.1994, 0.2010)),\n",
    "    ])\n",
    "\n",
    "    transform_test = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.4914, 0.4822, 0.4465),\n",
    "                                (0.2023, 0.1994, 0.2010)),\n",
    "    ])\n",
    "    trainset = torchvision.datasets.CIFAR10(\n",
    "        root='./data', train=True, download=True,\n",
    "        transform=transform_train)\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        trainset, batch_size=batch_size,\n",
    "        shuffle=True,\n",
    "        num_workers=4)\n",
    "\n",
    "    testset = torchvision.datasets.CIFAR10(\n",
    "        root='./data', train=False,\n",
    "        download=True, transform=transform_test)\n",
    "    test_loader = torch.utils.data.DataLoader(\n",
    "        testset, batch_size=100, shuffle=False,\n",
    "        num_workers=4)\n",
    "    return train_loader, test_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_epoch(epoch, model, dataloader, optimizer, criterion, device):\n",
    "    model.train()\n",
    "    train_loss, correct, total = 0, 0, 0\n",
    "    steps = 0\n",
    "    for inputs, targets in tqdm(dataloader, desc=\"Training\"):\n",
    "        inputs = inputs.to(device); targets = targets.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output_train = model(inputs)\n",
    "        loss_train = criterion(output_train, targets)\n",
    "        if optimizer == \"AdaHessian\":\n",
    "            loss_train.backward(create_graph=True)\n",
    "        else:\n",
    "            loss_train.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        train_loss += loss_train.item()\n",
    "        _, predicted = output_train.max(1)\n",
    "        total += targets.size(0)\n",
    "        correct += predicted.eq(targets).sum().item()\n",
    "        steps += 1\n",
    "\n",
    "    train_loss_epoch = train_loss / total\n",
    "    train_accuracy_epoch = correct / total\n",
    "    tqdm.write(f\"== [TRAIN] Epoch: {epoch}, Loss: {train_loss_epoch:.3f}, Accuracy: {train_accuracy_epoch:.3f} ==>\")\n",
    "    return train_loss_epoch, train_accuracy_epoch\n",
    "def valid_epoch(epoch, model, dataloader, criterion, device):\n",
    "    model.eval()\n",
    "    test_loss, correct, total = 0, 0, 0\n",
    "    with torch.no_grad():\n",
    "        for inputs, targets in tqdm(dataloader, desc=\"Validation\"):\n",
    "            inputs = inputs.to(device); targets = targets.to(device)\n",
    "            output_test = model(inputs)\n",
    "            loss_test = criterion(output_test, targets)\n",
    "            test_loss += loss_test.item()\n",
    "            _, predicted = output_test.max(1)\n",
    "            total += targets.size(0)\n",
    "            correct += predicted.eq(targets).sum().item()\n",
    "    \n",
    "    test_loss_epoch = test_loss / total\n",
    "    test_accuracy_epoch = correct / total\n",
    "    tqdm.write(f\"== [VALID] Epoch: {epoch}, Loss: {test_loss_epoch:.3f}, Accuracy: {test_accuracy_epoch:.3f} ==>\")\n",
    "    return test_loss_epoch, test_accuracy_epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(num_epochs, model, train_dataloader, test_dataloader, criterion, optimizer, device):\n",
    "    train_loss, test_loss, train_accuracy, test_accuracy = [], [], [], []\n",
    "    for epoch in range(num_epochs):\n",
    "        train_loss_epoch, train_accuracy_epoch = train_epoch(epoch, model, train_dataloader, optimizer, criterion, device)\n",
    "        train_loss.append(train_loss_epoch); train_accuracy.append(train_accuracy_epoch)\n",
    "        test_loss_epoch, test_accuracy_epoch = valid_epoch(epoch, model, test_dataloader, criterion, device)\n",
    "        test_loss.append(test_loss_epoch); test_accuracy.append(test_accuracy_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main(num_epochs):\n",
    "    if torch.cuda.is_available():\n",
    "        device = 'cuda'\n",
    "    elif torch.backends.mps.is_available():\n",
    "        device = 'mps' \n",
    "    else:\n",
    "        device = 'cpu'\n",
    "        print(\"Warning: CPU will be slow when running\")\n",
    "    model = get_network(type=\"resnet18\").to(device)\n",
    "    train_dataloader, test_dataloader = get_data(batch_size=256)\n",
    "    optimizer = AdaFisher(model, lr=0.001, gammas=[0.92, 0.008], Lambda=1e-3, weight_decay=5e-4)\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    train_model(num_epochs, model, train_dataloader, test_dataloader, criterion, optimizer, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCHS = 50\n",
    "main(EPOCHS)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
