{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5797832a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torchvision\n",
    "import torch.nn.functional as F\n",
    "import argparse\n",
    "import matplotlib\n",
    "\n",
    "from tqdm import tqdm\n",
    "import glob\n",
    "from PIL import Image\n",
    "import os\n",
    "from datetime import datetime\n",
    "import time\n",
    "import math\n",
    "from torch.autograd import Variable\n",
    "from biotorch.module.biomodule import BioModule\n",
    "\n",
    "from ANN import *\n",
    "from visualization import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dcc4b91c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda', index=0)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2a9edadb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), \n",
    "                                            torchvision.transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), \n",
    "                                            std=(3*0.2023, 3*0.1994, 3*0.2010))])\n",
    "\n",
    "cifar_dset_train = torchvision.datasets.CIFAR100('../../../data', train=True, transform=transform, target_transform=None, download=True)\n",
    "train_loader = torch.utils.data.DataLoader(cifar_dset_train, batch_size=20, shuffle=True, num_workers=0)\n",
    "\n",
    "cifar_dset_test = torchvision.datasets.CIFAR100('../../../data', train=False, transform=transform, target_transform=None, download=True)\n",
    "test_loader = torch.utils.data.DataLoader(cifar_dset_test, batch_size=20, shuffle=False, num_workers=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "eb51868d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Module has been converted to fa mode:\n",
      "\n",
      "The layer configuration was:  {'type': 'fa', 'options': {'init': 'kaiming'}}\n",
      "- All the 3 <class 'torch.nn.modules.linear.Linear'> layers were converted successfully.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "BioModule(\n",
       "  (module): MLP(\n",
       "    (linear_layers): ModuleList(\n",
       "      (0): Linear(in_features=3072, out_features=1000, bias=True)\n",
       "      (1): Linear(in_features=1000, out_features=500, bias=True)\n",
       "      (2): Linear(in_features=500, out_features=100, bias=True)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = \"cuda\"\n",
    "criterion = torch.nn.MSELoss().to(device)\n",
    "# criterion = torch.nn.CrossEntropyLoss()\n",
    "activation = F.relu\n",
    "architecture = [int(32*32*3), 1000, 500, 100]\n",
    "model = BioModule(MLP(architecture, activation = activation, final_layer_activation = activation), \n",
    "                  mode = \"fa\", layer_config = {\"type\": \"fa\", \"options\": {\"init\": \"kaiming\"}}).to(device)\n",
    "# model =MLP(architecture, activation = activation, final_layer_activation = False).to(device)\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "e16c84a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx in range(len(model.module.linear_layers)):\n",
    "    model.module.linear_layers[idx].scaling_factor *= 50\n",
    "    model.module.linear_layers[idx].init_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a5556d5a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy :\t 0.01068\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.01068"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluateClassification(model, train_loader, \"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "49f2c777",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [00:28, 87.24it/s]\n",
      "9it [00:00, 85.54it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 1, Train Accuracy : 0.03152, Test Accuracy : 0.0301\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [00:29, 85.87it/s]\n",
      "9it [00:00, 84.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 2, Train Accuracy : 0.03786, Test Accuracy : 0.0384\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [00:28, 87.77it/s]\n",
      "8it [00:00, 78.48it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 3, Train Accuracy : 0.04288, Test Accuracy : 0.0414\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [00:20, 123.70it/s]\n",
      "10it [00:00, 97.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 4, Train Accuracy : 0.04598, Test Accuracy : 0.0449\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [00:25, 97.24it/s] \n",
      "10it [00:00, 93.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 5, Train Accuracy : 0.0481, Test Accuracy : 0.0463\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [00:25, 97.38it/s] \n",
      "9it [00:00, 85.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 6, Train Accuracy : 0.04912, Test Accuracy : 0.0454\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [00:27, 92.25it/s] \n",
      "20it [00:00, 195.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 7, Train Accuracy : 0.05252, Test Accuracy : 0.0483\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [00:12, 197.53it/s]\n",
      "19it [00:00, 189.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 8, Train Accuracy : 0.05512, Test Accuracy : 0.0508\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [00:12, 194.03it/s]\n",
      "19it [00:00, 185.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 9, Train Accuracy : 0.05594, Test Accuracy : 0.0517\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [00:12, 194.11it/s]\n",
      "8it [00:00, 77.32it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 10, Train Accuracy : 0.0558, Test Accuracy : 0.05\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [00:17, 143.32it/s]\n",
      "20it [00:00, 198.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 11, Train Accuracy : 0.05644, Test Accuracy : 0.0513\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [00:24, 102.48it/s]\n",
      "10it [00:00, 81.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 12, Train Accuracy : 0.0585, Test Accuracy : 0.0529\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [00:29, 86.09it/s]\n",
      "9it [00:00, 87.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 13, Train Accuracy : 0.05934, Test Accuracy : 0.0545\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [00:29, 83.66it/s]\n",
      "10it [00:00, 86.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 14, Train Accuracy : 0.05954, Test Accuracy : 0.0566\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [00:30, 82.67it/s]\n",
      "8it [00:00, 78.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 15, Train Accuracy : 0.06152, Test Accuracy : 0.0578\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [00:30, 82.62it/s]\n",
      "9it [00:00, 89.97it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 16, Train Accuracy : 0.0606, Test Accuracy : 0.0559\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [00:30, 82.28it/s]\n",
      "9it [00:00, 86.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 17, Train Accuracy : 0.06082, Test Accuracy : 0.055\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [00:30, 82.95it/s]\n",
      "9it [00:00, 89.66it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 18, Train Accuracy : 0.0609, Test Accuracy : 0.0547\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [00:30, 81.90it/s]\n",
      "9it [00:00, 86.99it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 19, Train Accuracy : 0.06208, Test Accuracy : 0.0527\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [00:30, 81.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 20, Train Accuracy : 0.0618, Test Accuracy : 0.0538\n"
     ]
    }
   ],
   "source": [
    "# # specify optimizer (stochastic gradient descent) and learning rate\n",
    "optimizer = torch.optim.Adam(model.parameters(),lr = 1e-5)\n",
    "# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=0.001, nesterov=False)\n",
    "\n",
    "lr_decay_step = 50\n",
    "lr_decay = 0.9\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=lr_decay)\n",
    "trn_acc_list = []\n",
    "tst_acc_list = []\n",
    "\n",
    "n_epochs = 20\n",
    "for epoch_ in range(n_epochs):\n",
    "    model.train()\n",
    "    for idx, (x, y) in tqdm(enumerate(train_loader)):\n",
    "        x, y = Variable(x.to(device)), Variable(y.to(device))\n",
    "        y_one_hot = F.one_hot(y, num_classes=100)\n",
    "        optimizer.zero_grad()\n",
    "        y_hat = model(x)\n",
    "#         loss = criterion(y_hat,y) # Use this if criterion = torch.nn.CrossEntropyLoss()\n",
    "        loss = criterion(y_hat,y_one_hot.to(torch.float32)) # Use this if criterion = torch.nn.MSELoss().to(device)\n",
    "        # backward pass: compute gradient of the loss with respect to model parameters\n",
    "        model.zero_grad()\n",
    "        loss.backward()\n",
    "        # perform a single optimization step (parameter update)\n",
    "        optimizer.step()\n",
    "    \n",
    "    scheduler.step()\n",
    "    trn_acc = evaluateClassification(model, train_loader, device, False)\n",
    "    tst_acc = evaluateClassification(model, test_loader, device, False)\n",
    "    trn_acc_list.append(trn_acc)\n",
    "    tst_acc_list.append(tst_acc)\n",
    "    \n",
    "    print(\"Epoch : {}, Train Accuracy : {}, Test Accuracy : {}\".format(epoch_+1, trn_acc, tst_acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56bc7b00",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_convergence_plot(trn_acc_list, xlabel = 'Number of Epochs', ylabel = 'Accuracy %',\n",
    "                      title = 'MLP Train Accuracy w.r.t. Epochs', \n",
    "                      figsize = (12,8), fontsize = 25, linewidth = 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c1e4572",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_convergence_plot(tst_acc_list, xlabel = 'Number of Epochs', ylabel = 'Accuracy %',\n",
    "                      title = 'MLP Test Accuracy w.r.t. Epochs', \n",
    "                      figsize = (12,8), fontsize = 25, linewidth = 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1230778d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
