{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import argparse\n",
    "\n",
    "class VGG_BN_Dropout(nn.Module):\n",
    "    def __init__(self, num_classes=10):\n",
    "        super(VGG_BN_Dropout, self).__init__()\n",
    "        self.features = nn.Sequential(\n",
    "            nn.Conv2d(3, 64, kernel_size=3, padding=1),\n",
    "            nn.BatchNorm2d(64),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Dropout(0.3),\n",
    "\n",
    "            nn.Conv2d(64, 64, kernel_size=3, padding=1),\n",
    "            nn.BatchNorm2d(64),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2),\n",
    "\n",
    "            nn.Conv2d(64, 128, kernel_size=3, padding=1),\n",
    "            nn.BatchNorm2d(128),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Dropout(0.4),\n",
    "\n",
    "            nn.Conv2d(128, 128, kernel_size=3, padding=1),\n",
    "            nn.BatchNorm2d(128),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2),\n",
    "\n",
    "            nn.Conv2d(128, 256, kernel_size=3, padding=1),\n",
    "            nn.BatchNorm2d(256),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Dropout(0.4),\n",
    "\n",
    "            nn.Conv2d(256, 256, kernel_size=3, padding=1),\n",
    "            nn.BatchNorm2d(256),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Dropout(0.4),\n",
    "\n",
    "            nn.Conv2d(256, 256, kernel_size=3, padding=1),\n",
    "            nn.BatchNorm2d(256),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2),\n",
    "\n",
    "            nn.Conv2d(256, 512, kernel_size=3, padding=1),\n",
    "            nn.BatchNorm2d(512),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Dropout(0.4),\n",
    "\n",
    "            nn.Conv2d(512, 512, kernel_size=3, padding=1),\n",
    "            nn.BatchNorm2d(512),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Dropout(0.4),\n",
    "\n",
    "            nn.Conv2d(512, 512, kernel_size=3, padding=1),\n",
    "            nn.BatchNorm2d(512),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2),\n",
    "\n",
    "            nn.Conv2d(512, 512, kernel_size=3, padding=1),\n",
    "            nn.BatchNorm2d(512),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Dropout(0.4),\n",
    "\n",
    "            nn.Conv2d(512, 512, kernel_size=3, padding=1),\n",
    "            nn.BatchNorm2d(512),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Dropout(0.4),\n",
    "\n",
    "            nn.Conv2d(512, 512, kernel_size=3, padding=1),\n",
    "            nn.BatchNorm2d(512),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        )\n",
    "\n",
    "        self.classifier = nn.Sequential(\n",
    "            nn.Dropout(0.5),\n",
    "            nn.Linear(512, 512),\n",
    "            nn.BatchNorm1d(512),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Dropout(0.5),\n",
    "            nn.Linear(512, num_classes)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.features(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.classifier(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class AdaptiveAdagradType2(optim.Optimizer):\n",
    "    def __init__(self, params, b=0.01, gamma=2, epsilon=1e-6):\n",
    "        defaults = dict(b=b, gamma=gamma, epsilon=epsilon)\n",
    "        super(AdaptiveAdagradType2, self).__init__(params, defaults)\n",
    "\n",
    "    def step(self):\n",
    "        for group in self.param_groups:\n",
    "            b = group['b']\n",
    "            gamma = group['gamma']\n",
    "            epsilon = group['epsilon']\n",
    "\n",
    "            for p in group['params']:\n",
    "                if p.grad is None:\n",
    "                    continue\n",
    "                grad = p.grad.data\n",
    "                state = self.state[p]\n",
    "\n",
    "                if 'sum_grad' not in state:\n",
    "                    state['sum_grad'] = torch.zeros_like(p.data)\n",
    "\n",
    "                state['sum_grad'] += torch.norm(grad, p=gamma)\n",
    "                lr = (1 / (b + state['sum_grad'])).pow(0.5 + epsilon)\n",
    "                # print(\"self.param_groups=\", self.param_groups)\n",
    "                # print(f\"grad= {grad}, sum_grad= {state['sum_grad']}, lr={lr} \" )\n",
    "                # print(\"p.data=\", p.data)\n",
    "                p.data -= lr * grad\n",
    "                # print(\"p.data=\", p.data)\n",
    "                \n",
    "\n",
    "\n",
    "def calculate_accuracy(model, dataloader):\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for inputs, targets in dataloader:\n",
    "            inputs, targets = inputs.cuda(), targets.cuda()\n",
    "            outputs = model(inputs)\n",
    "            _, predicted = outputs.max(1)\n",
    "            total += targets.size(0)\n",
    "            correct += predicted.eq(targets).sum().item()\n",
    "    return correct / total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# b_eps_gam_tuple =(  (0.01,0,0.01), (0.1,0,0.01),\n",
    "#                     (0.01,0,0.1),  (0.1,0,0.1),\n",
    "#                     (0.01,0,1),    (0.1,0,1),\n",
    "#                     (0.01,0,2),    (0.1,0,2))\n",
    "\n",
    "b_eps_gam_tuple_2 =( (0,0,0.001), (0,0,0.01), (0,0,0.1), (0,0,0.5), (0,0,2),\n",
    "                     (0.01,0,0.001), (0.01,0,0.01), (0.01,0,0.1), (0.01,0,0.5), (0.01,0,2),\n",
    "                     (0.1,0,0.001), (0.1,0,0.01), (0.1,0,0.1), (0.1,0,0.5), (0.1,0,2))\n",
    "\n",
    "results = {}\n",
    "num_epochs = 150"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dataset=cifar10, batchsize=128, epochs=1, b=0.01, eps=0, gam=0.01\n",
      "dataset=cifar10, batchsize=128, epochs=1, b=0.1, eps=0, gam=0.01\n",
      "dataset=cifar10, batchsize=128, epochs=1, b=0.01, eps=0, gam=0.1\n",
      "dataset=cifar10, batchsize=128, epochs=1, b=0.1, eps=0, gam=0.1\n",
      "dataset=cifar10, batchsize=128, epochs=1, b=0.01, eps=0, gam=1\n",
      "dataset=cifar10, batchsize=128, epochs=1, b=0.1, eps=0, gam=1\n",
      "dataset=cifar10, batchsize=128, epochs=1, b=0.01, eps=0, gam=2\n",
      "dataset=cifar10, batchsize=128, epochs=1, b=0.1, eps=0, gam=2\n"
     ]
    }
   ],
   "source": [
    "for b, eps, gamma in b_eps_gam_tuple_2:\n",
    "    key = f\"b={b}_e={eps}_g={gamma}\"\n",
    "    p = argparse.ArgumentParser(description=\"ResNet + AdaGrad with custom LR schedule\")\n",
    "    p.add_argument('--dataset', type=str, default='cifar10', help='cifar10 or cifar100')\n",
    "    p.add_argument('--batch_size', type=int, default=128, help='batch size')\n",
    "    p.add_argument('--epochs', type=int, default=num_epochs, help='total epochs')\n",
    "    p.add_argument('--b', type=float, default=b, help='b hyperparameter')\n",
    "    p.add_argument('--epsilon', type=float, default=eps, help='epsilon hyperparameter')\n",
    "    p.add_argument('--gamma', type=float, default=gamma, help='gamma hyperparameter')\n",
    "    args, unknown = p.parse_known_args()\n",
    "    print(f\"dataset={args.dataset}, batchsize={args.batch_size}, epochs={args.epochs}, b={args.b}, eps={args.epsilon}, gam={args.gamma}\")\n",
    "\n",
    "    transform = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "    ])\n",
    "\n",
    "    train_dataset = datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)\n",
    "    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)\n",
    "\n",
    "    test_dataset = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)\n",
    "    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)\n",
    "\n",
    "    model = VGG_BN_Dropout().cuda()\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    optimizer = AdaptiveAdagradType2(model.parameters(), b=args.b, epsilon=args.epsilon)\n",
    "\n",
    "    train_loss_history = []\n",
    "    test_loss_history = []\n",
    "    train_acc_history = []\n",
    "    test_acc_history = []\n",
    "\n",
    "    for epoch in range(args.epochs):\n",
    "        model.train()\n",
    "        for inputs, targets in train_loader:\n",
    "            inputs, targets = inputs.cuda(), targets.cuda()\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(inputs)\n",
    "            loss = criterion(outputs, targets)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "        train_acc = calculate_accuracy(model, train_loader)\n",
    "        test_acc = calculate_accuracy(model, test_loader)\n",
    "\n",
    "        # Record history\n",
    "        train_loss_history.append(loss.item())\n",
    "        train_acc_history.append(train_acc)\n",
    "        test_acc_history.append(test_acc)\n",
    "\n",
    "        # Evaluate test loss\n",
    "        model.eval()\n",
    "        test_loss = 0.0\n",
    "        with torch.no_grad():\n",
    "            for inputs, targets in test_loader:\n",
    "                inputs, targets = inputs.cuda(), targets.cuda()\n",
    "                outputs = model(inputs)\n",
    "                test_loss += criterion(outputs, targets).item()\n",
    "        test_loss /= len(test_loader)\n",
    "        test_loss_history.append(test_loss)\n",
    "\n",
    "        # print(f\"Epoch {epoch + 1}: Train Loss = {loss.item()}, Test Loss = {test_loss}, Train Acc = {train_acc}, Test Acc = {test_acc}\")\n",
    "\n",
    "    results[key] = {\n",
    "        'train_loss': train_loss_history,\n",
    "        'test_loss': test_loss_history,\n",
    "        'train_acc': train_acc_history,\n",
    "        'test_acc': test_acc_history}\n",
    "    \n",
    "    # Print final summary for all epochs at current hyperparameter\n",
    "    print(\n",
    "        f\"{key} final result: \"\n",
    "        f\"train_loss={train_loss_history[-1]:.4f}, \"\n",
    "        f\"test_loss={test_loss_history[-1]:.4f}, \"\n",
    "        f\"train_acc={train_acc_history[-1]:.4f}, \"\n",
    "        f\"test_acc={test_acc_history[-1]:.4f}\")\n",
    "    # print(\"Training and Test Loss/Accuracy History:\")\n",
    "    # for epoch in range(args.epochs):\n",
    "    #     print(f\"Epoch {epoch + 1}: Train Loss = {train_loss_history[epoch]}, Test Loss = {test_loss_history[epoch]}, Train Acc = {train_acc_history[epoch]}, Test Acc = {test_acc_history[epoch]}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Variable saved to ./results/hyp_results_20250514_221521.pkl\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "from datetime import datetime\n",
    "\n",
    "timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "filename = f\"./results/hyp_results_{timestamp}.pkl\"\n",
    "with open(filename, 'wb') as file:\n",
    "    pickle.dump(results, file)\n",
    "print(f\"Variable saved to {filename}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'b=0.01_e=0_g=0.01': {'train_loss': [1.6842645406723022], 'test_loss': [2.1352227247214017], 'train_acc': [0.19176], 'test_acc': [0.1925]}, 'b=0.1_e=0_g=0.01': {'train_loss': [1.8834943771362305], 'test_loss': [2.478992863546444], 'train_acc': [0.11404], 'test_acc': [0.1119]}, 'b=0.01_e=0_g=0.1': {'train_loss': [1.8453700542449951], 'test_loss': [1.9831971729858011], 'train_acc': [0.21004], 'test_acc': [0.2098]}, 'b=0.1_e=0_g=0.1': {'train_loss': [1.8980433940887451], 'test_loss': [2.0331920204283316], 'train_acc': [0.21782], 'test_acc': [0.2236]}, 'b=0.01_e=0_g=1': {'train_loss': [1.8733949661254883], 'test_loss': [2.4101830615273006], 'train_acc': [0.18008], 'test_acc': [0.1791]}, 'b=0.1_e=0_g=1': {'train_loss': [1.8650909662246704], 'test_loss': [2.051742170430437], 'train_acc': [0.20886], 'test_acc': [0.2105]}, 'b=0.01_e=0_g=2': {'train_loss': [1.895500898361206], 'test_loss': [2.606310518482063], 'train_acc': [0.106], 'test_acc': [0.1082]}, 'b=0.1_e=0_g=2': {'train_loss': [1.8465735912322998], 'test_loss': [1.935309342191189], 'train_acc': [0.21564], 'test_acc': [0.2219]}}\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "\n",
    "with open('./results/hyp_results_20250514_221521.pkl', 'rb') as file:\n",
    "    data = pickle.load(file)\n",
    "    print(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['b=0.01_e=0_g=0.01', 'b=0.1_e=0_g=0.01', 'b=0.01_e=0_g=0.1', 'b=0.1_e=0_g=0.1', 'b=0.01_e=0_g=1', 'b=0.1_e=0_g=1', 'b=0.01_e=0_g=2', 'b=0.1_e=0_g=2'])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def toy_test_sample():\n",
    "    # Create a simple linear model\n",
    "    class SimpleModel(nn.Module):\n",
    "        def __init__(self):\n",
    "            super(SimpleModel, self).__init__()\n",
    "            self.fc = nn.Linear(3, 1, bias=False)\n",
    "\n",
    "        def forward(self, x):\n",
    "            return self.fc(x)\n",
    "\n",
    "    # Initialize the model\n",
    "    model = SimpleModel()\n",
    "    model.fc.weight.data = torch.tensor([[0.5, -0.5, 1.0]])\n",
    "    optimizer = AdaptiveAdagradType2(model.parameters(), b=0.01, epsilon=1e-6)\n",
    "\n",
    "    # Toy data\n",
    "    inputs = torch.tensor([[1.0, 2.0, 3.0]], requires_grad=True)\n",
    "    target = torch.tensor([[1.0]])\n",
    "    criterion = nn.MSELoss()\n",
    "\n",
    "    # Forward pass\n",
    "    output = model(inputs)\n",
    "    loss = criterion(output, target)\n",
    "    loss.backward()\n",
    "\n",
    "    # Before update\n",
    "    print(\"Before update:\", model.fc.weight.data)\n",
    "\n",
    "    # Perform update\n",
    "    optimizer.step()\n",
    "\n",
    "    # After update\n",
    "    print(\"After update:\", model.fc.weight.data)\n",
    "\n",
    "\n",
    "# toy_test_sample()\n",
    "\n",
    "# Before update: tensor([[ 0.5000, -0.5000,  1.0000]])\n",
    "# self.param_groups= [{'params': [Parameter containing:\n",
    "# tensor([[ 0.5000, -0.5000,  1.0000]], requires_grad=True)], 'b': 0.01, 'gamma': 2, 'epsilon': 1e-06}]\n",
    "# grad= tensor([[3., 6., 9.]]), sum_grad= tensor([[11.2250, 11.2250, 11.2250]]), lr=tensor([[0.2983, 0.2983, 0.2983]]) \n",
    "# p.data= tensor([[ 0.5000, -0.5000,  1.0000]])\n",
    "# p.data= tensor([[-0.3950, -2.2900, -1.6851]])\n",
    "# After update: tensor([[-0.3950, -2.2900, -1.6851]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
