{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "724e7053",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 🔧 Add root to sys.path for module imports (Jupyter-safe)\n",
    "import sys\n",
    "from pathlib import Path\n",
    "\n",
    "ROOT = Path.cwd().parent.parent  # Go up from /notebooks/ → experiment → project root\n",
    "if str(ROOT) not in sys.path:\n",
    "    sys.path.append(str(ROOT))\n",
    "\n",
    "print(f\"[INFO] Added ROOT to sys.path: {ROOT}\")\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db9135a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import json\n",
    "\n",
    "from config import get_config\n",
    "from train_utils.gpu_utils import get_device_summary\n",
    "from data.loader import get_dataloaders\n",
    "# from models.model import create_model\n",
    "from train_utils.resume import init_resume_state\n",
    "from train_utils.resume import fill_trackers_from_history\n",
    "from train_utils.training_loop import run_training_loop\n",
    "from train_utils.train_epoch import train_one_epoch\n",
    "from train_utils.evaluate import evaluate\n",
    "from train_utils.train_metrics_logger import update_train_logs\n",
    "from train_utils.train_metrics_logger import update_val_logs\n",
    "from train_utils.checkpoint_saver import save_epoch_checkpoint\n",
    "from train_utils.train_metrics_logger import record_and_save_epoch\n",
    "from train_utils.scheduler_utils import create_scheduler\n",
    "from train_utils.early_stopping import check_early_stopping\n",
    "from train_utils.training_summary import finalize_training_summary\n",
    "from train_utils.training_summary import print_best_model_summary\n",
    "from train_utils.plot_metrics import plot_train_val_metrics\n",
    "from train_utils.plot_metrics import plot_loss_accuracy\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "033f6562",
   "metadata": {},
   "outputs": [],
   "source": [
    "# cfg=get_config(config_path=\"/\" \\\n",
    "# \"experiments/exp_25072401_hybrid_mamaba_vit_stack/config/\" \\\n",
    "# \"hybrid_mambaout_base_plus_rw_ViT_tiny_patch16_224_bs16_ep50_lr1e-04_ds1008_g500_sched-RLRP.yml\")\n",
    "\n",
    "# cfg=get_config(config_path=\"/\" \\\n",
    "# \"experiments/exp_25072401_hybrid_mamaba_vit_stack/config/\" \\\n",
    "# \"hybrid_mambaout_base_plus_rw_ViT_tiny_patch16_224_bs16_ep50_lr1e-04_ds1000_g1_sched-RLRP.yml\")\n",
    "\n",
    "cfg=get_config()\n",
    "# print(cfg)\n",
    "print(json.dumps(vars(cfg), indent=2))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a91e4ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = get_device_summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d4f7f73",
   "metadata": {},
   "outputs": [],
   "source": [
    "from experiments.exp_25072401_hybrid_mamaba_vit_stack.models.hybrid_mamba_vit import create_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cf41709",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model and optimizer\n",
    "model, optimizer = create_model(cfg.backbone, cfg.input_shape, cfg.learning_rate)\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04b236de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data\n",
    "train_loader, val_loader, test_loader = get_dataloaders(cfg, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca2ffb22",
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.device_count() > 1:\n",
    "    print(f\"Parallelizing model across {torch.cuda.device_count()} GPUs\")\n",
    "    model = torch.nn.DataParallel(model)\n",
    "elif torch.cuda.device_count() == 1:\n",
    "    print(\"No parallelization, using single GPU\")\n",
    "elif torch.cuda.device_count() == 0:\n",
    "    print(\"No GPU available, using CPU\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f01e6c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "scheduler = create_scheduler(optimizer, cfg, train_loader=train_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5408d84",
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = {\n",
    "    # 'energy_loss_output': nn.BCELoss(),\n",
    "    'energy_loss_output': nn.BCEWithLogitsLoss(),\n",
    "    'alpha_output': nn.CrossEntropyLoss(),\n",
    "    'q0_output': nn.CrossEntropyLoss()\n",
    "}\n",
    "print(f\"[INFO] Loss functions:{criterion}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffdf8b64",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"[INFO] Init Training Trackers\")\n",
    "train_loss_energy_list, train_loss_alpha_list, train_loss_q0_list, train_loss_list = [], [], [],[]\n",
    "train_acc_energy_list, train_acc_alpha_list, train_acc_q0_list, train_acc_list = [], [], [], []\n",
    "\n",
    "print(f\"[INFO] Init Validation Trackers\")\n",
    "val_loss_energy_list, val_loss_alpha_list,val_loss_q0_list,val_loss_list = [], [], [], []\n",
    "val_acc_energy_list, val_acc_alpha_list,val_acc_q0_list ,val_acc_list = [],[],[],[]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e603aab1",
   "metadata": {},
   "outputs": [],
   "source": [
    "model, optimizer, start_epoch, best_acc, early_stop_counter, best_epoch, best_metrics, training_summary, all_epoch_metrics,summary_status = init_resume_state( model, optimizer, device,cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dc7a942",
   "metadata": {},
   "outputs": [],
   "source": [
    "fill_trackers_from_history(\n",
    "    all_epoch_metrics,\n",
    "    train_loss_energy_list, train_loss_alpha_list, train_loss_q0_list, train_loss_list,\n",
    "    train_acc_energy_list, train_acc_alpha_list, train_acc_q0_list, train_acc_list,\n",
    "    val_loss_energy_list, val_loss_alpha_list, val_loss_q0_list, val_loss_list,\n",
    "    val_acc_energy_list, val_acc_alpha_list, val_acc_q0_list, val_acc_list,\n",
    "    summary_status, best_epoch\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3ba017d",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_epoch,best_acc,best_metrics=run_training_loop(\n",
    "                      cfg,train_loader,val_loader,\n",
    "                      device, model,criterion,\n",
    "                      optimizer,scheduler,\n",
    "                      start_epoch,early_stop_counter,\n",
    "                      best_acc,best_metrics,best_epoch,\n",
    "                      train_loss_list,\n",
    "                        train_loss_energy_list,\n",
    "                        train_loss_alpha_list,\n",
    "                        train_loss_q0_list,\n",
    "                        train_acc_list,\n",
    "                        train_acc_energy_list,\n",
    "                        train_acc_alpha_list,\n",
    "                        train_acc_q0_list,\n",
    "                        val_loss_list,\n",
    "                        val_loss_energy_list,\n",
    "                        val_loss_alpha_list,\n",
    "                        val_loss_q0_list,\n",
    "                        val_acc_list,\n",
    "                        val_acc_energy_list,\n",
    "                        val_acc_alpha_list,\n",
    "                        val_acc_q0_list,\n",
    "                        all_epoch_metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5ebaed4",
   "metadata": {},
   "outputs": [],
   "source": [
    "finalize_training_summary(\n",
    "    summary=training_summary,\n",
    "    best_epoch=best_epoch,\n",
    "    best_acc=best_acc,\n",
    "    best_metrics=best_metrics,\n",
    "    output_dir=cfg.output_dir\n",
    ")\n",
    "print_best_model_summary(\n",
    "    best_epoch=best_epoch,\n",
    "    best_acc=best_acc,\n",
    "    best_metrics=best_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b30accf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_train_val_metrics(train_loss_list, val_loss_list, train_acc_list, val_acc_list, cfg.output_dir)\n",
    "plot_loss_accuracy(train_loss_list,\n",
    "                    train_loss_energy_list,\n",
    "                    train_loss_alpha_list,\n",
    "                    train_loss_q0_list,\n",
    "                    train_acc_list,\n",
    "                    train_acc_energy_list,\n",
    "                    train_acc_alpha_list,\n",
    "                    train_acc_q0_list,\n",
    "                    cfg.output_dir,\n",
    "                    title=\"Train Loss and Accuracy per Epoch\")\n",
    "plot_loss_accuracy(val_loss_list,\n",
    "                    val_loss_energy_list,\n",
    "                    val_loss_alpha_list,\n",
    "                    val_loss_q0_list,\n",
    "                    val_acc_list,\n",
    "                    val_acc_energy_list,\n",
    "                    val_acc_alpha_list,\n",
    "                    val_acc_q0_list,\n",
    "                    cfg.output_dir,\n",
    "                    title=\"Validation Loss and Accuracy per Epoch\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
