{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 🧪  Local Testing Notebook\n",
    "# Author: John Doe\n",
    "\n",
    "import os\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from data.loader import load_split_from_csv, JetDataset\n",
    "\n",
    "# ✅ Configuration\n",
    "root_dir = \"/home/johndoe/Projects/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_1000_balanced_unshuffled\"\n",
    "global_max = 121.79151153564453\n",
    "train_file = os.path.join(root_dir, \"train_files.csv\")\n",
    "\n",
    "# ✅ Load file-label list\n",
    "train_list = load_split_from_csv(train_file, root_dir)\n",
    "\n",
    "# ✅ Subset to test quickly\n",
    "train_subset = train_list[:16]  # Load small batch\n",
    "\n",
    "# ✅ Build Dataset\n",
    "train_dataset = JetDataset(train_subset, global_max=global_max)\n",
    "\n",
    "# ✅ Fetch sample\n",
    "img, labels = train_dataset[0]\n",
    "\n",
    "# ✅ Display Image\n",
    "plt.imshow(img.squeeze(0), cmap='hot')\n",
    "plt.title(f\"Energy Loss: {labels['energy_loss_output'].item()}, \"\n",
    "          f\"Alpha: {labels['alpha_output'].item()}, \"\n",
    "          f\"Q0: {labels['q0_output'].item()}\")\n",
    "plt.colorbar()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from models.model import create_model\n",
    "\n",
    "# ✅ Choose backbone: 'efficientnet', 'convnext', 'swin', or 'mamba'\n",
    "backbone = 'efficientnet'\n",
    "# backbone = 'convnext'\n",
    "# backbone = 'swin'\n",
    "# backbone = 'mamba'\n",
    "\n",
    "# ✅ Create model and optimizer\n",
    "model, optimizer = create_model(backbone=backbone, input_shape=(1, 32, 32), learning_rate=1e-4)\n",
    "\n",
    "# ✅ Move to GPU if available\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    "\n",
    "# ✅ Print model summary\n",
    "print(model)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
