{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c1f9c4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5012e04",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import torch\n",
    "import numpy as np\n",
    "from avalanche.models import IncrementalClassifier\n",
    "from collections import defaultdict\n",
    "from configs import default\n",
    "from utils.buffers import CalibrationBuffer\n",
    "from utils.dats import DATS\n",
    "from utils.data import load_benchmark\n",
    "from utils.calibration import one_hot_encode\n",
    "from utils.distance import compute_scores, select_representative_classes, assign_scores\n",
    "from utils.training import get_free_gpu_idx, evaluate, set_seed, compute_metrics, extract_features, load_trained_model, compute_nll\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2381ccbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = default.get_args('')\n",
    "DATASET = 'tinyimagenet'            # 'cifar10', 'cifar100', 'tinyimagenet', 'bloodmnist', 'dermamnist'\n",
    "AUGMENTATION = args.augmentation    # 'auto', 'standard'\n",
    "CL_STRATEGY = 'prototype'\n",
    "CALIBRATION_STRATEGY = 'dats'\n",
    "MODEL = 'slimresnet'                # 'slimresnet', 'resnet32'\n",
    "THRESHOLD = 0.6\n",
    "use_tta = False\n",
    "device = torch.device(f\"cuda:{get_free_gpu_idx()}\" if torch.cuda.is_available() else \"cpu\")\n",
    "calibration_res_dict = defaultdict(dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2eed85e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "before_ece_history = []\n",
    "after_ece_history = []\n",
    "temperature_history = []\n",
    "acc_history = []\n",
    "calibration_times = []\n",
    "logits_history = []\n",
    "\n",
    "for SEED in range(3):\n",
    "    print(f\"Running {CALIBRATION_STRATEGY}_{CL_STRATEGY} for seed {SEED} on {MODEL}-{DATASET}\")\n",
    "    set_seed(SEED)\n",
    "    calibration_buffer = CalibrationBuffer(benchmark_name=DATASET, mode='balanced')\n",
    "    benchmark, n_classes, train_transform, eval_transform = load_benchmark(benchmark_name=DATASET, augmentation=AUGMENTATION, seed=SEED)\n",
    "    \n",
    "    calibration_times_run = []\n",
    "    for experience_val in benchmark.valid_stream:\n",
    "        current_task_id = experience_val.current_experience\n",
    "        current_classes = experience_val.classes_in_this_experience\n",
    "        print(f\"Current Classes: {current_classes}\")\n",
    "        # Data loading and model setup (not timed)\n",
    "        trained_model, checkpoint, classes_so_far = load_trained_model(MODEL, DATASET, SEED, device, current_task_id)\n",
    "        trained_model.output = IncrementalClassifier(trained_model.output.in_features, initial_out_features=classes_so_far)\n",
    "        trained_model.load_state_dict(checkpoint[\"state_dict\"])\n",
    "        trained_model = trained_model.to(device)\n",
    "        \n",
    "        # START TIMING CALIBRATION PART\n",
    "        calibration_start_time = time.time()\n",
    "        # Calibration buffer updates and score computation\n",
    "        calibration_buffer.update(experience_val)\n",
    "        calibration_buffer.update_confidence_scores(trained_model, device)\n",
    "        val_logits, val_labels = evaluate(trained_model, experience_val.dataset.with_transforms('eval'), device)\n",
    "        buffer_logits, buffer_labels, buffer_scores, buffer_class_means, class_distances = compute_scores(trained_model, experience_val, calibration_buffer, device)\n",
    "        \n",
    "        # dats optimization\n",
    "        dats = DATS(ood_values_num=1)\n",
    "        dats.optimize(buffer_logits.numpy(), one_hot_encode(buffer_labels, buffer_logits.shape[1]), [buffer_scores.numpy()])\n",
    "\n",
    "        acc_run = []\n",
    "        before_ece_run = []\n",
    "        after_ece_run = []\n",
    "        temperature_run = []\n",
    "        logits_run = []\n",
    "        # Test evaluation and calibration\n",
    "        for experience_te in benchmark.test_stream[:current_task_id + 1]:\n",
    "            test_task_id = experience_te.current_experience\n",
    "            current_classes = experience_te.classes_in_this_experience\n",
    "\n",
    "            test_logits, test_labels = evaluate(trained_model, experience_te.dataset, device)\n",
    "            test_features = extract_features(trained_model, experience_te.dataset, device)\n",
    "            assigned_classes = select_representative_classes(test_features, buffer_class_means, threshold=THRESHOLD)\n",
    "            print(f'Task {test_task_id} - Current classes: {sorted(current_classes)} - Assigned Classes: {sorted(assigned_classes)}')\n",
    "            test_scores = assign_scores(test_labels.shape, assigned_classes, class_distances)\n",
    "\n",
    "            before_acc, before_ece, bin_acc, bin_cnf, bin_cnt = compute_metrics(test_logits, test_labels)\n",
    "            print(f'Task {test_task_id} - Accuracy: {before_acc:.4f}, Initial ECE: {before_ece:.4f}')\n",
    "            \n",
    "            # Apply calibration\n",
    "            calibrated_logits, t_list = dats.calibrate_before_softmax(test_logits.numpy(), [test_scores.numpy()])\n",
    "            after_acc, after_ece, bin_acc, bin_cnf, bin_cnt = compute_metrics(torch.Tensor(calibrated_logits), test_labels)\n",
    "            print(f'Task {test_task_id} - Accuracy: {after_acc:.4f}, Final ECE: {after_ece:.4f}, Avg Temperature: {np.mean(t_list):.4f}')\n",
    "            \n",
    "            acc_run.append(after_acc)\n",
    "            before_ece_run.append(before_ece)\n",
    "            after_ece_run.append(after_ece)\n",
    "            temperature_run.append(t_list.mean().item())\n",
    "            logits_run.append([test_logits, torch.Tensor(calibrated_logits), test_labels])\n",
    "            print()\n",
    "        \n",
    "        # END TIMING CALIBRATION PART\n",
    "        calibration_end_time = time.time()\n",
    "        calibration_time = calibration_end_time - calibration_start_time\n",
    "        calibration_times_run.append(calibration_time)\n",
    "        \n",
    "        print(f'Task: {current_task_id} - Calibration time: {calibration_time:.2f} seconds')\n",
    "        print()\n",
    "\n",
    "    before_ece_history.append(before_ece_run)\n",
    "    after_ece_history.append(after_ece_run)\n",
    "    temperature_history.append(temperature_run)\n",
    "    acc_history.append(acc_run)\n",
    "    calibration_times.append(calibration_times_run)\n",
    "    logits_history.append(logits_run)\n",
    "\n",
    "    total_calibration_time = sum(calibration_times_run)\n",
    "    avg_calibration_time = total_calibration_time / len(calibration_times_run)\n",
    "    print(f\"Seed {SEED} - Total calibration time: {total_calibration_time:.2f} seconds\")\n",
    "    print(f\"Seed {SEED} - Average calibration time per task: {avg_calibration_time:.2f} seconds\")\n",
    "    print(f\"Average accuracy: {sum(acc_run) / len(acc_run):.4f}\")\n",
    "    print(f\"Average ECE: {sum(after_ece_run) / len(after_ece_run):.4f} \\n\")\n",
    "\n",
    "calibration_res_dict[DATASET][f'{CALIBRATION_STRATEGY}_{CL_STRATEGY}'] = {\n",
    "    'before_ece': before_ece_history,\n",
    "    'after_ece': after_ece_history, \n",
    "    'acc': acc_history,\n",
    "    'temperature': temperature_history,\n",
    "    'calibration_times': calibration_times,\n",
    "    'logits': logits_history,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08065543",
   "metadata": {},
   "outputs": [],
   "source": [
    "dict_results = calibration_res_dict[DATASET][f'{CALIBRATION_STRATEGY}_{CL_STRATEGY}']\n",
    "after_ece_mean = torch.Tensor(dict_results['after_ece']).mean(-1).mean().item() * 100\n",
    "after_ece_std = torch.Tensor(dict_results['after_ece']).mean(-1).std().item() * 100\n",
    "before_ece_mean = torch.Tensor(dict_results['before_ece']).mean(-1).mean().item() * 100\n",
    "before_ece_std = torch.Tensor(dict_results['before_ece']).mean(-1).std().item() * 100\n",
    "acc_mean = torch.Tensor(dict_results['acc']).mean(-1).mean().item() * 100\n",
    "acc_std = torch.Tensor(dict_results['acc']).mean(-1).std().item() * 100\n",
    "last_ece_mean = torch.Tensor(dict_results['after_ece'])[:,-1].mean(-1).item() * 100\n",
    "last_ece_std = torch.Tensor(dict_results['after_ece'])[:,-1].std(-1).item() * 100\n",
    "execution_time_mean = torch.Tensor(dict_results['calibration_times']).sum(-1).mean().item()\n",
    "execution_time_std = torch.Tensor(dict_results['calibration_times']).sum(-1).std().item()\n",
    "delta_last_ece_mean = (torch.Tensor(dict_results['after_ece'])[:,-1] - torch.Tensor(dict_results['before_ece'])[:,-1]).mean(-1).item() * 100\n",
    "delta_last_ece_std = (torch.Tensor(dict_results['after_ece'])[:,-1] - torch.Tensor(dict_results['before_ece'])[:,-1]).std(-1).item() * 100\n",
    "before_nll_mean, before_nll_std = compute_nll(dict_results, calibrated=False)\n",
    "after_nll_mean, after_nll_std = compute_nll(dict_results, from_logits=True)\n",
    "\n",
    "print(f\"Results for {CALIBRATION_STRATEGY}_{CL_STRATEGY} on {MODEL}-{DATASET}\")\n",
    "print(f\"Accuracy: {acc_mean:.2f} ± {acc_std:.2f}\")\n",
    "print(f\"Before ECE: {before_ece_mean:.2f} ± {before_ece_std:.2f}\")\n",
    "print(f\"After ECE: {after_ece_mean:.2f} ± {after_ece_std:.2f}\")\n",
    "print(f\"Last ECE: {last_ece_mean:.2f} ± {last_ece_std:.2f}\")\n",
    "print(f'Execution time: {execution_time_mean:.2f} ± {execution_time_std:.2f} sec.')\n",
    "print(f'Delta Last ECE: {delta_last_ece_mean:.2f} ± {delta_last_ece_std:.2f}')\n",
    "print(f'Before NLL: {before_nll_mean:.2f} ± {before_nll_std:.2f}')\n",
    "print(f'After NLL: {after_nll_mean:.2f} ± {after_nll_std:.2f}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cl-cal",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
