{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97c49271",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torchvision\n",
    "import tarfile\n",
    "import random\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "from torchvision.datasets.utils import download_url\n",
    "from torchvision.datasets import ImageFolder, CIFAR100\n",
    "from torch.utils.data import DataLoader\n",
    "import torchvision.transforms as tt\n",
    "from torch.utils.data import random_split\n",
    "from torchvision.utils import make_grid\n",
    "import matplotlib\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "matplotlib.rcParams['figure.facecolor'] = '#ffffff'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec272629-71cb-4bb1-8a42-e45cb0dc3199",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 1971\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed(seed)\n",
    "torch.use_deterministic_algorithms(True)\n",
    "np.random.seed(seed)\n",
    "random.seed(seed)\n",
    "import os\n",
    "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":16:8\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c082e82",
   "metadata": {},
   "outputs": [],
   "source": [
    "project_name='cifar100-resnet9'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83620cde-98e2-48ae-974f-00f56e4fbb4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = './cifar-100-images/CIFAR100'\n",
    "print(os.listdir(data_dir))\n",
    "classes = os.listdir(data_dir + \"/TRAIN\")\n",
    "print(classes)\n",
    "print(len(classes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e674e253-9df8-4dd7-80f9-e7c5e9a2450f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data transforms (normalization & data augmentation)\n",
    "stats = ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))\n",
    "train_tfms = tt.Compose([tt.RandomCrop(32, padding=4, padding_mode='reflect'), \n",
    "                         tt.RandomHorizontalFlip(), \n",
    "                         tt.ToTensor(), \n",
    "                         tt.Normalize(*stats,inplace=True)])\n",
    "valid_tfms = tt.Compose([tt.ToTensor(), tt.Normalize(*stats)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae6112c7-90be-4f3c-a10f-cfdfc7dee495",
   "metadata": {},
   "outputs": [],
   "source": [
    "# PyTorch datasets\n",
    "train_ds = ImageFolder(data_dir+'/TRAIN', train_tfms)\n",
    "check_ds = ImageFolder(data_dir+'/TRAIN', train_tfms)\n",
    "valid_ds = ImageFolder(data_dir+'/TEST', valid_tfms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3d160dc-e128-4025-92ec-73c7f2f2ccea",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 256"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1870683-8cd9-442c-bb34-bf657fa9d586",
   "metadata": {},
   "outputs": [],
   "source": [
    "# PyTorch data loaders\n",
    "train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=3, pin_memory=True)\n",
    "check_dl = DataLoader(check_ds, batch_size, shuffle=True, num_workers=3, pin_memory=True)\n",
    "valid_dl = DataLoader(valid_ds, batch_size*2, num_workers=3, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0e43ca4-69b6-441d-ac9d-d1039609e7b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def denormalize(images, means, stds):\n",
    "    means = torch.tensor(means).reshape(1, 3, 1, 1)\n",
    "    stds = torch.tensor(stds).reshape(1, 3, 1, 1)\n",
    "    return images * stds + means\n",
    "\n",
    "def show_batch(dl):\n",
    "    for images, labels in dl:\n",
    "        fig, ax = plt.subplots(figsize=(12, 12))\n",
    "        ax.set_xticks([]); ax.set_yticks([])\n",
    "        denorm_images = denormalize(images, *stats)\n",
    "        ax.imshow(make_grid(denorm_images[:64], nrow=8).permute(1, 2, 0).clamp(0,1))\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ebf3cfa-18df-4c12-97b3-39ec8ee6ad67",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_default_device(gpu):\n",
    "    \"\"\"Pick GPU if available, else CPU\"\"\"\n",
    "    if torch.cuda.is_available():\n",
    "        return torch.device(f'cuda:{gpu}')\n",
    "    else:\n",
    "        return torch.device('cpu')\n",
    "    \n",
    "def to_device(data, device):\n",
    "    \"\"\"Move tensor(s) to chosen device\"\"\"\n",
    "    if isinstance(data, (list,tuple)):\n",
    "        return [to_device(x, device) for x in data]\n",
    "    return data.to(device, non_blocking=True)\n",
    "\n",
    "class DeviceDataLoader():\n",
    "    \"\"\"Wrap a dataloader to move data to a device\"\"\"\n",
    "    def __init__(self, dl, device):\n",
    "        self.dl = dl\n",
    "        self.device = device\n",
    "        \n",
    "    def __iter__(self):\n",
    "        \"\"\"Yield a batch of data after moving it to device\"\"\"\n",
    "        for b in self.dl: \n",
    "            yield to_device(b, self.device)\n",
    "\n",
    "    def __len__(self):\n",
    "        \"\"\"Number of batches\"\"\"\n",
    "        return len(self.dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8f22522-4145-4b5e-a939-661d44c2bdb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = get_default_device(1)\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "def0472d-b320-41e4-953c-df41e51b7c6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dl = DeviceDataLoader(train_dl, device)\n",
    "check_dl = DeviceDataLoader(check_dl, device)\n",
    "valid_dl = DeviceDataLoader(valid_dl, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "925d9ba5-f0a8-4c07-8c8b-cf006ca3a415",
   "metadata": {},
   "outputs": [],
   "source": [
    "def accuracy(outputs, labels):\n",
    "    _, preds = torch.max(outputs, dim=1)\n",
    "    return torch.tensor(torch.sum(preds == labels).item() / len(preds))\n",
    "\n",
    "class ImageClassificationBase(nn.Module):\n",
    "    def training_step(self, batch):\n",
    "        images, labels = batch \n",
    "        out = self(images)                  # Generate predictions\n",
    "        loss = F.cross_entropy(out, labels) # Calculate loss\n",
    "        return loss\n",
    "    \n",
    "    def validation_step(self, batch):\n",
    "        images, labels = batch \n",
    "        out = self(images)                    # Generate predictions\n",
    "        loss = F.cross_entropy(out, labels)   # Calculate loss\n",
    "        acc = accuracy(out, labels)           # Calculate accuracy\n",
    "        return {'val_loss': loss.detach(), 'val_acc': acc}\n",
    "        \n",
    "    def validation_epoch_end(self, outputs):\n",
    "        batch_losses = [x['val_loss'] for x in outputs]\n",
    "        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses\n",
    "        batch_accs = [x['val_acc'] for x in outputs]\n",
    "        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies\n",
    "        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}\n",
    "    \n",
    "    def epoch_end(self, epoch, result):\n",
    "        print(\"Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}\".format(\n",
    "            epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_acc']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e78e2a5-c53b-4cb7-ba6b-0b097c56e6a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def conv_block(in_channels, out_channels, pool=False):\n",
    "    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), \n",
    "              nn.BatchNorm2d(out_channels), \n",
    "              nn.ReLU(inplace=True)]\n",
    "    if pool: layers.append(nn.MaxPool2d(2))\n",
    "    return nn.Sequential(*layers)\n",
    "\n",
    "class ResNet9(ImageClassificationBase):\n",
    "    def __init__(self, in_channels, num_classes):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.conv1 = conv_block(in_channels, 64)\n",
    "        self.conv2 = conv_block(64, 128, pool=True)\n",
    "        self.res1 = nn.Sequential(conv_block(128, 128), conv_block(128, 128))\n",
    "        \n",
    "        self.conv3 = conv_block(128, 256, pool=True)\n",
    "        self.conv4 = conv_block(256, 512, pool=True)\n",
    "        self.res2 = nn.Sequential(conv_block(512, 512), conv_block(512, 512))\n",
    "        \n",
    "        self.classifier = nn.Sequential(nn.MaxPool2d(4), \n",
    "                                        nn.Flatten(), \n",
    "                                        nn.Dropout(0.2),\n",
    "                                        nn.Linear(512, num_classes))\n",
    "        \n",
    "    def forward(self, xb):\n",
    "        out = self.conv1(xb)\n",
    "        out = self.conv2(out)\n",
    "        out = self.res1(out) + out\n",
    "        out = self.conv3(out)\n",
    "        out = self.conv4(out)\n",
    "        out = self.res2(out) + out\n",
    "        out = self.classifier(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29880176-4c05-401d-9e7e-ce08c0bd4444",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 1000\n",
    "max_lr = 0.01\n",
    "grad_clip = 0\n",
    "weight_decay = 0\n",
    "opt_func = torch.optim.SGD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "095af88a-fb89-4dc0-8766-04e973b57760",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_random = np.random.randint(0,10000000)\n",
    "n_random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "588b77d6-ec8e-4fee-a50e-c0cda92ae5ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = to_device(ResNet9(3, 100), device)\n",
    "#torch.save(model, './save/model_star_resnet9_{}_{}_{}_{}_{}.pt'.format(n_random, max_lr, batch_size, seed, epochs))\n",
    "model_star = torch.load('./save/model_star_resnet9_{}_{}_{}_{}_{}.pt'.format(n_random, max_lr, batch_size, seed, epochs))\n",
    "model_star = to_device(model_star, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2eab92c-009d-480c-9516-b356a4a1a102",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def evaluate(model, val_loader):\n",
    "    model.eval()\n",
    "    outputs = [model.validation_step(batch) for batch in val_loader]\n",
    "    return model.validation_epoch_end(outputs)\n",
    "\n",
    "def get_lr(optimizer):\n",
    "    for param_group in optimizer.param_groups:\n",
    "        return param_group['lr']\n",
    "\n",
    "def fit_one_cycle(epochs, max_lr, model, model_star, train_loader, check_loader, val_loader, \n",
    "                  weight_decay=0, grad_clip=None, opt_func=torch.optim.SGD):\n",
    "    torch.cuda.empty_cache()\n",
    "    tracking = {}\n",
    "    \n",
    "    # Set up cutom optimizer with weight decay\n",
    "    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)\n",
    "    # Set up one-cycle learning rate scheduler\n",
    "    sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epochs, \n",
    "                                                steps_per_epoch=len(train_loader))\n",
    "\n",
    "    \n",
    "    stoch_losses = []\n",
    "    stoch_corrs = []\n",
    "    stoch_grad_norms = []\n",
    "    dists = []\n",
    "    for epoch in range(epochs):\n",
    "        train_losses = []\n",
    "        lrs = []\n",
    "        # Training Phase \n",
    "        model.train()\n",
    "        \n",
    "        for n_batch, batch in enumerate(train_loader):\n",
    "            loss = model.training_step(batch)\n",
    "            train_losses.append(loss)\n",
    "            loss.backward()\n",
    "\n",
    "            if n_batch % len(train_loader) == 0:\n",
    "                model.eval()\n",
    "                stoch_grad_norm = 0.\n",
    "                dist = 0.\n",
    "                stoch_grad_corr = 0.\n",
    "                for p, p_star in zip(model.parameters(), model_star.parameters()):\n",
    "                    p_grad = p.grad.data.view(-1)\n",
    "                    p_param = p.data.view(-1)\n",
    "                    p_param_star = p_star.data.view(-1)\n",
    "                    stoch_grad_norm += p_grad.dot(p_grad).cpu()\n",
    "                    stoch_grad_corr += p_grad.dot(p_param - p_param_star).cpu()\n",
    "                    dist += (p_param - p_param_star).dot(p_param - p_param_star).cpu()\n",
    "        \n",
    "                stoch_grad_norm = np.sqrt(stoch_grad_norm)\n",
    "                dist = np.sqrt(dist)\n",
    "        \n",
    "                model.train()\n",
    "    \n",
    "                stoch_losses += [loss.item()]\n",
    "                stoch_corrs += [stoch_grad_corr.item()] \n",
    "                stoch_grad_norms += [stoch_grad_norm.item()]\n",
    "                dists += [dist.item()]\n",
    "\n",
    "                print('Epoch[{}], stoch_grad_angle {:.5f}, stoch_grad_corr {:.5f}, stoch_loss {:.5f}'.format(\n",
    "                    epoch,\n",
    "                    stoch_corrs[-1]/(stoch_grad_norms[-1]*dists[-1]+1e-10),\n",
    "                    stoch_corrs[-1],\n",
    "                    stoch_losses[-1])\n",
    "                     )\n",
    "\n",
    "            \n",
    "            \n",
    "            # Gradient clipping\n",
    "            if grad_clip: \n",
    "                nn.utils.clip_grad_value_(model.parameters(), grad_clip)\n",
    "            \n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "            \n",
    "            # Record & update learning rate\n",
    "            lrs.append(get_lr(optimizer))\n",
    "            sched.step()\n",
    "        \n",
    "        # Validation phase\n",
    "        result = evaluate(model, val_loader)\n",
    "        result['train_loss'] = torch.stack(train_losses).mean().item()\n",
    "        result['lrs'] = lrs\n",
    "        model.epoch_end(epoch, result)\n",
    "    \n",
    "    \n",
    "    tracking['stoch_loss'] = stoch_losses\n",
    "    tracking['stoch_corr'] = stoch_corrs\n",
    "    tracking['stoch_grad_norm'] = stoch_grad_norms\n",
    "    tracking['dist'] = dists\n",
    "    return tracking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "147bc458-128d-4723-8c40-c710a5bc93c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "history = fit_one_cycle(epochs, max_lr, model, model_star, train_dl, check_dl, valid_dl, \n",
    "                             grad_clip=grad_clip, \n",
    "                             weight_decay=weight_decay, \n",
    "                             opt_func=opt_func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1067152-a383-44a8-a32d-125962755f05",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.semilogy(history['stoch_loss'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a0eb07f-d57a-4566-b2dd-2d729fa54b03",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model, './save/model_star_resnet9_{}_{}_{}_{}_{}.pt'.format(n_random, max_lr, batch_size, seed, epochs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2895f2db-12fb-415d-a052-7e87971e6bc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f'./result/info_{n_random}_{seed}_{epochs}_{batch_size}_{max_lr}.pickle', 'wb') as handle:\n",
    "    pickle.dump(history, handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93963d2c-a6dd-4b43-a820-c69c559b0b65",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
