{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcb9cace",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 🔧 Add root to sys.path for module imports (Jupyter-safe)\n",
    "import sys\n",
    "from pathlib import Path\n",
    "\n",
    "ROOT = Path.cwd().parent.parent  # Go up from /notebooks/ → experiment → project root\n",
    "if str(ROOT) not in sys.path:\n",
    "    sys.path.append(str(ROOT))\n",
    "\n",
    "print(f\"[INFO] Added ROOT to sys.path: {ROOT}\")\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9b2807f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import json\n",
    "\n",
    "from config import get_config\n",
    "from train_utils.gpu_utils import get_device_summary\n",
    "from data.loader import get_dataloaders\n",
    "from train_utils.resume import init_resume_state\n",
    "from train_utils.resume import fill_trackers_from_history\n",
    "from train_utils.resume import load_pretrained_model\n",
    "from train_utils.training_loop import run_training_loop\n",
    "from train_utils.scheduler_utils import create_scheduler\n",
    "from train_utils.training_summary import finalize_training_summary\n",
    "from train_utils.training_summary import print_best_model_summary\n",
    "from train_utils.plot_metrics import plot_train_val_metrics\n",
    "from train_utils.plot_metrics import plot_loss_accuracy\n",
    "from train_utils.plot_metrics import plot_confusion_matrices\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be67c907",
   "metadata": {},
   "outputs": [],
   "source": [
    "# cfg=get_config(config_path=\"config/convnext_fb_in22k_ft_in1k_bs512_ep50_lr1e-04_ds1000.yml\")\n",
    "# cfg=get_config(config_path=\"config/convnext_fb_in1k_bs512_ep50_lr1e-04_ds1000.yml\")\n",
    "# cfg=get_config(config_path=\"config/convnext_gaussian_bs512_ep50_lr1e-04_ds1000.yml\")\n",
    "# cfg=get_config(config_path=\"config/efficientnet_bs512_ep50_lr1e-01_ds1000_sched-RLRP.yml\")\n",
    "# cfg=get_config(config_path=\"config/vit_\" \\\n",
    "# \"bs512_ep50_lr1e-04_ds1000.yml\")\n",
    "# cfg=get_config(config_path=\"config/mambaout_base_plus_rw_bs32_ep50_lr1e-04_ds1000-g1.yml\")\n",
    "# cfg=get_config(config_path=\"config/mambaout_base_plus_rw_bs16_ep50_lr1e-04_ds1008_g500_sched-RLRP.yml\")\n",
    "\n",
    "# from experiments.exp_mamaba_vit_stack.models.hybrid_mamba_vit import create_model\n",
    "# cfg=get_config(config_path=\"/\" \\\n",
    "# \"experiments/exp_mamaba_vit_stack/config/\" \\\n",
    "# \"hybrid_mambaout_base_plus_rw_ViT_tiny_patch16_224_bs64_ep1_lr1e-04_ds1008_g500_sched-RLRP.yml\")\n",
    "\n",
    "from models.model_vit import create_model\n",
    "# cfg=get_config(config_path=\"/\" \\\n",
    "# \"experiments/exp_preload_trained_model_and_train_more/config/\" \\\n",
    "# \"vit_tiny_patch16_224_gaussian_bs32_ep200_lr1e-04_p60_ds1008_g500_sched-RLRP_preload_p12.yml\")\n",
    "\n",
    "# from models.model_mamba import create_model\n",
    "# cfg=get_config(config_path=\"/\" \\\n",
    "# \"experiments/exp_alpha_s_evaluation_with_vit/config/\" \\\n",
    "# \"vit_tiny_patch16_224_gaussian_bs32_ep50_lr1e-04_p12_ds7200000_g500_sched-RLRP_preload_p4.yml\")\n",
    "\n",
    "\n",
    "# cfg=get_config(config_path=\"/\" \\\n",
    "# \"experiments/exp_evaluating_vit_with_q_0_probab_distirbution_plot/config/\" \\\n",
    "# \"vit_tiny_patch16_224_gaussian_bs32_ep50_lr1e-04_p12_ds1008_g500_sched-RLRP_preload_p4.yml\")\n",
    "cfg=get_config()\n",
    "print(json.dumps(vars(cfg), indent=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30787e6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(cfg.output_dir, exist_ok=True)\n",
    "print(f\"[INFO] Saving all outputs to: {cfg.output_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4c87c35",
   "metadata": {},
   "outputs": [],
   "source": [
    "device= get_device_summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75ae2d7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data\n",
    "train_loader, val_loader, test_loader = get_dataloaders(cfg, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc955a17",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model and optimizer\n",
    "model, optimizer = create_model(cfg.backbone, cfg.input_shape, cfg.learning_rate)\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95e5622c",
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.device_count() > 1:\n",
    "    print(f\"Parallelizing model across {torch.cuda.device_count()} GPUs\")\n",
    "    model = torch.nn.DataParallel(model)\n",
    "elif torch.cuda.device_count() == 1:\n",
    "    print(\"No parallelization, using single GPU\")\n",
    "elif torch.cuda.device_count() == 0:\n",
    "    print(\"No GPU available, using CPU\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10d67592",
   "metadata": {},
   "outputs": [],
   "source": [
    "scheduler = create_scheduler(optimizer, cfg, train_loader=train_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bb7f050",
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = {\n",
    "    # 'energy_loss_output': nn.BCELoss(),\n",
    "    'energy_loss_output': nn.BCEWithLogitsLoss(),\n",
    "    'alpha_output': nn.CrossEntropyLoss(),\n",
    "    'q0_output': nn.CrossEntropyLoss()\n",
    "}\n",
    "print(f\"[INFO] Loss functions:{criterion}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdaac3c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"[INFO] Init Training Trackers\")\n",
    "train_loss_energy_list, train_loss_alpha_list, train_loss_q0_list, train_loss_list = [], [], [],[]\n",
    "train_acc_energy_list, train_acc_alpha_list, train_acc_q0_list, train_acc_list = [], [], [], []\n",
    "\n",
    "print(f\"[INFO] Init Validation Trackers\")\n",
    "val_loss_energy_list, val_loss_alpha_list,val_loss_q0_list,val_loss_list = [], [], [], []\n",
    "val_acc_energy_list, val_acc_alpha_list,val_acc_q0_list ,val_acc_list = [],[],[],[]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3c9f46d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model, optimizer, start_epoch, best_acc, early_stop_counter, best_epoch, best_metrics, training_summary, all_epoch_metrics,summary_status = init_resume_state( model, optimizer, device,cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5d28f23",
   "metadata": {},
   "outputs": [],
   "source": [
    "fill_trackers_from_history(\n",
    "    all_epoch_metrics,\n",
    "    train_loss_energy_list, train_loss_alpha_list, train_loss_q0_list, train_loss_list,\n",
    "    train_acc_energy_list, train_acc_alpha_list, train_acc_q0_list, train_acc_list,\n",
    "    val_loss_energy_list, val_loss_alpha_list, val_loss_q0_list, val_loss_list,\n",
    "    val_acc_energy_list, val_acc_alpha_list, val_acc_q0_list, val_acc_list,\n",
    "    summary_status, best_epoch\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29a29f7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "model, preloaded = load_pretrained_model(model, device, cfg)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0340a601",
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha_names = (\"0.2\",\"0.3\",\"0.4\")\n",
    "q0_names    = (\"1.0\",\"1.5\",\"2.0\",\"2.5\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9db6feb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from train_utils.evaluate import evaluate\n",
    "# experiment folder for artifacts\n",
    "art_dir = os.path.join(cfg.output_dir,\"val/prob_plots\")\n",
    "os.makedirs(art_dir, exist_ok=True)\n",
    "\n",
    "alpha_hist_path      = os.path.join(art_dir, \"alpha_pred_hist\")      # .png/.pdf added by evaluate\n",
    "alpha_avgprob_path   = os.path.join(art_dir, \"alpha_avgprob\")        # .png/.pdf added by evaluate\n",
    "q0_avgprob_path      = os.path.join(art_dir, \"q0_avgprob\")           # .png/.pdf added by evaluate\n",
    "\n",
    "metrics = evaluate(\n",
    "    val_loader, model, criterion, device,\n",
    "    loss_weights=getattr(cfg, \"loss_weights\", None),\n",
    "    # plots:\n",
    "    make_alpha_fig=True,\n",
    "    alpha_fig_path=str(alpha_hist_path),\n",
    "    make_alpha_avgprob_fig=True,\n",
    "    alpha_avgprob_fig_path=str(alpha_avgprob_path),\n",
    "    make_q0_avgprob_fig=True,\n",
    "    q0_avgprob_fig_path=str(q0_avgprob_path),\n",
    "    alpha_class_names=alpha_names,\n",
    "    q0_class_names=q0_names,\n",
    ")\n",
    "\n",
    "print(\"Saved images:\")\n",
    "print(\"  α_s histogram:        \", metrics.get(\"alpha_hist_path\"))\n",
    "print(\"  α_s avg-prob bars:    \", metrics.get(\"alpha_avgprob_hist_path\"))\n",
    "print(\"  α_s probabilities:     \", metrics.get(\"alpha_probs_csv\"))\n",
    "print(\"  α_s heatmap:          \", metrics.get(\"alpha_heatmap_path\"))\n",
    "print(\"  Q0  avg-prob bars:    \", metrics.get(\"q0_avgprob_hist_path\"))\n",
    "print(\"  Q0  probabilities:     \", metrics.get(\"q0_probs_csv\"))\n",
    "print(\"  Q0  heatmap:          \", metrics.get(\"q0_heatmap_path\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ba8cce0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from train_utils.evaluate import evaluate\n",
    "# metrics = evaluate(\n",
    "#     val_loader, model, criterion, device,\n",
    "#     make_q0_avgprob_fig=True,\n",
    "#     q0_avgprob_fig_path=os.path.join(cfg.output_dir, \"val/hists/q0_avgprob_per_true_bin\"),\n",
    "#     q0_class_names=(\"1.0\",\"1.5\",\"2.0\",\"2.5\")\n",
    "# )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59438e11",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_confusion_matrices(metrics, output_dir= f\"{cfg.output_dir}/val\", color_map=\"Oranges\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31a6ca54",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "print(\"Saved:\", metrics.get(\"q0_avgprob_hist_path\"))\n",
    "print(f\"[INFO] Energy Acc ={metrics['energy']['accuracy']:.4f}, αs Acc = {metrics['alpha']['accuracy']:.4f}, Q0 Acc = {metrics['q0']['accuracy']:.4f}, Total Acc = {metrics['accuracy']:.4f}\")\n",
    "print(f\"[INFO] Energy Loss ={metrics['loss_energy']:.4f}, αs Loss = {metrics['loss_alpha']:.4f}, Q0 Loss = {metrics['loss_q0']:.4f}, Total Loss = {metrics['loss']:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e81253b",
   "metadata": {},
   "outputs": [],
   "source": [
    "finalize_training_summary(\n",
    "    summary=training_summary,\n",
    "    best_epoch=\"NA\",\n",
    "    best_acc=\"NA\",\n",
    "    best_metrics=metrics,\n",
    "    output_dir=os.path.join(cfg.output_dir, \"val\")\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47e0f6e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from train_utils.evaluate import evaluate\n",
    "# metrics = evaluate(\n",
    "#     test_loader, model, criterion, device,\n",
    "#     make_q0_avgprob_fig=True,\n",
    "#     q0_avgprob_fig_path=os.path.join(cfg.output_dir, \"test/hists/q0_avgprob_per_true_bin\"),\n",
    "#     q0_class_names=(\"1.0\",\"1.5\",\"2.0\",\"2.5\")\n",
    "# )\n",
    "# print(\"Saved:\", metrics.get(\"q0_avgprob_hist_path\"))\n",
    "# print(f\"[INFO] Energy Acc ={metrics['energy']['accuracy']:.4f}, αs Acc = {metrics['alpha']['accuracy']:.4f}, Q0 Acc = {metrics['q0']['accuracy']:.4f}, Total Acc = {metrics['accuracy']:.4f}\")\n",
    "# print(f\"[INFO] Energy Loss ={metrics['loss_energy']:.4f}, αs Loss = {metrics['loss_alpha']:.4f}, Q0 Loss = {metrics['loss_q0']:.4f}, Total Loss = {metrics['loss']:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba0bd8da",
   "metadata": {},
   "outputs": [],
   "source": [
    "# finalize_training_summary(\n",
    "#     summary=training_summary,\n",
    "#     best_epoch=\"NA\",\n",
    "#     best_acc=\"NA\",\n",
    "#     best_metrics=metrics,\n",
    "#     output_dir=os.path.join(cfg.output_dir, \"test\")\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d99e96fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from train_utils.evaluate import evaluate\n",
    "# val_metrics = evaluate(\n",
    "#     val_loader, model, criterion, device,\n",
    "#     make_alpha_avgprob_fig=True, \n",
    "#     alpha_avgprob_fig_path= os.path.join(cfg.output_dir, \"val/hists/alpha_avgprob_per_true_bin\"),\n",
    "#     alpha_class_names=(\"0.2\",\"0.3\",\"0.4\")\n",
    "# )\n",
    "# print(f\"[INFO] Energy Acc ={val_metrics['energy']['accuracy']:.4f}, αs Acc = {val_metrics['alpha']['accuracy']:.4f}, Q0 Acc = {val_metrics['q0']['accuracy']:.4f}, Total Acc = {val_metrics['accuracy']:.4f}\")\n",
    "# print(f\"[INFO] Energy Loss ={val_metrics['loss_energy']:.4f}, αs Loss = {val_metrics['loss_alpha']:.4f}, Q0 Loss = {val_metrics['loss_q0']:.4f}, Total Loss = {val_metrics['loss']:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c28287e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot_confusion_matrices(val_metrics, output_dir=os.path.join(cfg.output_dir, \"val\"), color_map=\"Oranges\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "488bf3f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# finalize_training_summary(\n",
    "#     summary=training_summary,\n",
    "#     best_epoch=\"NA\",\n",
    "#     best_acc=\"NA\",\n",
    "#     best_metrics=val_metrics,\n",
    "#     output_dir=os.path.join(cfg.output_dir, \"val\")\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "003c192a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from train_utils.evaluate import evaluate\n",
    "# test_metrics = evaluate(\n",
    "#     test_loader, model, criterion, device,\n",
    "#     make_alpha_fig=True, alpha_fig_path=os.path.join(cfg.output_dir, \"test/hists/alpha_hist_per_true_bin\"), alpha_class_names=(\"0.2\",\"0.3\",\"0.4\"),\n",
    "#     make_alpha_avgprob_fig=True,   alpha_avgprob_fig_path= os.path.join(cfg.output_dir, \"test/hists/alpha_avgprob_per_true_bin\")\n",
    "# )\n",
    "# print(f\"[INFO] Energy Acc ={test_metrics['energy']['accuracy']:.4f}, αs Acc = {test_metrics['alpha']['accuracy']:.4f}, Q0 Acc = {test_metrics['q0']['accuracy']:.4f}, Total Acc = {test_metrics['accuracy']:.4f}\")\n",
    "# print(f\"[INFO] Energy Loss ={val_metrics['loss_energy']:.4f}, αs Loss = {val_metrics['loss_alpha']:.4f}, Q0 Loss = {val_metrics['loss_q0']:.4f}, Total Loss = {val_metrics['loss']:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "352cf993",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot_confusion_matrices(test_metrics, output_dir=os.path.join(cfg.output_dir, \"test\"), color_map=\"Oranges\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "698b395a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# finalize_training_summary(\n",
    "#     summary=training_summary,\n",
    "#     best_epoch=\"NA\",\n",
    "#     best_acc=\"NA\",\n",
    "#     best_metrics=test_metrics,\n",
    "#     output_dir=os.path.join(cfg.output_dir, \"test\")\n",
    "# )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
