{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdb1aa70",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Device: cuda\n"
     ]
    }
   ],
   "source": [
    "import os, math, time, random\n",
    "from dataclasses import dataclass\n",
    "from typing import Tuple\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR, ConstantLR\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import transforms, datasets, models\n",
    "from torchvision.datasets.utils import download_and_extract_archive\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(\"Device:\", device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14016bd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "AUTO_DOWNLOAD_TINYIN = True\n",
    "\n",
    "@dataclass\n",
    "class TrainCfg:\n",
    "    name: str\n",
    "    dataset: str\n",
    "    num_classes: int\n",
    "    input_size: int\n",
    "    batch_size: int = 128\n",
    "    epochs: int = 200\n",
    "    optimizer: str = \"SGD\"\n",
    "    lr: float = 0.01\n",
    "    momentum: float = 0.9\n",
    "    weight_decay: float = 0.0\n",
    "    cosine: bool = True\n",
    "    warmup_epochs: int = 5\n",
    "    cifar_style_resnet: bool = False\n",
    "\n",
    "cfgs = {\n",
    "    # MNIST (1×28×28), FC baselines\n",
    "    \"FC2-M\":  TrainCfg(\"FC2\",  \"MNIST\",         10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "    \"FC5-M\":  TrainCfg(\"FC5\",  \"MNIST\",         10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "    \"FC12-M\": TrainCfg(\"FC12\", \"MNIST\",         10, input_size=28, epochs=10, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "\n",
    "    # Fashion-MNIST (1×28×28), FC baselines\n",
    "    \"FC2-FM\":  TrainCfg(\"FC2\",  \"FashionMNIST\", 10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "    \"FC5-FM\":  TrainCfg(\"FC5\",  \"FashionMNIST\", 10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "    \"FC12-FM\": TrainCfg(\"FC12\", \"FashionMNIST\", 10, input_size=28, epochs=10, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "\n",
    "    # CIFAR-10\n",
    "    \"FC5\":     TrainCfg(\"FC5\",    \"CIFAR10\", 10, input_size=32, lr=0.01, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "    \"FC12\":    TrainCfg(\"FC12\",   \"CIFAR10\", 10, input_size=32, lr=0.01, cosine=False, warmup_epochs=0, weight_decay=0.0),\n",
    "    \"VGG16\":   TrainCfg(\"VGG16\",  \"CIFAR10\", 10, input_size=32, lr=0.01, cosine=True,  warmup_epochs=5, weight_decay=5e-4),\n",
    "    \"AlexNet\": TrainCfg(\"AlexNet\",\"CIFAR10\", 10, input_size=224,lr=0.01, cosine=True,  warmup_epochs=5, weight_decay=5e-4),  # upsample to 224\n",
    "\n",
    "    # CIFAR-100\n",
    "    \"ResNet18_C100\": TrainCfg(\"ResNet18_C100\",\"CIFAR100\",100,input_size=32,lr=0.1, cosine=True, warmup_epochs=5, weight_decay=5e-4, cifar_style_resnet=True),\n",
    "    \"ResNet50_C100\": TrainCfg(\"ResNet50_C100\",\"CIFAR100\",100,input_size=32,lr=0.1, cosine=True, warmup_epochs=5, weight_decay=5e-4, cifar_style_resnet=True),\n",
    "\n",
    "    # TinyImageNet (optional; place dataset under ./data/tiny-imagenet-200)\n",
    "    \"ResNet18_TinyIN\": TrainCfg(\"ResNet18_TinyIN\",\"TinyImageNet\",200,input_size=64,lr=0.01,cosine=True,warmup_epochs=5,weight_decay=5e-4),\n",
    "    \"ResNet50_TinyIN\": TrainCfg(\"ResNet50_TinyIN\",\"TinyImageNet\",200,input_size=64,lr=0.01,cosine=True,warmup_epochs=5,weight_decay=5e-4),\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b70f1c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dataset statistics\n",
    "MNIST_MEAN,   MNIST_STD   = (0.1307,), (0.3081,)\n",
    "FASHION_MEAN, FASHION_STD = (0.2860,), (0.3530,)\n",
    "CIFAR10_MEAN, CIFAR10_STD   = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)\n",
    "CIFAR100_MEAN, CIFAR100_STD = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)\n",
    "IMAGENET_MEAN, IMAGENET_STD = (0.485, 0.456, 0.406),    (0.229, 0.224, 0.225)\n",
    "\n",
    "def _cifar_transforms(size: int, mean, std, resize_to_224: bool = False):\n",
    "    if size == 32 and not resize_to_224:\n",
    "        train_tfms = transforms.Compose([\n",
    "            transforms.RandomCrop(32, padding=4),\n",
    "            transforms.RandomHorizontalFlip(),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean, std),\n",
    "        ])\n",
    "        test_tfms = transforms.Compose([\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean, std),\n",
    "        ])\n",
    "    else:\n",
    "        train_tfms = transforms.Compose([\n",
    "            transforms.Resize(size),\n",
    "            transforms.RandomHorizontalFlip(),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean, std),\n",
    "        ])\n",
    "        test_tfms = transforms.Compose([\n",
    "            transforms.Resize(size),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean, std),\n",
    "        ])\n",
    "    return train_tfms, test_tfms\n",
    "\n",
    "def _gray_transforms(size: int, mean, std):\n",
    "    train_tfms = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean, std),\n",
    "    ])\n",
    "    test_tfms = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean, std),\n",
    "    ])\n",
    "    return train_tfms, test_tfms\n",
    "\n",
    "def get_dataloaders(cfg: TrainCfg, data_root: str = \"./data\") -> Tuple[DataLoader, DataLoader]:\n",
    "    if cfg.dataset == \"MNIST\":\n",
    "        train_tfms, test_tfms = _gray_transforms(28, MNIST_MEAN, MNIST_STD)\n",
    "        train_set = datasets.MNIST(root=data_root, train=True,  download=True, transform=train_tfms)\n",
    "        test_set  = datasets.MNIST(root=data_root, train=False, download=True, transform=test_tfms)\n",
    "\n",
    "    elif cfg.dataset == \"FashionMNIST\":\n",
    "        train_tfms, test_tfms = _gray_transforms(28, FASHION_MEAN, FASHION_STD)\n",
    "        train_set = datasets.FashionMNIST(root=data_root, train=True,  download=True, transform=train_tfms)\n",
    "        test_set  = datasets.FashionMNIST(root=data_root, train=False, download=True, transform=test_tfms)\n",
    "\n",
    "    elif cfg.dataset == \"CIFAR10\":\n",
    "        resize_to_224 = (cfg.input_size == 224)\n",
    "        train_tfms, test_tfms = _cifar_transforms(cfg.input_size, CIFAR10_MEAN, CIFAR10_STD, resize_to_224=resize_to_224)\n",
    "        train_set = datasets.CIFAR10(root=data_root, train=True,  download=True, transform=train_tfms)\n",
    "        test_set  = datasets.CIFAR10(root=data_root, train=False, download=True, transform=test_tfms)\n",
    "\n",
    "    elif cfg.dataset == \"CIFAR100\":\n",
    "        train_tfms, test_tfms = _cifar_transforms(cfg.input_size, CIFAR100_MEAN, CIFAR100_STD, resize_to_224=False)\n",
    "        train_set = datasets.CIFAR100(root=data_root, train=True,  download=True, transform=train_tfms)\n",
    "        test_set  = datasets.CIFAR100(root=data_root, train=False, download=True, transform=test_tfms)\n",
    "\n",
    "    elif cfg.dataset == \"TinyImageNet\":\n",
    "        train_dir = os.path.join(data_root, \"tiny-imagenet-200\", \"train\")\n",
    "        val_dir   = os.path.join(data_root, \"tiny-imagenet-200\", \"val\")\n",
    "        train_tfms = transforms.Compose([\n",
    "            transforms.RandomResizedCrop(cfg.input_size),\n",
    "            transforms.RandomHorizontalFlip(),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),\n",
    "        ])\n",
    "        test_tfms = transforms.Compose([\n",
    "            transforms.Resize(cfg.input_size + 8),\n",
    "            transforms.CenterCrop(cfg.input_size),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),\n",
    "        ])\n",
    "        train_set = datasets.ImageFolder(train_dir, transform=train_tfms)\n",
    "        test_set  = datasets.ImageFolder(val_dir,   transform=test_tfms)\n",
    "\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown dataset {cfg.dataset}\")\n",
    "\n",
    "    pin = torch.cuda.is_available()\n",
    "    train_loader = DataLoader(train_set, batch_size=cfg.batch_size, shuffle=True,  num_workers=4, pin_memory=pin)\n",
    "    test_loader  = DataLoader(test_set,  batch_size=cfg.batch_size, shuffle=False, num_workers=4, pin_memory=pin)\n",
    "    return train_loader, test_loader\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf271163",
   "metadata": {},
   "outputs": [],
   "source": [
    "class FCNet(nn.Module):\n",
    "    def __init__(self, in_dim: int, widths):\n",
    "        super().__init__()\n",
    "        dims = [in_dim] + list(widths)\n",
    "        layers = []\n",
    "        for i in range(len(dims) - 2):\n",
    "            layers += [nn.Linear(dims[i], dims[i + 1]), nn.ReLU(inplace=True)]\n",
    "        layers += [nn.Linear(dims[-2], dims[-1])]\n",
    "        self.net = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.flatten(x, 1)\n",
    "        return self.net(x)\n",
    "\n",
    "def _infer_in_dim(cfg: TrainCfg) -> int:\n",
    "    c = 1 if cfg.dataset in (\"MNIST\", \"FashionMNIST\") else 3\n",
    "    return c * cfg.input_size * cfg.input_size\n",
    "\n",
    "def build_model(cfg: TrainCfg) -> nn.Module:\n",
    "    in_dim = _infer_in_dim(cfg)\n",
    "\n",
    "    if cfg.name == \"FC2\":\n",
    "        model = FCNet(in_dim, [100, cfg.num_classes])\n",
    "    elif cfg.name == \"FC5\":\n",
    "        model = FCNet(in_dim, [1000, 600, 300, 100, cfg.num_classes])\n",
    "    elif cfg.name == \"FC12\":\n",
    "        model = FCNet(in_dim, [1000, 900, 800, 750, 700, 650, 600, 500, 400, 200, 100, cfg.num_classes])\n",
    "\n",
    "    elif cfg.name == \"AlexNet\":\n",
    "        model = models.alexnet(weights=None)\n",
    "        if isinstance(model.classifier, nn.Sequential):\n",
    "            in_features = model.classifier[-1].in_features\n",
    "            model.classifier[-1] = nn.Linear(in_features, cfg.num_classes)\n",
    "\n",
    "    elif cfg.name == \"VGG16\":\n",
    "        model = models.vgg16(weights=None)\n",
    "        model.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
    "        model.classifier = nn.Linear(512, cfg.num_classes)\n",
    "\n",
    "    elif \"ResNet18\" in cfg.name:\n",
    "        model = models.resnet18(weights=None, num_classes=cfg.num_classes)\n",
    "        if cfg.cifar_style_resnet:\n",
    "            model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "            model.maxpool = nn.Identity()\n",
    "\n",
    "    elif \"ResNet50\" in cfg.name:\n",
    "        model = models.resnet50(weights=None, num_classes=cfg.num_classes)\n",
    "        if cfg.cifar_style_resnet:\n",
    "            model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "            model.maxpool = nn.Identity()\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown model name {cfg.name}\")\n",
    "\n",
    "    return model.to(device)\n",
    "\n",
    "def count_params(model: nn.Module):\n",
    "    total = sum(p.numel() for p in model.parameters())\n",
    "    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "    return total, trainable\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27373cf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_optimizer_and_scheduler(model: nn.Module, cfg: TrainCfg):\n",
    "    if cfg.optimizer.upper() == \"SGD\":\n",
    "        opt = optim.SGD(model.parameters(), lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay)\n",
    "    \n",
    "    if cfg.optimizer.upper() == \"ADAM\":\n",
    "        opt = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)\n",
    "        sch = ConstantLR(opt, factor=1.0, total_iters=cfg.epochs)\n",
    "        return opt, sch\n",
    "\n",
    "    if cfg.cosine:\n",
    "        warm = LinearLR(opt, start_factor=1e-8, end_factor=1.0, total_iters=max(1, cfg.warmup_epochs))\n",
    "        cos  = CosineAnnealingLR(opt, T_max=max(1, cfg.epochs - cfg.warmup_epochs))\n",
    "        sch  = SequentialLR(opt, schedulers=[warm, cos], milestones=[cfg.warmup_epochs])\n",
    "    else:\n",
    "        if cfg.warmup_epochs > 0:\n",
    "            warm = LinearLR(opt, start_factor=1e-8, end_factor=1.0, total_iters=cfg.warmup_epochs)\n",
    "            const = ConstantLR(opt, factor=1.0, total_iters=max(1, cfg.epochs - cfg.warmup_epochs))\n",
    "            sch = SequentialLR(opt, [warm, const], milestones=[cfg.warmup_epochs])\n",
    "        else:\n",
    "            sch = ConstantLR(opt, factor=1.0, total_iters=cfg.epochs)\n",
    "    return opt, sch\n",
    "\n",
    "def train(model: nn.Module, train_loader: DataLoader, test_loader: DataLoader, cfg: TrainCfg, exp_tag: str):\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    optimizer, scheduler = make_optimizer_and_scheduler(model, cfg)\n",
    "\n",
    "    best_acc = 0.0\n",
    "    for epoch in range(cfg.epochs):\n",
    "        model.train()\n",
    "        running_loss, running_correct, seen = 0.0, 0, 0\n",
    "        start = time.time()\n",
    "\n",
    "        for x, y in train_loader:\n",
    "            x, y = x.to(device), y.to(device)\n",
    "            optimizer.zero_grad(set_to_none=True)\n",
    "            logits = model(x)\n",
    "            loss = criterion(logits, y)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            running_loss += loss.item() * y.size(0)\n",
    "            running_correct += (logits.argmax(1) == y).sum().item()\n",
    "            seen += y.size(0)\n",
    "\n",
    "        scheduler.step()\n",
    "        train_loss, train_acc = running_loss/seen, 100.0*running_correct/seen\n",
    "        val_loss, val_acc = evaluate(model, test_loader, criterion)\n",
    "\n",
    "        if val_acc > best_acc:\n",
    "            best_acc = val_acc\n",
    "            os.makedirs(\"checkpoints\", exist_ok=True)\n",
    "            torch.save({\"model\": model.state_dict(), \"cfg\": cfg.__dict__, \"epoch\": epoch}, f\"checkpoints/{exp_tag}_best.pth\")\n",
    "\n",
    "        print(f\"[{epoch+1:03d}/{cfg.epochs}] \"\n",
    "              f\"train_loss={train_loss:.4f} acc={train_acc:.2f}% | \"\n",
    "              f\"val_loss={val_loss:.4f} acc={val_acc:.2f}% | \"\n",
    "              f\"best={best_acc:.2f}% | lr={optimizer.param_groups[0]['lr']:.5f} | \"\n",
    "              f\"time={time.time()-start:.1f}s\")\n",
    "\n",
    "    print(\"Best val acc:\", best_acc)\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate(model: nn.Module, loader: DataLoader, criterion):\n",
    "    model.eval()\n",
    "    total_loss, total_correct, n = 0.0, 0, 0\n",
    "    for x, y in loader:\n",
    "        x, y = x.to(device), y.to(device)\n",
    "        logits = model(x)\n",
    "        loss = criterion(logits, y)\n",
    "        total_loss += loss.item() * y.size(0)\n",
    "        total_correct += (logits.argmax(1) == y).sum().item()\n",
    "        n += y.size(0)\n",
    "    return total_loss / n, 100.0 * total_correct / n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e024aeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in cfgs.keys():\n",
    "    model_config = cfgs[key]\n",
    "    train_loader, test_loader = get_dataloaders(model_config)\n",
    "    model = build_model(model_config)\n",
    "    tag = f\"{model_config.name}_{model_config.dataset}\"\n",
    "    train(model, train_loader, test_loader, model_config, tag)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
