{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b492eab5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import timm\n",
    "print(\"\\n\".join(timm.list_models('vi*')))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e45583e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cell 1: Imports and Dummy Dataset\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import SwinModel, SwinConfig\n",
    "import torchvision.transforms as transforms\n",
    "from torchvision.datasets import ImageFolder\n",
    "import os\n",
    "\n",
    "# Dummy dataset loader (replace with ML-JET loader)\n",
    "class DummyDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, size=1000):\n",
    "        self.data = torch.rand(size, 1, 32, 32)\n",
    "        self.labels_energy = torch.randint(0, 2, (size, 1)).float()\n",
    "        self.labels_alpha = torch.randint(0, 3, (size,))\n",
    "        self.labels_q0 = torch.randint(0, 4, (size,))\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.data[idx], {\n",
    "            'energy_loss_output': self.labels_energy[idx],\n",
    "            'alpha_output': self.labels_alpha[idx],\n",
    "            'q0_output': self.labels_q0[idx]\n",
    "        }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45ad7bee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "from timm.models.vision_transformer import VisionTransformer\n",
    "\n",
    "class ViTTinyMultiHeadClassifier(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        # 1) Build a ViT that takes 1×32×32 inputs directly:\n",
    "        #    - img_size=32, patch_size=4 → (32/4=8)^2 = 64 patches\n",
    "        #    - embed_dim small for speed, e.g. 192\n",
    "        #    - depth=4 layers, num_heads=3 (must divide 192)\n",
    "        self.backbone = VisionTransformer(\n",
    "            img_size=32,\n",
    "            patch_size=4,\n",
    "            in_chans=1,\n",
    "            embed_dim=192,\n",
    "            depth=4,\n",
    "            num_heads=3,\n",
    "            mlp_ratio=4.0,\n",
    "            qkv_bias=True,\n",
    "            drop_rate=0.0,\n",
    "            attn_drop_rate=0.0,\n",
    "            drop_path_rate=0.1,\n",
    "            norm_layer=nn.LayerNorm,\n",
    "            num_classes=0,        # returns features\n",
    "        )\n",
    "\n",
    "        # 2) Feature dimension from ViT’s output:\n",
    "        self.feature_dim = self.backbone.embed_dim  # 192\n",
    "\n",
    "        # 3) Multi‑head classification layers:\n",
    "        self.energy_head = nn.Linear(self.feature_dim, 1)\n",
    "        self.alpha_head  = nn.Linear(self.feature_dim, 3)\n",
    "        self.q0_head     = nn.Linear(self.feature_dim, 4)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # x: (B,1,32,32)\n",
    "        feats = self.backbone(x)    # → (B, feature_dim)\n",
    "        return {\n",
    "            'energy_loss_output': self.energy_head(feats),\n",
    "            'alpha_output':       self.alpha_head(feats),\n",
    "            'q0_output':          self.q0_head(feats)\n",
    "        }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "781e6784",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cell 3: Loss Computation\n",
    "\n",
    "def compute_loss(outputs, targets):\n",
    "    bce = nn.BCEWithLogitsLoss()\n",
    "    ce = nn.CrossEntropyLoss()\n",
    "    loss_energy = bce(outputs['energy_loss_output'], targets['energy_loss_output'])\n",
    "    loss_alpha = ce(outputs['alpha_output'], targets['alpha_output'])\n",
    "    loss_q0 = ce(outputs['q0_output'], targets['q0_output'])\n",
    "    return loss_energy + loss_alpha + loss_q0\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00ef3fbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cell 4: Training Loop Function\n",
    "\n",
    "def train_one_epoch(model, loader, optimizer, device):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    for x, y in loader:\n",
    "        x = x.to(device)\n",
    "        y = {k: v.to(device) for k, v in y.items()}\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(x)\n",
    "        loss = compute_loss(outputs, y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total_loss += loss.item()\n",
    "    return total_loss / len(loader)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac0c1e49",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cell 5: Main Training Script\n",
    "\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model = ViTTinyMultiHeadClassifier().to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
    "train_loader = DataLoader(DummyDataset(), batch_size=32, shuffle=True)\n",
    "\n",
    "for epoch in range(5):\n",
    "    loss = train_one_epoch(model, train_loader, optimizer, device)\n",
    "    print(f\"Epoch {epoch+1}, Loss: {loss:.4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b435ea7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "# point to the parent directory\n",
    "sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))\n",
    "from models.model import create_model\n",
    "class ViTParamClassifier(nn.Module):\n",
    "    def __init__(self, vit_name='vit_base_patch16_224', pretrained=True, num_alpha=3, num_q0=4):\n",
    "        super().__init__()\n",
    "        # Load a ViT without its default head\n",
    "        self.vit = create_model(vit_name, pretrained=pretrained, num_classes=0)\n",
    "        embed_dim = self.vit.num_features  # e.g. 768\n",
    "\n",
    "        # Single binary head for Energy‑Loss Module\n",
    "        self.energy_head = nn.Linear(embed_dim, 1)\n",
    "\n",
    "        # Categorical heads for alpha_s and Q0\n",
    "        self.alpha_head  = nn.Linear(embed_dim, num_alpha)\n",
    "        self.q0_head     = nn.Linear(embed_dim, num_q0)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # x shape: (B, 1, 32, 32)  → expand to 3 channels & resize if needed\n",
    "        if x.shape[1] == 1:\n",
    "            x = x.repeat(1,3,1,1)\n",
    "        x = nn.functional.interpolate(x, size=(224,224), mode='bilinear', align_corners=False)\n",
    "\n",
    "        features = self.vit(x)  # → (B, embed_dim)\n",
    "        return {\n",
    "            'energy': self.energy_head(features).squeeze(1),        # (B,)\n",
    "            'alpha' : self.alpha_head(features),                    # (B,3)\n",
    "            'q0'    : self.q0_head(features)                        # (B,4)\n",
    "        }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d49ea3e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Example losses\n",
    "bce   = nn.BCEWithLogitsLoss()\n",
    "ce    = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
    "\n",
    "for epoch in range(start_epoch, cfg.epochs):\n",
    "    model.train()\n",
    "    for imgs, (y_energy, y_alpha, y_q0) in train_loader:\n",
    "        imgs = imgs.to(device)\n",
    "        out = model(imgs)\n",
    "        loss = (\n",
    "            bce(out['energy'], y_energy.to(device).float()) +\n",
    "            ce(out['alpha'],  y_alpha.to(device)) +\n",
    "            ce(out['q0'],     y_q0.to(device))\n",
    "        )\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    # … validation, logging, early‑stop checks …\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c373b9d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import timm\n",
    "\n",
    "# List all ConvNeXt model names\n",
    "all_models = timm.list_models('vit*')\n",
    "print(\"🔍 All ViT models:\")\n",
    "for name in all_models:\n",
    "    print(\"  -\", name)\n",
    "\n",
    "# List only pretrained ConvNeXt models\n",
    "pretrained_models = timm.list_models('vit*', pretrained=True)\n",
    "print(\"\\n✅ Pretrained ConvNeXt models:\")\n",
    "for name in pretrained_models:\n",
    "    print(\"  -\", name)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
