{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from utils.load_dataset import load_dataset\n",
    "import logging\n",
    "from azure_blob_storage import get_numpy_from_azure\n",
    "\n",
    "logger = logging.getLogger(\"\")\n",
    "dataset_name = 'imagenet'\n",
    "num_augs = 2\n",
    "\n",
    "dataset = load_dataset(\n",
    "    dataset=dataset_name,\n",
    "    train_batch_size=128,\n",
    "    test_batch_size=128,\n",
    "    val_split=0.0,\n",
    "    augment=False,\n",
    "    shuffle=False,\n",
    "    root_path='set path',\n",
    "    random_seed=0,\n",
    "    mean=[0, 0, 0],\n",
    "    std=[1, 1, 1],\n",
    "    logger=logger,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_scores_and_masks(seeds, dataset, method, num_augs):\n",
    "    h = 0.001\n",
    "    num_seeds = len(seeds)\n",
    "    score_container_dir = dataset.name\n",
    "    aug_scores_file_name = {\n",
    "        'curv_zo' : \"curvature_scores_zo_{}_{}_{}_tid{}.pt\",\n",
    "        'curv' : \"curvature_scores_{}_{}_{}_tid{}.pt\",\n",
    "        'prob' : \"prob_{}_{}_{}_tid{}.pt\",\n",
    "        'loss' : \"losses_{}_{}_{}_tid{}.pt\",\n",
    "        'mentr' : \"m_entropy_{}_{}_{}_tid{}.pt\",\n",
    "        'loss_g' : \"loss_g_{}_{}_{}_tid{}.pt\",\n",
    "    }\n",
    "\n",
    "    score_seeds_v_seed = np.zeros((num_augs, dataset.train_length, num_seeds))\n",
    "    mask_seeds_v_seed = np.zeros((dataset.train_length, num_seeds))\n",
    "    start_seed = seeds[0]\n",
    "    for exp_idx in seeds:\n",
    "        array_idx = exp_idx - start_seed\n",
    "        augs = 1 if method == 'mentr' else num_augs\n",
    "        for aug_idx in range(augs):\n",
    "            file_name = aug_scores_file_name[method].format(dataset.name, exp_idx, h, aug_idx)\n",
    "            scores = get_numpy_from_azure(\n",
    "                container_name= \"curvature-mi-scores\", \n",
    "                container_dir=score_container_dir, \n",
    "                file_name=file_name)\n",
    "\n",
    "            score_seeds_v_seed[aug_idx, :, array_idx] = scores\n",
    "\n",
    "        if dataset.name == 'imagenet':\n",
    "            array = np.load(f'/path/imagenet-resnet50/{0.7}/{exp_idx}/aux_arrays.npz')\n",
    "            mask_idxs = array['subsample_idx']\n",
    "            mask_seeds_v_seed[mask_idxs, array_idx] = 1\n",
    "        else:\n",
    "            blob_container = \"curvature-mi-models\"\n",
    "            container_dir = \"cifar100_indexs\"\n",
    "            mask_idxs = get_numpy_from_azure(blob_container, container_dir, f\"{exp_idx}.npy\")\n",
    "            mask_seeds_v_seed[mask_idxs, array_idx] = 1\n",
    "\n",
    "    return score_seeds_v_seed, mask_seeds_v_seed\n",
    "\n",
    "def get_score_dist_params(scores, masks, num_augs):\n",
    "    means_in = []\n",
    "    means_out = []\n",
    "    var_in = []\n",
    "    var_out = []\n",
    "    for aug in range(num_augs):\n",
    "        in_scores = np.where(masks == 1, scores[aug], np.NaN)\n",
    "        out_scores = np.where(masks == 0, scores[aug], np.NaN)\n",
    "        means_in.append(np.nanmean(in_scores, axis=1))\n",
    "        means_out.append(np.nanmean(out_scores, axis=1))\n",
    "        var_in.append(np.nanvar(in_scores, axis=1))\n",
    "        var_out.append(np.nanvar(out_scores, axis=1))\n",
    "\n",
    "    return {\n",
    "        'means_in': np.row_stack(means_in),\n",
    "        'means_out': np.row_stack(means_out),\n",
    "        'var_in': np.row_stack(var_in),\n",
    "        'var_out': np.row_stack(var_out)\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = {}\n",
    "masks = {}\n",
    "dist_params = {}\n",
    "train_end_seed = 52\n",
    "train_seeds = list(range(train_end_seed))\n",
    "test_seeds = list(range(train_end_seed, train_end_seed + 3))\n",
    "eps = 1e-22"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def log_err(x):\n",
    "    return np.log(x + eps)\n",
    "\n",
    "def onem_log_err(x):\n",
    "    return np.log(1 - x + eps)\n",
    "\n",
    "def logit_scale(x):\n",
    "    x = x / (1 - x + eps)\n",
    "    return log_err(x)\n",
    "\n",
    "def nll(curv_score, mean, var):\n",
    "    ll = - ( ( (curv_score - mean)**2) / (2 * (var ** 2) ) ) -0.5 * np.log(var ** 2) - 0.5 * np.log(2 * np.pi)\n",
    "    return -ll\n",
    "\n",
    "def likelihood(curv_score, mean, var):\n",
    "    nll = - ( ( (curv_score - mean)**2) / (2 * (var ** 2) ) ) - 0.5 * np.log(var ** 2) - 0.5 * np.log(2 * np.pi)\n",
    "    likelihood_val = np.exp(nll)\n",
    "    return likelihood_val\n",
    "\n",
    "def get_likelihood_ratio(test_scores, dist_params, num_augs, train_scores, train_masks):\n",
    "    in_likelihood = np.zeros((test_scores.shape[1]))\n",
    "    out_likelihood = np.zeros_like(in_likelihood)\n",
    "    for aug in range(num_augs):\n",
    "        test_score_aug = test_scores[aug, :, 0]\n",
    "        in_likelihood += likelihood(\n",
    "            test_score_aug, \n",
    "            dist_params['means_in'][aug], \n",
    "            dist_params['var_in'][aug] + eps)\n",
    "\n",
    "        out_likelihood += likelihood(\n",
    "            test_score_aug, \n",
    "            dist_params['means_out'][aug], \n",
    "            dist_params['var_out'][aug] + eps)\n",
    "\n",
    "    likelihood_ratio = in_likelihood / (out_likelihood + eps)\n",
    "    return likelihood_ratio\n",
    "\n",
    "def get_likelihood_ratio_cnst(test_scores, dist_params, num_augs, train_scores, train_masks):\n",
    "    in_likelihood = np.zeros((test_scores.shape[1]))\n",
    "    out_likelihood = np.zeros_like(in_likelihood)\n",
    "    for aug in range(num_augs):\n",
    "        test_score_aug = test_scores[aug, :, 0]\n",
    "        in_likelihood += likelihood(\n",
    "            test_score_aug, \n",
    "            dist_params['means_in'][aug], \n",
    "            1)\n",
    "\n",
    "        out_likelihood += likelihood(\n",
    "            test_score_aug, \n",
    "            dist_params['means_out'][aug], \n",
    "            1)\n",
    "\n",
    "    likelihood_ratio = in_likelihood / (out_likelihood + eps)\n",
    "    return likelihood_ratio\n",
    "\n",
    "def get_nll_ratio(test_scores, dist_params, num_augs, train_scores, train_masks):\n",
    "    in_likelihood = np.zeros((test_scores.shape[1]))\n",
    "    out_likelihood = np.zeros_like(in_likelihood)\n",
    "    for aug in range(num_augs):\n",
    "        test_score_aug = test_scores[aug, :, 0]\n",
    "        in_likelihood += nll(\n",
    "            test_score_aug, \n",
    "            dist_params['means_in'][aug], \n",
    "            dist_params['var_in'][aug] + 1e-32)\n",
    "\n",
    "        out_likelihood += nll(\n",
    "            test_score_aug, \n",
    "            dist_params['means_out'][aug], \n",
    "            dist_params['var_in'][aug] + 1e-32)\n",
    "\n",
    "    likelihood_ratio = in_likelihood - out_likelihood\n",
    "    return -likelihood_ratio\n",
    "\n",
    "def get_identity(test_scores, dist_params, num_augs, train_scores, train_masks):\n",
    "    return -test_scores[0]\n",
    "\n",
    "def get_mast(test_scores, dist_params, num_augs, train_scores, train_masks):\n",
    "    aug = 0\n",
    "    threshold = (dist_params['means_in'][aug] +  dist_params['means_out'][aug]) / 2 \n",
    "    return -(test_scores[aug, :, 0] - threshold)\n",
    "\n",
    "def get_mast_offline(test_scores, dist_params, num_augs, train_scores, train_masks):\n",
    "    aug = 0\n",
    "    threshold = dist_params['means_out'][aug]\n",
    "    return -(test_scores[aug, :, 0] - threshold)\n",
    "\n",
    "def get_ye_et_el_attack_r(test_scores, dist_params, num_augs, train_scores, train_masks):\n",
    "    aug = 0\n",
    "    alpha = 1e-4\n",
    "    out_scores = np.where(train_masks == 0, train_scores[aug], np.NaN)\n",
    "    threshold = np.nanpercentile(out_scores, alpha * 100, axis=1)\n",
    "    return -(test_scores[aug, :, 0] - threshold)\n",
    "\n",
    "def get_class_based(test_scores, dist_params, num_augs, train_scores, train_masks):\n",
    "    # This is handled during roc calculations\n",
    "    return -test_scores[0, :, 0]\n",
    "\n",
    "def get_loss_count(test_scores, dist_params, num_augs, train_scores, train_masks):\n",
    "    in_scores = np.where(train_masks == 1, train_scores[0], np.NaN)\n",
    "    threshold = np.nanmean(in_scores)\n",
    "    pred = (test_scores[:, :, 0] < threshold).sum(0) / num_augs\n",
    "    return pred\n",
    "\n",
    "def find_tpr_at_fpr(tpr, fpr, target_fpr=1e-1):\n",
    "    \"\"\"\n",
    "    Finds the TPR when FPR equals target_fpr using interpolation.\n",
    "\n",
    "    Parameters:\n",
    "    tpr (array): Array of true positive rates.\n",
    "    fpr (array): Array of false positive rates.\n",
    "    target_fpr (float): The target false positive rate. Default is 1e-1.\n",
    "\n",
    "    Returns:\n",
    "    float: Interpolated TPR at the target FPR.\n",
    "    \"\"\"\n",
    "    # Ensure the arrays are numpy arrays\n",
    "    tpr = np.array(tpr)\n",
    "    fpr = np.array(fpr)\n",
    "\n",
    "    # Use numpy's interpolation function\n",
    "    tpr_at_target_fpr = np.interp(target_fpr, fpr, tpr)\n",
    "    return tpr_at_target_fpr\n",
    "\n",
    "methods = [\n",
    "    ('curv_zo', log_err, get_nll_ratio, 'Curv ZO NLL', num_augs, train_seeds), \n",
    "    ('curv_zo', log_err, get_likelihood_ratio_cnst, 'Curv ZO LR', num_augs, train_seeds), \n",
    "    ('prob', logit_scale, get_nll_ratio, 'Carlini et al.', num_augs, train_seeds),\n",
    "    ('loss', lambda x: x, get_identity, 'Yeom et al.', 1, [0]),\n",
    "    ('loss', lambda x: x, get_mast, 'Sablayrolles et al.', 1, train_seeds),\n",
    "    ('loss', lambda x: x, get_mast_offline, 'Watson et al.', 1, train_seeds),\n",
    "    ('loss', lambda x: x, get_ye_et_el_attack_r, 'Ye et al.', 1, train_seeds),\n",
    "    ('mentr', lambda x: x, get_class_based, 'Song et al.', 1, train_seeds),\n",
    "]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_194773/2068398035.py:49: RuntimeWarning: Mean of empty slice\n",
      "  means_in.append(np.nanmean(in_scores, axis=1))\n",
      "/tmp/ipykernel_194773/2068398035.py:50: RuntimeWarning: Mean of empty slice\n",
      "  means_out.append(np.nanmean(out_scores, axis=1))\n",
      "/tmp/ipykernel_194773/2068398035.py:51: RuntimeWarning: Degrees of freedom <= 0 for slice.\n",
      "  var_in.append(np.nanvar(in_scores, axis=1))\n",
      "/tmp/ipykernel_194773/2068398035.py:52: RuntimeWarning: Degrees of freedom <= 0 for slice.\n",
      "  var_out.append(np.nanvar(out_scores, axis=1))\n"
     ]
    }
   ],
   "source": [
    "for info_type, score_func, _, method_name, augs, seeds in methods:\n",
    "    if method_name in scores:\n",
    "        continue\n",
    "    \n",
    "    score, mask = get_scores_and_masks(\n",
    "        seeds=seeds, \n",
    "        dataset=dataset,\n",
    "        method=info_type,\n",
    "        num_augs=augs)\n",
    "\n",
    "    score = score_func(score)\n",
    "    params = get_score_dist_params(score, mask, augs)\n",
    "    scores[method_name] = score\n",
    "    masks[method_name] = mask\n",
    "    dist_params[method_name] = params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Curv ZO NLL\n",
      "Curv ZO LR\n",
      "Carlini et al.\n",
      "Yeom et al.\n",
      "Sablayrolles et al.\n",
      "Watson et al.\n",
      "Ye et al.\n",
      "Song et al.\n"
     ]
    }
   ],
   "source": [
    "test_scores = {}\n",
    "test_masks = {}\n",
    "\n",
    "for info_type, score_func, _, method_name, augs, _ in methods:\n",
    "    print(method_name)\n",
    "    test_score, test_mask = get_scores_and_masks(\n",
    "        seeds=test_seeds, \n",
    "        dataset=dataset,\n",
    "        method=info_type,\n",
    "        num_augs=augs)\n",
    "\n",
    "    test_scores[method_name] = score_func(test_score)\n",
    "    test_masks[method_name] = test_mask\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=================== Seed 52 ======================\n",
      "AUC Curv ZO NLL: 77.33, Acc 69.07\n",
      "AUC Curv ZO LR: 72.22, Acc 68.70\n",
      "AUC Carlini et al.: 73.47, Acc 66.15\n",
      "AUC Yeom et al.: 63.26, Acc 58.51\n",
      "AUC Sablayrolles et al.: 76.53, Acc 66.98\n",
      "AUC Watson et al.: 69.51, Acc 61.49\n",
      "AUC Ye et al.: 75.78, Acc 66.14\n",
      "AUC Song et al.: 63.31, Acc 57.91\n",
      "=================== Seed 53 ======================\n",
      "AUC Curv ZO NLL: 77.49, Acc 69.16\n",
      "AUC Curv ZO LR: 72.31, Acc 68.77\n",
      "AUC Carlini et al.: 73.48, Acc 66.15\n",
      "AUC Yeom et al.: 63.19, Acc 58.47\n",
      "AUC Sablayrolles et al.: 76.44, Acc 66.86\n",
      "AUC Watson et al.: 69.41, Acc 61.38\n",
      "AUC Ye et al.: 75.85, Acc 66.19\n",
      "AUC Song et al.: 63.25, Acc 57.84\n",
      "=================== Seed 54 ======================\n",
      "AUC Curv ZO NLL: 77.53, Acc 69.26\n",
      "AUC Curv ZO LR: 72.30, Acc 68.79\n",
      "AUC Carlini et al.: 73.44, Acc 66.13\n",
      "AUC Yeom et al.: 63.24, Acc 58.51\n",
      "AUC Sablayrolles et al.: 76.53, Acc 66.93\n",
      "AUC Watson et al.: 69.38, Acc 61.34\n",
      "AUC Ye et al.: 75.74, Acc 66.14\n",
      "AUC Song et al.: 63.30, Acc 57.88\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import roc_auc_score, roc_curve, balanced_accuracy_score\n",
    "\n",
    "auroc_all = {}\n",
    "bal_acc_all = {}\n",
    "\n",
    "for test_seed_idx, test_seed in enumerate(test_seeds):\n",
    "    print(f\"=================== Seed {test_seed} ======================\")\n",
    "    for idx, (info_type, _, pred_func, method_name, augs, _) in enumerate(methods):\n",
    "        # if method_name != 'Song et al.':\n",
    "        #     continue\n",
    "        y_pred = pred_func(\n",
    "            test_scores[method_name][..., test_seed_idx, np.newaxis],\n",
    "            dist_params[method_name],\n",
    "            augs,\n",
    "            scores[method_name],\n",
    "            masks[method_name])\n",
    "\n",
    "        idxs = np.where(np.isnan(y_pred) == False)[0]\n",
    "        y_true = test_masks[method_name][..., test_seed_idx, np.newaxis]\n",
    "\n",
    "        y_pred = y_pred[idxs]\n",
    "        y_true = y_true[idxs]\n",
    "        if method_name == 'Song et al.':\n",
    "            if dataset.name == 'imagenet':\n",
    "                npz = np.load('../HIKER_consolidated/analysis_checkpoints/imagenet/imagenet_index.npz', allow_pickle=True)\n",
    "                labels = npz['tr_labels'].flatten()\n",
    "            else:\n",
    "                labels = get_numpy_from_azure(\n",
    "                    container_name= \"curvature-mi-scores\", \n",
    "                    container_dir=dataset.name, \n",
    "                    file_name=f\"true_labels_cifar100_{train_end_seed}_0.1.pt\")\n",
    "            \n",
    "            auc = 0 \n",
    "            n_classes = dataset.num_classes\n",
    "            balanced_accuracy = 0\n",
    "            for label in range(n_classes):\n",
    "                indices = np.where(labels == label)[0]\n",
    "                # Compute ROC curve for each class\n",
    "                fpr, tpr, thr = roc_curve(y_true[indices], y_pred[indices])\n",
    "                auc += roc_auc_score(y_true, y_pred)\n",
    "                \n",
    "                # Find the optimal threshold: where the sum of FPR and TPR is closest to 1\n",
    "                optimal_idx = np.argmin(np.abs(fpr + tpr - 1))\n",
    "                optimal_threshold = thr[optimal_idx]\n",
    "                # Binarize predictions based on the optimal threshold\n",
    "                y_pred_binarized = (y_pred >= optimal_threshold).astype(int)\n",
    "\n",
    "                # Calculate Balanced Accuracy\n",
    "                balanced_accuracy += balanced_accuracy_score(y_true, y_pred_binarized)\n",
    "            \n",
    "            auc = auc / n_classes\n",
    "            balanced_accuracy = balanced_accuracy / n_classes\n",
    "\n",
    "        else:\n",
    "            fpr, tpr, thr = roc_curve(y_true, y_pred, pos_label=1)\n",
    "            auc = roc_auc_score(y_true, y_pred)\n",
    "\n",
    "            # Find the optimal threshold: where the sum of FPR and TPR is closest to 1\n",
    "            # \"Balanced accuracy is symmetric. That is, the metric\n",
    "            # assigns equal cost to false-positives and to false-negatives.\"\n",
    "            # - LiRA https://arxiv.org/pdf/2112.03570.pdf\n",
    "            optimal_idx = np.argmin(np.abs(fpr + tpr - 1))\n",
    "            optimal_threshold = thr[optimal_idx]\n",
    "\n",
    "            # Binarize predictions based on the optimal threshold\n",
    "            y_pred_binarized = (y_pred >= optimal_threshold).astype(int)\n",
    "\n",
    "            # Calculate Balanced Accuracy\n",
    "            balanced_accuracy = balanced_accuracy_score(y_true, y_pred_binarized)\n",
    "        \n",
    "        print(f\"AUC {method_name}: {auc*100:.2f}, Acc {balanced_accuracy * 100:.2f}\")\n",
    "        if method_name in auroc_all:\n",
    "            auroc_all[method_name] += [auc]\n",
    "            bal_acc_all[method_name] += [balanced_accuracy]\n",
    "        else:\n",
    "            auroc_all[method_name] = [auc]\n",
    "            bal_acc_all[method_name] = [balanced_accuracy]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Curv ZO NLL, auroc 77.45 $\\pm$ 0.09, bal 69.16 $\\pm$ 0.08\n",
      "Curv ZO LR, auroc 72.28 $\\pm$ 0.04, bal 68.76 $\\pm$ 0.04\n",
      "Carlini et al., auroc 73.46 $\\pm$ 0.02, bal 66.14 $\\pm$ 0.01\n",
      "Yeom et al., auroc 63.23 $\\pm$ 0.03, bal 58.50 $\\pm$ 0.02\n",
      "Sablayrolles et al., auroc 76.50 $\\pm$ 0.04, bal 66.93 $\\pm$ 0.05\n",
      "Watson et al., auroc 69.44 $\\pm$ 0.05, bal 61.40 $\\pm$ 0.06\n",
      "Ye et al., auroc 75.79 $\\pm$ 0.05, bal 66.16 $\\pm$ 0.02\n",
      "Song et al., auroc 63.29 $\\pm$ 0.03, bal 57.88 $\\pm$ 0.03\n"
     ]
    }
   ],
   "source": [
    "for k, v in auroc_all.items():\n",
    "    auc_for_k = np.array(v)\n",
    "    bal_acc_for_k = np.array(bal_acc_all[k])\n",
    "\n",
    "    print(f\"{k}, auroc {auc_for_k.mean() * 100:.2f} $\\pm$ {auc_for_k.std()* 100:.2f}, bal {bal_acc_for_k.mean()* 100:.2f} $\\pm$ {bal_acc_for_k.std()* 100:.2f}\") "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.9",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
