{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb6661cf-06e1-4f8f-8f3d-13acc2c9a5b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import random\n",
    "import torch\n",
    "import torchvision\n",
    "\n",
    "from mpl_toolkits.axes_grid1 import ImageGrid\n",
    "from PIL import Image\n",
    "from utils.training_utils import LSLogLoss, BinaryMixupLoss, MixupLoss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efb39914-9dce-49af-91d3-2adb23a9a7f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "if device != \"cpu\":\n",
    "    print(\"Device count: \", torch.cuda.device_count())\n",
    "    print(\"GPU being used: {}\".format(torch.cuda.get_device_name(0)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "481935f9-f166-4775-8c08-0eb192dc9ec8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fix seeds for reproducibility.\n",
    "seed = 42\n",
    "random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad702466-9c67-412f-8b17-296a46f40b95",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(\n",
    "    model,\n",
    "    train_loader,\n",
    "    loss_fn,\n",
    "    optimizer,\n",
    "    device=\"cpu\",\n",
    "):\n",
    "    model.train()\n",
    "    mixup_train = isinstance(loss_fn, MixupLoss)\n",
    "    avg_batch_loss = 0\n",
    "    for data, target in train_loader:\n",
    "        data, target = data.to(device), target.to(device) # target.to(device).float().unsqueeze(dim=1)\n",
    "        optimizer.zero_grad()\n",
    "        if mixup_train:\n",
    "            loss = loss_fn(model, data, target)\n",
    "        else:\n",
    "            output = model(data)\n",
    "            loss = loss_fn(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        avg_batch_loss += loss.item() / len(train_loader)\n",
    "\n",
    "    return avg_batch_loss\n",
    "\n",
    "\n",
    "def test(model, test_loader, device=\"cpu\"):\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            if output.shape[1] > 1:\n",
    "                pred = output.argmax(\n",
    "                    dim=1, keepdim=True\n",
    "                )  # get the index of the max log-probability\n",
    "            else:\n",
    "                pred = output.round()\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    return 100 * (1 - (correct / len(test_loader.dataset)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfb2fd5b-514f-404f-971a-a8b32412691c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ColoredMNIST(torch.utils.data.Dataset):\n",
    "\n",
    "    def __init__(self, train=True, to_tensor=True):\n",
    "        self.dataset = torchvision.datasets.MNIST(\"data\", train=train, download=True)\n",
    "        self.train = train\n",
    "        self.to_tensor = to_tensor\n",
    "        self.transform = torchvision.transforms.ToTensor()\n",
    "        #self.colors = [4 * i * np.ones(3) for i in range(1, 11)]\n",
    "        self.colors = np.random.randint(low=0, high=16, size=(10, 3))\n",
    "        # self.transform = torchvision.transforms.Compose(\n",
    "        #     [torchvision.transforms.ToTensor(), \n",
    "        #      torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n",
    "        # )\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img, target = self.dataset[index]\n",
    "        img = img.convert(\"RGB\")\n",
    "\n",
    "        # if target < 5:\n",
    "        #     target = 0\n",
    "        # else:\n",
    "        #     target = 1\n",
    "\n",
    "        if self.train:\n",
    "            img = np.array(img)\n",
    "            img[(img[:, :, 0] == 0), :] = self.colors[target]\n",
    "            img = Image.fromarray(img, mode=\"RGB\")\n",
    "        else:\n",
    "            img = np.array(img)\n",
    "            img[(img[:, :, 0] == 0), :] = self.colors[9 - target]\n",
    "            img = Image.fromarray(img, mode=\"RGB\")\n",
    "            \n",
    "        if self.to_tensor:\n",
    "            img = self.transform(img)\n",
    "        return img, target\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d503a5f5-2828-45ed-b722-8ab43d13c5a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_train = ColoredMNIST(train=True, to_tensor=True)\n",
    "mnist_test = ColoredMNIST(train=False, to_tensor=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f7b5fa8-cc6e-4eb7-bce0-7fdbc78b3ffb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def imshow(img):\n",
    "    # img = img / 2 + 0.5 \n",
    "    npimg = img.numpy()\n",
    "    plt.figure(figsize=(3, 3))\n",
    "    plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
    "    plt.axis('off')\n",
    "    plt.savefig(\"plots/mnist_grid.jpg\", bbox_inches=\"tight\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d15aeb5d-7c98-457c-9377-fb61021f4f00",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dl = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ea1f219-62a6-4401-b721-cb2b1b9df931",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataiter = iter(train_dl)\n",
    "images, labels = next(dataiter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7ba08d2-81d7-4cdf-8e0c-43af87ef08eb",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "imshow(torchvision.utils.make_grid(images))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee8198d1-275e-4da3-9e08-681d41d152e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "in_dim = mnist_train[0][0].shape[0] * mnist_train[0][0].shape[1] * mnist_train[0][0].shape[2]\n",
    "epochs = 25\n",
    "wd = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f8e3477-d343-4411-bc72-d5d06192d09b",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Flatten(),\n",
    "    torch.nn.Linear(in_dim, 2048),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(2048, 10),\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9b8df90-a016-49d5-9fe2-7c97868eec19",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=wd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "496bd787-ebbe-4a72-bc41-04dd7d0ec946",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dl = torch.utils.data.DataLoader(mnist_train, batch_size=500, shuffle=True)\n",
    "test_dl = torch.utils.data.DataLoader(mnist_test, batch_size=500, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3169e11-0d2e-4b3d-ad9e-6a2c9b8fe9d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#loss_fn = torch.nn.CrossEntropyLoss()\n",
    "#loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.1)\n",
    "loss_fn = MixupLoss(mixup_alpha=1, criterion=torch.nn.CrossEntropyLoss())\n",
    "#loss_fn = torch.nn.BCELoss()\n",
    "#loss_fn = LSLogLoss(alpha=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37728d47-9d6e-45a1-94ed-6f20f4aaff8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(1, epochs + 1):\n",
    "    avg_batch_loss = train(model, train_dl, loss_fn, optimizer, device)\n",
    "    test_error = test(model, test_dl, device)\n",
    "    print(f\"[Epoch {epoch}] Batch Loss: {avg_batch_loss:.3f} \\t Test Error: {test_error:.2f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
