{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b76de58",
   "metadata": {},
   "outputs": [],
   "source": [
    "import timm\n",
    "print(\"\\n\".join(timm.list_models('*')))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40a78ffd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cell 1: Imports and Dummy Dataset\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import SwinModel, SwinConfig\n",
    "import torchvision.transforms as transforms\n",
    "from torchvision.datasets import ImageFolder\n",
    "import os\n",
    "\n",
    "# Dummy dataset loader (replace with ML-JET loader)\n",
    "class DummyDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, size=1000):\n",
    "        self.data = torch.rand(size, 1, 32, 32)\n",
    "        self.labels_energy = torch.randint(0, 2, (size, 1)).float()\n",
    "        self.labels_alpha = torch.randint(0, 3, (size,))\n",
    "        self.labels_q0 = torch.randint(0, 4, (size,))\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.data[idx], {\n",
    "            'energy_loss_output': self.labels_energy[idx],\n",
    "            'alpha_output': self.labels_alpha[idx],\n",
    "            'q0_output': self.labels_q0[idx]\n",
    "        }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "758c42b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cell 2: Custom SwinTiny for 32×32 without upsampling\n",
    "\n",
    "import torch.nn as nn\n",
    "from timm.models.swin_transformer import SwinTransformer\n",
    "\n",
    "class SwinMultiHeadClassifier(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        # 1) Build a Swin that takes 1×32×32 inputs directly:\n",
    "        #    - patch_size divides 32, e.g. 4 → produces (32/4=8) patches per dim\n",
    "        #    - window_size also divides 8, e.g. 4 → non‐overlapping windows\n",
    "        self.backbone = SwinTransformer(\n",
    "            img_size=32,\n",
    "            patch_size=4,\n",
    "            in_chans=1,\n",
    "            embed_dim=96,\n",
    "            depths=[2, 2, 6, 2],\n",
    "            num_heads=[3, 6, 12, 24],\n",
    "            window_size=4,\n",
    "            mlp_ratio=4.,\n",
    "            qkv_bias=True,\n",
    "            drop_rate=0.,\n",
    "            attn_drop_rate=0.,\n",
    "            drop_path_rate=0.1,\n",
    "            norm_layer=nn.LayerNorm,\n",
    "            patch_norm=True,\n",
    "            use_checkpoint=False,\n",
    "            num_classes=0  # <— so it returns features, not 1000‐class logits\n",
    "        )\n",
    "        # 2) The final feature dim is embed_dim * 2^(len(depths)-1)\n",
    "        #    Since swin builds hierarchies, use its num_features attr:\n",
    "        self.feature_dim = self.backbone.num_features\n",
    "\n",
    "        # 3) Classification heads\n",
    "        self.energy_head = nn.Linear(self.feature_dim, 1)\n",
    "        self.alpha_head  = nn.Linear(self.feature_dim, 3)\n",
    "        self.q0_head     = nn.Linear(self.feature_dim, 4)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # x: (B,1,32,32)\n",
    "        # SwinTransformer returns (B, num_features) when num_classes=0\n",
    "        feats = self.backbone(x)\n",
    "        return {\n",
    "            'energy_loss_output': self.energy_head(feats),\n",
    "            'alpha_output':       self.alpha_head(feats),\n",
    "            'q0_output':          self.q0_head(feats)\n",
    "        }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "293371ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cell 3: Loss Computation\n",
    "\n",
    "def compute_loss(outputs, targets):\n",
    "    bce = nn.BCEWithLogitsLoss()\n",
    "    ce = nn.CrossEntropyLoss()\n",
    "    loss_energy = bce(outputs['energy_loss_output'], targets['energy_loss_output'])\n",
    "    loss_alpha = ce(outputs['alpha_output'], targets['alpha_output'])\n",
    "    loss_q0 = ce(outputs['q0_output'], targets['q0_output'])\n",
    "    return loss_energy + loss_alpha + loss_q0\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60d57236",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cell 4: Training Loop Function\n",
    "\n",
    "def train_one_epoch(model, loader, optimizer, device):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    for x, y in loader:\n",
    "        x = x.to(device)\n",
    "        y = {k: v.to(device) for k, v in y.items()}\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(x)\n",
    "        loss = compute_loss(outputs, y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total_loss += loss.item()\n",
    "    return total_loss / len(loader)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e45a2152",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cell 5: Main Training Script\n",
    "\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model = SwinMultiHeadClassifier().to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
    "train_loader = DataLoader(DummyDataset(), batch_size=32, shuffle=True)\n",
    "\n",
    "for epoch in range(5):\n",
    "    loss = train_one_epoch(model, train_loader, optimizer, device)\n",
    "    print(f\"Epoch {epoch+1}, Loss: {loss:.4f}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
