{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.func import functional_call, grad\n",
    "from torch.nn import functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data import random_split\n",
    "from torchvision.datasets import *\n",
    "from torchvision import transforms\n",
    "import pytorch_lightning as pl\n",
    "import os\n",
    "from threading import Thread\n",
    "\n",
    "from gc_module import ContNet\n",
    "\n",
    "print(torch.cuda.is_available())\n",
    "torch.zeros(1).cuda()\n",
    "torch.set_float32_matmul_precision('high')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class MNISTClass(ContNet):\n",
    "    def __init__(self, loglambda0: float, cont_lr: float, cont_reg: float, warmup_epochs: int):\n",
    "        super().__init__(loglambda0, cont_lr, cont_reg, warmup_epochs)\n",
    "        self.lossfunc = F.cross_entropy\n",
    "\n",
    "        self.dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())\n",
    "        size_train = int(len(self.dataset)*0.9)\n",
    "        self.data_train, self.data_val = random_split(self.dataset, [size_train, len(self.dataset) - size_train])\n",
    "\n",
    "        activ = nn.ReLU()\n",
    "        self.net = nn.Sequential(nn.Flatten(start_dim=1),\n",
    "                                nn.Linear(784, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 50), activ,\n",
    "                                nn.Linear(50, 10),\n",
    "                                nn.Softmax(dim=1)\n",
    "                                )\n",
    "        \n",
    "        for layer in self.net.modules():\n",
    "            if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):\n",
    "                nn.init.xavier_uniform_(layer.weight)\n",
    "                #nn.init.zeros_(layer.weight)\n",
    "                nn.init.zeros_(layer.bias)\n",
    "\n",
    "    def forward(self, x):\n",
    "        ypred = self.net(x)\n",
    "        return ypred\n",
    "    \n",
    "    def configure_optimizers(self):\n",
    "        # include logcontvar in optimizer\n",
    "        optimizer = torch.optim.Adam([{'params': self.net.parameters()},\n",
    "                                      {'params': (self.logcontvar,), 'lr': self.cont_lr}],\n",
    "                                     lr=1e-4)\n",
    "        return optimizer\n",
    "\n",
    "    def training_step(self, train_batch, batch_idx):\n",
    "        self.log('param_norm', sum(p.pow(2.0).sum() for p in self.parameters()))\n",
    "        \n",
    "        opt = self.optimizers()\n",
    "        opt.zero_grad()\n",
    "\n",
    "        # add gaussian noise to parameters\n",
    "        rand_samp, ref_params = self.perturb_params()\n",
    "        \n",
    "        # compute loss\n",
    "        x, y = train_batch\n",
    "        ypred = self.forward(x)\n",
    "        loss = self.lossfunc(ypred, y)\n",
    "        self.manual_backward(loss)\n",
    "        \n",
    "        # calculate accuracy\n",
    "        match = torch.eq(y, torch.argmax(ypred, dim=1))\n",
    "        acc = torch.sum(match.detach())/y.shape[0]\n",
    "\n",
    "        # compute contvar gradient\n",
    "        self.contvar_grad(rand_samp, loss)\n",
    "        \n",
    "        # reload reference parameters\n",
    "        self.load_state_dict(ref_params)\n",
    "\n",
    "        opt.step()\n",
    "        self.log('train_loss', loss.detach())\n",
    "        self.log('train_acc', acc)\n",
    "\n",
    "    def validation_step(self, val_batch, batch_idx):\n",
    "        x, y = val_batch\n",
    "        ypred = self.net(x)\n",
    "\n",
    "        loss = self.lossfunc(ypred, y)\n",
    "        self.log('val_loss', loss)\n",
    "        \n",
    "        match = torch.eq(y, torch.argmax(ypred, dim=1))\n",
    "        acc = torch.sum(match)/y.shape[0]\n",
    "        self.log('val_acc', acc)\n",
    "    \n",
    "    def train_dataloader(self):\n",
    "        return DataLoader(self.data_train, batch_size=250, shuffle=True, num_workers=16, persistent_workers=True)\n",
    "        \n",
    "    def val_dataloader(self):\n",
    "        return DataLoader(self.data_val, batch_size=250, num_workers=16, persistent_workers=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_case(hyperparams: list) -> None:\n",
    "    mymodel = MNISTClass(*hyperparams)\n",
    "\n",
    "    # training\n",
    "    epochs = 200\n",
    "    logger = pl.loggers.tensorboard.TensorBoardLogger('.', name=f'l30_dcont_mnist_{epochs}_' + '_'.join([f'{p}' for p in hyperparams]))\n",
    "    trainer = pl.Trainer(max_epochs=epochs, accelerator='gpu', logger=logger)\n",
    "    trainer.fit(mymodel)\n",
    "\n",
    "def sweep_hyperparams(hyperparam_list: list, n_runs: int, hyperparams: list = []) -> None:\n",
    "    if len(hyperparams) == len(hyperparam_list):\n",
    "        for i in range(n_runs):\n",
    "            print('-'*80)\n",
    "            print(f'Running case with hyperparams {hyperparams}')\n",
    "            print('-'*80)\n",
    "\n",
    "            run_case(hyperparams)\n",
    "\n",
    "    else:\n",
    "        for hyperparam_i in hyperparam_list[len(hyperparams)]:\n",
    "            new_hyperparams = hyperparams + [hyperparam_i]\n",
    "            sweep_hyperparams(hyperparam_list, n_runs, new_hyperparams)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "Running case with hyperparams [nan, 0.01, 0.001, 19]\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name | Type       | Params\n",
      "------------------------------------\n",
      "0 | net  | Sequential | 111 K \n",
      "------------------------------------\n",
      "111 K     Trainable params\n",
      "0         Non-trainable params\n",
      "111 K     Total params\n",
      "0.445     Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 199: 100%|██████████| 216/216 [00:04<00:00, 46.10it/s, v_num=5]      "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=200` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 199: 100%|██████████| 216/216 [00:04<00:00, 45.85it/s, v_num=5]\n",
      "--------------------------------------------------------------------------------\n",
      "Running case with hyperparams [nan, 0.01, 0.001, 19]\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name | Type       | Params\n",
      "------------------------------------\n",
      "0 | net  | Sequential | 111 K \n",
      "------------------------------------\n",
      "111 K     Trainable params\n",
      "0         Non-trainable params\n",
      "111 K     Total params\n",
      "0.445     Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 199: 100%|██████████| 216/216 [00:04<00:00, 46.63it/s, v_num=6]       "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=200` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 199: 100%|██████████| 216/216 [00:04<00:00, 46.18it/s, v_num=6]\n",
      "--------------------------------------------------------------------------------\n",
      "Running case with hyperparams [nan, 0.01, 0.001, 19]\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name | Type       | Params\n",
      "------------------------------------\n",
      "0 | net  | Sequential | 111 K \n",
      "------------------------------------\n",
      "111 K     Trainable params\n",
      "0         Non-trainable params\n",
      "111 K     Total params\n",
      "0.445     Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 199: 100%|██████████| 216/216 [00:04<00:00, 47.23it/s, v_num=7]      "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=200` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 199: 100%|██████████| 216/216 [00:04<00:00, 46.98it/s, v_num=7]\n",
      "--------------------------------------------------------------------------------\n",
      "Running case with hyperparams [nan, 0.01, 0.001, 19]\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name | Type       | Params\n",
      "------------------------------------\n",
      "0 | net  | Sequential | 111 K \n",
      "------------------------------------\n",
      "111 K     Trainable params\n",
      "0         Non-trainable params\n",
      "111 K     Total params\n",
      "0.445     Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 199: 100%|██████████| 216/216 [00:04<00:00, 45.48it/s, v_num=8]       "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=200` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 199: 100%|██████████| 216/216 [00:04<00:00, 45.17it/s, v_num=8]\n",
      "--------------------------------------------------------------------------------\n",
      "Running case with hyperparams [nan, 0.01, 0.001, 19]\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name | Type       | Params\n",
      "------------------------------------\n",
      "0 | net  | Sequential | 111 K \n",
      "------------------------------------\n",
      "111 K     Trainable params\n",
      "0         Non-trainable params\n",
      "111 K     Total params\n",
      "0.445     Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 199: 100%|██████████| 216/216 [00:04<00:00, 47.24it/s, v_num=9]      "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=200` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 199: 100%|██████████| 216/216 [00:04<00:00, 46.98it/s, v_num=9]\n",
      "--------------------------------------------------------------------------------\n",
      "Running case with hyperparams [-13.0, 0.01, 0.001, 19]\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name | Type       | Params\n",
      "------------------------------------\n",
      "0 | net  | Sequential | 111 K \n",
      "------------------------------------\n",
      "111 K     Trainable params\n",
      "0         Non-trainable params\n",
      "111 K     Total params\n",
      "0.445     Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 199: 100%|██████████| 216/216 [00:08<00:00, 24.65it/s, v_num=5]       "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=200` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 199: 100%|██████████| 216/216 [00:08<00:00, 24.55it/s, v_num=5]\n",
      "--------------------------------------------------------------------------------\n",
      "Running case with hyperparams [-13.0, 0.01, 0.001, 19]\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name | Type       | Params\n",
      "------------------------------------\n",
      "0 | net  | Sequential | 111 K \n",
      "------------------------------------\n",
      "111 K     Trainable params\n",
      "0         Non-trainable params\n",
      "111 K     Total params\n",
      "0.445     Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 199: 100%|██████████| 216/216 [00:07<00:00, 28.13it/s, v_num=6]       "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=200` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 199: 100%|██████████| 216/216 [00:07<00:00, 27.99it/s, v_num=6]\n",
      "--------------------------------------------------------------------------------\n",
      "Running case with hyperparams [-13.0, 0.01, 0.001, 19]\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name | Type       | Params\n",
      "------------------------------------\n",
      "0 | net  | Sequential | 111 K \n",
      "------------------------------------\n",
      "111 K     Trainable params\n",
      "0         Non-trainable params\n",
      "111 K     Total params\n",
      "0.445     Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 199: 100%|██████████| 216/216 [00:07<00:00, 28.28it/s, v_num=7]       "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=200` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 199: 100%|██████████| 216/216 [00:07<00:00, 28.19it/s, v_num=7]\n",
      "--------------------------------------------------------------------------------\n",
      "Running case with hyperparams [-13.0, 0.01, 0.001, 19]\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name | Type       | Params\n",
      "------------------------------------\n",
      "0 | net  | Sequential | 111 K \n",
      "------------------------------------\n",
      "111 K     Trainable params\n",
      "0         Non-trainable params\n",
      "111 K     Total params\n",
      "0.445     Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 199: 100%|██████████| 216/216 [00:07<00:00, 28.20it/s, v_num=8]      "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=200` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 199: 100%|██████████| 216/216 [00:07<00:00, 28.09it/s, v_num=8]\n",
      "--------------------------------------------------------------------------------\n",
      "Running case with hyperparams [-13.0, 0.01, 0.001, 19]\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name | Type       | Params\n",
      "------------------------------------\n",
      "0 | net  | Sequential | 111 K \n",
      "------------------------------------\n",
      "111 K     Trainable params\n",
      "0         Non-trainable params\n",
      "111 K     Total params\n",
      "0.445     Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 24:   2%|▏         | 5/216 [00:00<00:09, 21.73it/s, v_num=9]         "
     ]
    }
   ],
   "source": [
    "def main() -> None:\n",
    "    # model\n",
    "    loglambda0 = [float('nan'), -13.0]\n",
    "    cont_lr = [1e-2]\n",
    "    cont_reg = [1e-3]\n",
    "    warmup_epochs = [19]\n",
    "    n_runs = 5\n",
    "\n",
    "    hyperparam_list = [loglambda0, cont_lr, cont_reg, warmup_epochs]\n",
    "    sweep_hyperparams(hyperparam_list, n_runs)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gcenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
