{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import json\n",
    "\n",
    "from config import get_config\n",
    "from train_utils.gpu_utils import get_device_summary\n",
    "from data.loader import get_dataloaders\n",
    "from models.model import create_model\n",
    "from train_utils.resume import init_resume_state\n",
    "from train_utils.resume import fill_trackers_from_history\n",
    "from train_utils.resume import load_pretrained_model\n",
    "from train_utils.training_loop import run_training_loop\n",
    "from train_utils.scheduler_utils import create_scheduler\n",
    "from train_utils.training_summary import finalize_training_summary\n",
    "from train_utils.training_summary import print_best_model_summary\n",
    "from train_utils.plot_metrics import plot_train_val_metrics\n",
    "from train_utils.plot_metrics import plot_loss_accuracy\n",
    "from train_utils.plot_metrics import plot_confusion_matrices\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] Config Path: /experiments/exp_loss_weight_sweep/configs/convnext_gaussian_bs32_ep50_lr1e-04_ds7200000_g500_sched-RLRP__lw-S1_balanced.yml\n",
      "[INFO] Detected native Ubuntu host: DS044955\n",
      "[INFO] Using dataset root: /home/johndoe/Projects/110_JetscapeML/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/\n",
      "[INFO] Using dataset_size from config: 7200000\n",
      "{\n",
      "  \"model_tag\": \"ConvNeXt_Gaussian_g500__lw-S1_balanced\",\n",
      "  \"backbone\": \"convnext_gaussian\",\n",
      "  \"batch_size\": 32,\n",
      "  \"epochs\": 50,\n",
      "  \"learning_rate\": 0.0001,\n",
      "  \"patience\": 12,\n",
      "  \"input_shape\": [\n",
      "    1,\n",
      "    32,\n",
      "    32\n",
      "  ],\n",
      "  \"global_max\": 121.79151153564453,\n",
      "  \"dataset_root_dir\": \"/home/johndoe/Projects/110_JetscapeML/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/\",\n",
      "  \"train_csv\": \"/home/johndoe/Projects/110_JetscapeML/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/file_labels_aggregated_ds7200000_g500_val_folds_out/fold1_train.csv\",\n",
      "  \"val_csv\": \"/home/johndoe/Projects/110_JetscapeML/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/file_labels_aggregated_ds7200000_g500_val_folds_out/fold1_val.csv\",\n",
      "  \"test_csv\": \"/home/johndoe/Projects/110_JetscapeML/hm_jetscapeml_source/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_7200000_balanced_unshuffled/file_labels_aggregated_ds7200000_g500_val_folds_out/fold1_val.csv\",\n",
      "  \"output_dir\": \"experiments/exp_loss_weight_sweep/training_output/ConvNeXt_Gaussian_g500__lw-S1_balanced_bs32_ep50_lr1e-04_ds7200000_g500_sched_ReduceLROnPlateau_preloaded_weighted_lossenergy_loss_output_1.0_alpha_output_1.0_q0_output_1.0\",\n",
      "  \"group_size\": 500,\n",
      "  \"scheduler\": {\n",
      "    \"type\": \"ReduceLROnPlateau\",\n",
      "    \"mode\": \"max\",\n",
      "    \"factor\": 0.5,\n",
      "    \"patience\": 4,\n",
      "    \"verbose\": true\n",
      "  },\n",
      "  \"dataset_size\": 7200000,\n",
      "  \"preload_model_path\": \"experiments/exp_best_trained_models/training_output/ConvNeXt_tiny_Gaussian_g500_bs32_ep50_lr1e-04_ds7200000_g500_sched_ReduceLROnPlateau/best_model.pth\",\n",
      "  \"loss_weights\": {\n",
      "    \"energy_loss_output\": 1.0,\n",
      "    \"alpha_output\": 1.0,\n",
      "    \"q0_output\": 1.0\n",
      "  }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "# cfg=get_config(config_path=\"config/convnext_fb_in22k_ft_in1k_bs512_ep50_lr1e-04_ds1000.yml\")\n",
    "# cfg=get_config(config_path=\"config/convnext_fb_in1k_bs512_ep50_lr1e-04_ds1000.yml\")\n",
    "# cfg=get_config(config_path=\"config/convnext_gaussian_bs512_ep50_lr1e-04_ds1000.yml\")\n",
    "# cfg=get_config(config_path=\"config/efficientnet_bs512_ep50_lr1e-01_ds1000_sched-RLRP.yml\")\n",
    "# cfg=get_config(config_path=\"config/vit_\" \\\n",
    "# \"bs512_ep50_lr1e-04_ds1000.yml\")\n",
    "# cfg=get_config(config_path=\"config/mambaout_base_plus_rw_bs32_ep50_lr1e-04_ds1000-g1.yml\")\n",
    "# cfg=get_config(config_path=\"config/mambaout_base_plus_rw_bs16_ep50_lr1e-04_ds1008_g500_sched-RLRP.yml\")\n",
    "\n",
    "# from experiments.exp_mamaba_vit_stack.models.hybrid_mamba_vit import create_model\n",
    "# cfg=get_config(config_path=\"/\" \\\n",
    "# \"experiments/exp_mamaba_vit_stack/config/\" \\\n",
    "# \"hybrid_mambaout_base_plus_rw_ViT_tiny_patch16_224_bs64_ep1_lr1e-04_ds1008_g500_sched-RLRP.yml\")\n",
    "# from models.model_mamba import create_model\n",
    "# from models.model_vit import create_model\n",
    "# cfg=get_config(config_path=\"/\" \\\n",
    "# \"experiments/exp_adding_loss_weights_for_q0_emphasis/config/\" \\\n",
    "# \"vit_tiny_patch16_224_gaussian_bs32_ep200_lr1e-04_p60_ds1008_g500_sched-RLRP_preload_p12.yml\")\n",
    "# from models.model import create_model\n",
    "# cfg=get_config(config_path=\"/\" \\\n",
    "# \"experiments/exp_loss_weight_sweep/configs/\" \\\n",
    "# \"convnext_gaussian_bs32_ep50_lr1e-04_ds7200000_g500_sched-RLRP__lw-S1_balanced.yml\")\n",
    "\n",
    "cfg=get_config()\n",
    "print(json.dumps(vars(cfg), indent=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(cfg.output_dir, exist_ok=True)\n",
    "print(f\"[INFO] Saving all outputs to: {cfg.output_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device= get_device_summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data\n",
    "train_loader, val_loader, test_loader = get_dataloaders(cfg, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model and optimizer\n",
    "model, optimizer = create_model(cfg.backbone, cfg.input_shape, cfg.learning_rate)\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.device_count() > 1:\n",
    "    print(f\"Parallelizing model across {torch.cuda.device_count()} GPUs\")\n",
    "    model = torch.nn.DataParallel(model)\n",
    "elif torch.cuda.device_count() == 1:\n",
    "    print(\"No parallelization, using single GPU\")\n",
    "elif torch.cuda.device_count() == 0:\n",
    "    print(\"No GPU available, using CPU\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scheduler = create_scheduler(optimizer, cfg, train_loader=train_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = {\n",
    "    # 'energy_loss_output': nn.BCELoss(),\n",
    "    'energy_loss_output': nn.BCEWithLogitsLoss(),\n",
    "    'alpha_output': nn.CrossEntropyLoss(),\n",
    "    'q0_output': nn.CrossEntropyLoss()\n",
    "}\n",
    "print(f\"[INFO] Loss functions:{criterion}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"[INFO] Init Training Trackers\")\n",
    "train_loss_energy_list, train_loss_alpha_list, train_loss_q0_list, train_loss_list = [], [], [],[]\n",
    "train_acc_energy_list, train_acc_alpha_list, train_acc_q0_list, train_acc_list = [], [], [], []\n",
    "\n",
    "print(f\"[INFO] Init Validation Trackers\")\n",
    "val_loss_energy_list, val_loss_alpha_list,val_loss_q0_list,val_loss_list = [], [], [], []\n",
    "val_acc_energy_list, val_acc_alpha_list,val_acc_q0_list ,val_acc_list = [],[],[],[]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model, optimizer, start_epoch, best_acc, early_stop_counter, best_epoch, best_metrics, training_summary, all_epoch_metrics,summary_status = init_resume_state( model, optimizer, device,cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fill_trackers_from_history(\n",
    "    all_epoch_metrics,\n",
    "    train_loss_energy_list, train_loss_alpha_list, train_loss_q0_list, train_loss_list,\n",
    "    train_acc_energy_list, train_acc_alpha_list, train_acc_q0_list, train_acc_list,\n",
    "    val_loss_energy_list, val_loss_alpha_list, val_loss_q0_list, val_loss_list,\n",
    "    val_acc_energy_list, val_acc_alpha_list, val_acc_q0_list, val_acc_list,\n",
    "    summary_status, best_epoch\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model, preloaded = load_pretrained_model(model, device, cfg)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for testing\n",
    "# train_metrics = train_one_epoch(train_loader, model, criterion, optimizer, device)\n",
    "# print(f\"[INFO] Training metrics: {train_metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_epoch,best_acc,best_metrics=run_training_loop(\n",
    "                      cfg,train_loader,val_loader,\n",
    "                      device, model,criterion,\n",
    "                      optimizer,scheduler,\n",
    "                      start_epoch,early_stop_counter,\n",
    "                      best_acc,best_metrics,best_epoch,\n",
    "                      train_loss_list,\n",
    "                        train_loss_energy_list,\n",
    "                        train_loss_alpha_list,\n",
    "                        train_loss_q0_list,\n",
    "                        train_acc_list,\n",
    "                        train_acc_energy_list,\n",
    "                        train_acc_alpha_list,\n",
    "                        train_acc_q0_list,\n",
    "                        val_loss_list,\n",
    "                        val_loss_energy_list,\n",
    "                        val_loss_alpha_list,\n",
    "                        val_loss_q0_list,\n",
    "                        val_acc_list,\n",
    "                        val_acc_energy_list,\n",
    "                        val_acc_alpha_list,\n",
    "                        val_acc_q0_list,\n",
    "                        all_epoch_metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "finalize_training_summary(\n",
    "    summary=training_summary,\n",
    "    best_epoch=best_epoch,\n",
    "    best_acc=best_acc,\n",
    "    best_metrics=best_metrics,\n",
    "    output_dir=cfg.output_dir\n",
    ")\n",
    "print_best_model_summary(\n",
    "    best_epoch=best_epoch,\n",
    "    best_acc=best_acc,\n",
    "    best_metrics=best_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_train_val_metrics(train_loss_list, val_loss_list, train_acc_list, val_acc_list, cfg.output_dir)\n",
    "plot_loss_accuracy(train_loss_list,\n",
    "                    train_loss_energy_list,\n",
    "                    train_loss_alpha_list,\n",
    "                    train_loss_q0_list,\n",
    "                    train_acc_list,\n",
    "                    train_acc_energy_list,\n",
    "                    train_acc_alpha_list,\n",
    "                    train_acc_q0_list,\n",
    "                    cfg.output_dir,\n",
    "                    title=\"Train Loss and Accuracy per Epoch\")\n",
    "plot_loss_accuracy(val_loss_list,\n",
    "                    val_loss_energy_list,\n",
    "                    val_loss_alpha_list,\n",
    "                    val_loss_q0_list,\n",
    "                    val_acc_list,\n",
    "                    val_acc_energy_list,\n",
    "                    val_acc_alpha_list,\n",
    "                    val_acc_q0_list,\n",
    "                    cfg.output_dir,\n",
    "                    title=\"Validation Loss and Accuracy per Epoch\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_confusion_matrices(best_metrics, output_dir= f\"{cfg.output_dir}/val\", color_map=\"Oranges\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha_names = (\"0.2\",\"0.3\",\"0.4\")\n",
    "q0_names    = (\"1.0\",\"1.5\",\"2.0\",\"2.5\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from train_utils.evaluate import evaluate\n",
    "# experiment folder for artifacts\n",
    "art_dir = os.path.join(cfg.output_dir,\"val/prob_plots\")\n",
    "os.makedirs(art_dir, exist_ok=True)\n",
    "\n",
    "alpha_hist_path      = os.path.join(art_dir, \"alpha_pred_hist\")      # .png/.pdf added by evaluate\n",
    "alpha_avgprob_path   = os.path.join(art_dir, \"alpha_avgprob\")        # .png/.pdf added by evaluate\n",
    "q0_avgprob_path      = os.path.join(art_dir, \"q0_avgprob\")           # .png/.pdf added by evaluate\n",
    "\n",
    "metrics_val = evaluate(\n",
    "    val_loader, model, criterion, device,\n",
    "    loss_weights=getattr(cfg, \"loss_weights\", None),\n",
    "    # plots:\n",
    "    make_alpha_fig=True,\n",
    "    alpha_fig_path=str(alpha_hist_path),\n",
    "    make_alpha_avgprob_fig=True,\n",
    "    alpha_avgprob_fig_path=str(alpha_avgprob_path),\n",
    "    make_q0_avgprob_fig=True,\n",
    "    q0_avgprob_fig_path=str(q0_avgprob_path),\n",
    "    alpha_class_names=alpha_names,\n",
    "    q0_class_names=q0_names,\n",
    ")\n",
    "\n",
    "print(\"Saved images:\")\n",
    "print(\"  α_s histogram:        \", metrics_val.get(\"alpha_hist_path\"))\n",
    "print(\"  α_s avg-prob bars:    \", metrics_val.get(\"alpha_avgprob_hist_path\"))\n",
    "print(\"  α_s probabilities:     \", metrics_val.get(\"alpha_probs_csv\"))\n",
    "print(\"  α_s heatmap:          \", metrics_val.get(\"alpha_heatmap_path\"))\n",
    "print(\"  Q0  avg-prob bars:    \", metrics_val.get(\"q0_avgprob_hist_path\"))\n",
    "print(\"  Q0  probabilities:     \", metrics_val.get(\"q0_probs_csv\"))\n",
    "print(\"  Q0  heatmap:          \", metrics_val.get(\"q0_heatmap_path\"))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# (optional) test split as well\n",
    "try:\n",
    "    from train_utils.evaluate import evaluate\n",
    "    # experiment folder for artifacts\n",
    "    art_dir = os.path.join(cfg.output_dir,\"test/prob_plots\")\n",
    "    os.makedirs(art_dir, exist_ok=True)\n",
    "\n",
    "    alpha_hist_path      = os.path.join(art_dir, \"alpha_pred_hist\")      # .png/.pdf added by evaluate\n",
    "    alpha_avgprob_path   = os.path.join(art_dir, \"alpha_avgprob\")        # .png/.pdf added by evaluate\n",
    "    q0_avgprob_path      = os.path.join(art_dir, \"q0_avgprob\")           # .png/.pdf added by evaluate\n",
    "\n",
    "    metrics_test = evaluate(\n",
    "        test_loader, model, criterion, device,\n",
    "        loss_weights=getattr(cfg, \"loss_weights\", None),\n",
    "        # plots:\n",
    "        make_alpha_fig=True,\n",
    "        alpha_fig_path=str(alpha_hist_path),\n",
    "        make_alpha_avgprob_fig=True,\n",
    "        alpha_avgprob_fig_path=str(alpha_avgprob_path),\n",
    "        make_q0_avgprob_fig=True,\n",
    "        q0_avgprob_fig_path=str(q0_avgprob_path),\n",
    "        alpha_class_names=alpha_names,\n",
    "        q0_class_names=q0_names,\n",
    "    )\n",
    "\n",
    "    print(\"Saved images:\")\n",
    "    print(\"  α_s histogram:        \", metrics_test.get(\"alpha_hist_path\"))\n",
    "    print(\"  α_s avg-prob bars:    \", metrics_test.get(\"alpha_avgprob_hist_path\"))\n",
    "    print(\"  α_s probabilities:     \", metrics_test.get(\"alpha_probs_csv\"))\n",
    "    print(\"  α_s heatmap:          \", metrics_test.get(\"alpha_heatmap_path\"))\n",
    "    print(\"  Q0  avg-prob bars:    \", metrics_test.get(\"q0_avgprob_hist_path\"))\n",
    "    print(\"  Q0  probabilities:     \", metrics_test.get(\"q0_probs_csv\"))\n",
    "    print(\"  Q0  heatmap:          \", metrics_test.get(\"q0_heatmap_path\"))\n",
    "    have_test = True\n",
    "except NameError:\n",
    "    have_test = False"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
