{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ed49939",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from mamba_ssm import Mamba\n",
    "\n",
    "batch, length, dim = 2, 64, 16\n",
    "x = torch.randn(batch, length, dim).to(\"cuda\")\n",
    "model = Mamba(\n",
    "    # This module uses roughly 3 * expand * d_model^2 parameters\n",
    "    d_model=dim, # Model dimension d_model\n",
    "    d_state=16,  # SSM state expansion factor\n",
    "    d_conv=4,    # Local convolution width\n",
    "    expand=2,    # Block expansion factor\n",
    ").to(\"cuda\")\n",
    "y = model(x)\n",
    "assert y.shape == x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2bed972",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cell 1: Imports and Dummy Dataset\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import SwinModel, SwinConfig\n",
    "import torchvision.transforms as transforms\n",
    "from torchvision.datasets import ImageFolder\n",
    "import os\n",
    "\n",
    "# Dummy dataset loader (replace with ML-JET loader)\n",
    "class DummyDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, size=1000):\n",
    "        self.data = torch.rand(size, 1, 32, 32)\n",
    "        self.labels_energy = torch.randint(0, 2, (size, 1)).float()\n",
    "        self.labels_alpha = torch.randint(0, 3, (size,))\n",
    "        self.labels_q0 = torch.randint(0, 4, (size,))\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.data[idx], {\n",
    "            'energy_loss_output': self.labels_energy[idx],\n",
    "            'alpha_output': self.labels_alpha[idx],\n",
    "            'q0_output': self.labels_q0[idx]\n",
    "        }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3db239d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MambaVisionMultiHead(nn.Module):\n",
    "    def __init__(self, in_chans=1, img_size=32, embed_dim=128, mamba_layers=4, mamba_hidden=256):\n",
    "        super().__init__()\n",
    "        self.proj = nn.Sequential(\n",
    "            nn.Conv2d(in_chans, embed_dim, kernel_size=3, padding=1),\n",
    "            nn.Flatten(2),\n",
    "            nn.Linear(img_size*img_size, img_size),\n",
    "           \n",
    "        )\n",
    "        self.norm= nn.LayerNorm(embed_dim)\n",
    "        self.mamba = Mamba(d_model=embed_dim, d_state=mamba_hidden, d_conv=mamba_layers)\n",
    "        self.pool = nn.AdaptiveAvgPool1d(1)\n",
    "        self.head_energy = nn.Linear(embed_dim, 1)\n",
    "        self.head_alpha  = nn.Linear(embed_dim, 3)\n",
    "        self.head_q0     = nn.Linear(embed_dim, 4)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # x: (B,1,32,32)\n",
    "        z = self.proj(x)               # (B, embed_dim, 32)\n",
    "        z = z.permute(2,0,1)           # (seq_len, B, embed_dim)\n",
    "        out_seq = self.mamba(z)        # (seq_len, B, embed_dim)\n",
    "        feat = out_seq[-1]             # (B, embed_dim)\n",
    "        return {\n",
    "            'energy_loss_output': self.head_energy(feat),\n",
    "            'alpha_output':  self.head_alpha(feat),\n",
    "            'q0_output':     self.head_q0(feat)\n",
    "        }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8847b47",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.optim as optim            # ← add this line\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "model = MambaVisionMultiHead(\n",
    "    in_chans=1, img_size=32,\n",
    "    embed_dim=128, mamba_layers=4, mamba_hidden=256\n",
    ").to(device)\n",
    "\n",
    "# Losses\n",
    "crit_e = nn.BCEWithLogitsLoss()\n",
    "crit_a = nn.CrossEntropyLoss()\n",
    "crit_q = nn.CrossEntropyLoss()\n",
    "\n",
    "def composite_loss(preds, labels):\n",
    "    le = crit_e(preds['energy_loss_output'], labels['energy_loss_output'].float())\n",
    "    la = crit_a(preds['alpha_output'],  labels['alpha_output'])\n",
    "    lq = crit_q(preds['q0_output'],     labels['q0_output'])\n",
    "    return le + la + lq, (le.item(), la.item(), lq.item())\n",
    "\n",
    "# Optimizer\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-4)\n",
    "\n",
    "# Step LR scheduler: drop LR by 0.1 every 15 epochs\n",
    "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07f6219b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "def train_one_epoch(loader):\n",
    "    model.train()\n",
    "    running_loss = 0.0\n",
    "    for x, labels in tqdm(loader, desc='Train'):\n",
    "        x = x.to(device)\n",
    "        labels = {k: v.to(device) for k,v in labels.items()}\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        preds = model(x)\n",
    "        loss, _ = composite_loss(preds, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        running_loss += loss.item()\n",
    "    return running_loss / len(loader)\n",
    "\n",
    "@torch.no_grad()\n",
    "def validate(loader):\n",
    "    model.eval()\n",
    "    total_loss = 0.0\n",
    "    correct_e = correct_a = correct_q = total = 0\n",
    "\n",
    "    for x, labels in tqdm(loader, desc='Val'):\n",
    "        x = x.to(device)\n",
    "        labels = {k: v.to(device) for k,v in labels.items()}\n",
    "        preds = model(x)\n",
    "        loss, _ = composite_loss(preds, labels)\n",
    "        total_loss += loss.item()\n",
    "\n",
    "        e_pred = (torch.sigmoid(preds['energy_loss_output']) > 0.5).long()\n",
    "        a_pred = preds['alpha_output'].argmax(dim=1)\n",
    "        q_pred = preds['q0_output'].argmax(dim=1)\n",
    "\n",
    "        correct_e += (e_pred == labels['energy_loss_output']).sum().item()\n",
    "        correct_a += (a_pred == labels['alpha_output']).sum().item()\n",
    "        correct_q += (q_pred == labels['q0_output']).sum().item()\n",
    "        total += x.size(0)\n",
    "\n",
    "    return {\n",
    "        'loss':   total_loss / len(loader),\n",
    "        'acc_e':  correct_e / total,\n",
    "        'acc_a':  correct_a / total,\n",
    "        'acc_q':  correct_q / total\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1197ec93",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, val_loader, test_loader = DataLoader(DummyDataset(), batch_size=32, shuffle=True), DataLoader(DummyDataset(), batch_size=32, shuffle=True),DataLoader(DummyDataset(), batch_size=32, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bd98449",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "num_epochs = 1\n",
    "best_val = float('inf')\n",
    "\n",
    "for epoch in range(1, num_epochs+1):\n",
    "    train_loss = train_one_epoch(train_loader)\n",
    "    val_metrics = validate(val_loader)\n",
    "    scheduler.step()\n",
    "\n",
    "    print(f\"Epoch {epoch:02d} | \"\n",
    "          f\"Train Loss: {train_loss:.4f} | \"\n",
    "          f\"Val Loss: {val_metrics['loss']:.4f} | \"\n",
    "          f\"E Acc: {val_metrics['acc_e']:.2%} | \"\n",
    "          f\"A Acc: {val_metrics['acc_a']:.2%} | \"\n",
    "          f\"Q Acc: {val_metrics['acc_q']:.2%} | \"\n",
    "          f\"LR: {scheduler.get_last_lr()[0]:.2e}\"\n",
    "    )\n",
    "\n",
    "    if val_metrics['loss'] < best_val:\n",
    "        best_val = val_metrics['loss']\n",
    "        torch.save(model.state_dict(), 'best_mamba_vision.pth')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fdced46",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cell 3: Loss Computation\n",
    "\n",
    "def compute_loss(outputs, targets):\n",
    "    bce = nn.BCEWithLogitsLoss()\n",
    "    ce = nn.CrossEntropyLoss()\n",
    "    loss_energy = bce(outputs['energy_loss_output'], targets['energy_loss_output'])\n",
    "    loss_alpha = ce(outputs['alpha_output'], targets['alpha_output'])\n",
    "    loss_q0 = ce(outputs['q0_output'], targets['q0_output'])\n",
    "    return loss_energy + loss_alpha + loss_q0\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "737191ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cell 4: Training Loop Function\n",
    "\n",
    "def train_one_epoch(model, loader, optimizer, device):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    for x, y in loader:\n",
    "        x = x.to(device)\n",
    "        y = {k: v.to(device) for k, v in y.items()}\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(x)\n",
    "        loss = compute_loss(outputs, y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total_loss += loss.item()\n",
    "    return total_loss / len(loader)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df824744",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01e374f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cell 5: Main Training Script\n",
    "\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model = MambaVisionMultiHead().to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
    "train_loader = DataLoader(DummyDataset(), batch_size=32, shuffle=True)\n",
    "\n",
    "for epoch in range(5):\n",
    "    loss = train_one_epoch(model, train_loader, optimizer, device)\n",
    "    print(f\"Epoch {epoch+1}, Loss: {loss:.4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dd70517",
   "metadata": {},
   "outputs": [],
   "source": [
    "import timm\n",
    "# from mamba_ssm import Mamba\n",
    "# List all Mamba model names\n",
    "all_models = timm.list_models('mamba*')\n",
    "print(\"🔍 All Mamba models:\")\n",
    "for name in all_models:\n",
    "    print(\"  -\", name)\n",
    "\n",
    "# List only pretrained Mamba models\n",
    "pretrained_models = timm.list_models('mamba*', pretrained=True)\n",
    "print(\"\\n✅ Pretrained Mamba models:\")\n",
    "for name in pretrained_models:\n",
    "    print(\"  -\", name)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
