{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from utils.load_dataset import load_dataset\n",
    "import logging\n",
    "from azure_blob_storage import get_numpy_from_azure\n",
    "from tqdm import tqdm\n",
    "\n",
    "logger = logging.getLogger(\"\")\n",
    "dataset_name = 'cifar10'\n",
    "num_augs = 2\n",
    "\n",
    "dataset = load_dataset(\n",
    "    dataset=dataset_name,\n",
    "    train_batch_size=128,\n",
    "    test_batch_size=128,\n",
    "    val_split=0.0,\n",
    "    augment=False,\n",
    "    shuffle=False,\n",
    "    root_path='set path',\n",
    "    random_seed=0,\n",
    "    mean=[0, 0, 0],\n",
    "    std=[1, 1, 1],\n",
    "    logger=logger,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_scores_and_masks(seeds, dataset, method, num_augs):\n",
    "    h = 0.001\n",
    "    num_seeds = len(seeds)\n",
    "    score_container_dir = dataset.name\n",
    "    aug_scores_file_name = {\n",
    "        'curv_zo' : \"curvature_scores_zo_{}_resnet18_ds{}_shadow.ckpt_{}_tid{}.pt\",\n",
    "        'curv' : \"curvature_scores_{}_resnet18_ds{}_shadow.ckpt_{}_tid{}.pt\",\n",
    "        'prob' : \"prob_{}_resnet18_ds{}_shadow.ckpt_{}_tid{}.pt\",\n",
    "        'loss' : \"losses_{}_resnet18_ds{}_shadow.ckpt_{}_tid{}.pt\",\n",
    "        'mentr' : \"m_entropy_{}_resnet18_ds{}_shadow.ckpt_{}_tid{}.pt\",\n",
    "        'loss_g' : \"loss_g_{}_resnet18_ds{}_shadow.ckpt_{}_tid{}.pt\",\n",
    "    }\n",
    "\n",
    "    score_seeds_v_seed = np.zeros((num_augs, dataset.train_length, num_seeds))\n",
    "    mask_seeds_v_seed = np.zeros((dataset.train_length, num_seeds))\n",
    "    start_seed = seeds[0]\n",
    "    for exp_idx in seeds:\n",
    "        array_idx = exp_idx - start_seed\n",
    "        augs = 1 if method == 'mentr' else num_augs\n",
    "        for aug_idx in range(augs):\n",
    "            file_name = aug_scores_file_name[method].format(dataset.name, exp_idx, h, aug_idx)\n",
    "            scores = get_numpy_from_azure(\n",
    "                container_name= \"curvature-mi-scores\", \n",
    "                container_dir=score_container_dir, \n",
    "                file_name=file_name)\n",
    "\n",
    "            score_seeds_v_seed[aug_idx, :, array_idx] = scores\n",
    "\n",
    "        if dataset == 'imagenet':\n",
    "            array = np.load(f'/path/imagenet-resnet50/{0.7}/{exp_idx}/aux_arrays.npz')\n",
    "            mask_idxs = array['subsample_idx']\n",
    "            mask_seeds_v_seed[mask_idxs, array_idx] = 1\n",
    "        else:\n",
    "            blob_container = \"curvature-mi-models\"\n",
    "            container_dir = \"cifar100_indexs\"\n",
    "            mask_idxs = get_numpy_from_azure(blob_container, container_dir, f\"{exp_idx}.npy\")\n",
    "            mask_seeds_v_seed[mask_idxs, array_idx] = 1\n",
    "\n",
    "    return score_seeds_v_seed, mask_seeds_v_seed\n",
    "\n",
    "def get_score_dist_params(scores, masks, num_augs):\n",
    "    means_in = []\n",
    "    means_out = []\n",
    "    var_in = []\n",
    "    var_out = []\n",
    "    for aug in range(num_augs):\n",
    "        in_scores = np.where(masks == 1, scores[aug], np.NaN)\n",
    "        out_scores = np.where(masks == 0, scores[aug], np.NaN)\n",
    "        means_in.append(np.nanmean(in_scores, axis=1))\n",
    "        means_out.append(np.nanmean(out_scores, axis=1))\n",
    "        var_in.append(np.nanvar(in_scores, axis=1))\n",
    "        var_out.append(np.nanvar(out_scores, axis=1))\n",
    "\n",
    "    return {\n",
    "        'means_in': np.row_stack(means_in),\n",
    "        'means_out': np.row_stack(means_out),\n",
    "        'var_in': np.row_stack(var_in),\n",
    "        'var_out': np.row_stack(var_out)\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = {}\n",
    "masks = {}\n",
    "dist_params = {}\n",
    "train_end_seed = 64\n",
    "train_seeds = list(range(train_end_seed))\n",
    "test_seeds = list(range(train_end_seed, train_end_seed + 3))\n",
    "eps = 1e-22"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def log_err(x):\n",
    "    return np.log(x + eps)\n",
    "\n",
    "def logit_scale(x):\n",
    "    x = x / (1 - x + eps)\n",
    "    return log_err(x)\n",
    "\n",
    "def nll(curv_score, mean, var):\n",
    "    ll = - ( ( (curv_score - mean)**2) / (2 * (var ** 2) ) ) -0.5 * np.log(var ** 2) - 0.5 * np.log(2 * np.pi)\n",
    "    return -ll\n",
    "\n",
    "def likelihood(curv_score, mean, var):\n",
    "    nll = - ( ( (curv_score - mean)**2) / (2 * (var ** 2) ) ) - 0.5 * np.log(var ** 2) - 0.5 * np.log(2 * np.pi)\n",
    "    likelihood_val = np.exp(nll)\n",
    "    return likelihood_val\n",
    "\n",
    "def get_likelihood_ratio(test_scores, dist_params, num_augs, train_scores, train_masks):\n",
    "    in_likelihood = np.zeros((test_scores.shape[1]))\n",
    "    out_likelihood = np.zeros_like(in_likelihood)\n",
    "    for aug in range(num_augs):\n",
    "        test_score_aug = test_scores[aug, :, 0]\n",
    "        in_likelihood += likelihood(\n",
    "            test_score_aug, \n",
    "            dist_params['means_in'][aug], \n",
    "            dist_params['var_in'][aug] + eps)\n",
    "\n",
    "        out_likelihood += likelihood(\n",
    "            test_score_aug, \n",
    "            dist_params['means_out'][aug], \n",
    "            dist_params['var_out'][aug] + eps)\n",
    "\n",
    "    likelihood_ratio = in_likelihood / (out_likelihood + eps)\n",
    "    return likelihood_ratio\n",
    "\n",
    "def get_likelihood_ratio_c(test_scores, dist_params, num_augs, train_scores, train_masks):\n",
    "    in_likelihood = np.zeros((test_scores.shape[1]))\n",
    "    out_likelihood = np.zeros_like(in_likelihood)\n",
    "    for aug in range(num_augs):\n",
    "        test_score_aug = test_scores[aug, :, 0]\n",
    "        in_likelihood += likelihood(\n",
    "            test_score_aug, \n",
    "            dist_params['means_in'][aug], \n",
    "            1)\n",
    "\n",
    "        out_likelihood += likelihood(\n",
    "            test_score_aug, \n",
    "            dist_params['means_out'][aug], \n",
    "            1)\n",
    "\n",
    "    likelihood_ratio = in_likelihood / (out_likelihood + eps)\n",
    "    return likelihood_ratio\n",
    "\n",
    "def get_nll_ratio_c(test_scores, dist_params, num_augs, train_scores, train_masks):\n",
    "    in_likelihood = np.zeros((test_scores.shape[1]))\n",
    "    out_likelihood = np.zeros_like(in_likelihood)\n",
    "    for aug in range(num_augs):\n",
    "        test_score_aug = test_scores[aug, :, 0]\n",
    "        in_likelihood += nll(\n",
    "            test_score_aug, \n",
    "            dist_params['means_in'][aug], \n",
    "            1)\n",
    "\n",
    "        out_likelihood += nll(\n",
    "            test_score_aug, \n",
    "            dist_params['means_out'][aug], \n",
    "            1)\n",
    "\n",
    "    likelihood_ratio = in_likelihood - out_likelihood\n",
    "    return -likelihood_ratio\n",
    "\n",
    "def get_nll_ratio(test_scores, dist_params, num_augs, train_scores, train_masks):\n",
    "    in_likelihood = np.zeros((test_scores.shape[1]))\n",
    "    out_likelihood = np.zeros_like(in_likelihood)\n",
    "    for aug in range(num_augs):\n",
    "        test_score_aug = test_scores[aug, :, 0]\n",
    "        in_likelihood += nll(\n",
    "            test_score_aug, \n",
    "            dist_params['means_in'][aug], \n",
    "            dist_params['var_in'][aug] + 1e-32)\n",
    "\n",
    "        out_likelihood += nll(\n",
    "            test_score_aug, \n",
    "            dist_params['means_out'][aug], \n",
    "            dist_params['var_in'][aug] + 1e-32)\n",
    "\n",
    "    likelihood_ratio = in_likelihood - out_likelihood\n",
    "    return -likelihood_ratio\n",
    "\n",
    "def get_identity(test_scores, dist_params, num_augs, train_scores, train_masks):\n",
    "    return -test_scores[0]\n",
    "\n",
    "def get_mast(test_scores, dist_params, num_augs, train_scores, train_masks):\n",
    "    aug = 0\n",
    "    threshold = (dist_params['means_in'][aug] +  dist_params['means_out'][aug]) / 2 \n",
    "    return -(test_scores[aug, :, 0] - threshold)\n",
    "\n",
    "def get_mast_offline(test_scores, dist_params, num_augs, train_scores, train_masks):\n",
    "    aug = 0\n",
    "    threshold = dist_params['means_out'][aug]\n",
    "    return -(test_scores[aug, :, 0] - threshold)\n",
    "\n",
    "def get_ye_et_el_attack_r(test_scores, dist_params, num_augs, train_scores, train_masks):\n",
    "    aug = 0\n",
    "    alpha = 1e-4\n",
    "    out_scores = np.where(train_masks == 0, train_scores[aug], np.NaN)\n",
    "    threshold = np.nanpercentile(out_scores, alpha * 100, axis=1)\n",
    "    return -(test_scores[aug, :, 0] - threshold)\n",
    "\n",
    "def get_class_based(test_scores, dist_params, num_augs, train_scores, train_masks):\n",
    "    # This is handled during roc calculations\n",
    "    return -test_scores[0, :, 0]\n",
    "\n",
    "def get_loss_count(test_scores, dist_params, num_augs, train_scores, train_masks):\n",
    "    in_scores = np.where(train_masks == 1, train_scores[0], np.NaN)\n",
    "    threshold = np.nanmean(in_scores)\n",
    "    pred = (test_scores[:, :, 0] < threshold).sum(0) / num_augs\n",
    "    return pred\n",
    "\n",
    "methods = [\n",
    "    ('curv_zo', log_err, get_nll_ratio, 'Curv ZO NLL', num_augs, train_seeds), \n",
    "    ('curv_zo', log_err, get_likelihood_ratio, 'Curv ZO LR', num_augs, train_seeds), \n",
    "    ('prob', logit_scale, get_likelihood_ratio, 'Carlini et al.', num_augs, train_seeds),\n",
    "    ('loss', lambda x: x, get_identity, 'Yeom et al.', 1, [0]),\n",
    "    ('loss', lambda x: x, get_mast, 'Sablayrolles et al.', 1, train_seeds),\n",
    "    ('loss', lambda x: x, get_mast_offline, 'Watson et al.', 1, train_seeds),\n",
    "    ('loss', lambda x: x, get_ye_et_el_attack_r, 'Ye et al.', 1, train_seeds),\n",
    "    ('mentr', lambda x: x, get_class_based, 'Song et al.', 1, train_seeds),\n",
    "]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 38%|███▊      | 3/8 [02:26<04:03, 48.79s/it]/tmp/ipykernel_199063/2565310228.py:49: RuntimeWarning: Mean of empty slice\n",
      "  means_in.append(np.nanmean(in_scores, axis=1))\n",
      "/tmp/ipykernel_199063/2565310228.py:50: RuntimeWarning: Mean of empty slice\n",
      "  means_out.append(np.nanmean(out_scores, axis=1))\n",
      "/tmp/ipykernel_199063/2565310228.py:51: RuntimeWarning: Degrees of freedom <= 0 for slice.\n",
      "  var_in.append(np.nanvar(in_scores, axis=1))\n",
      "/tmp/ipykernel_199063/2565310228.py:52: RuntimeWarning: Degrees of freedom <= 0 for slice.\n",
      "  var_out.append(np.nanvar(out_scores, axis=1))\n",
      "100%|██████████| 8/8 [04:40<00:00, 35.02s/it]\n"
     ]
    }
   ],
   "source": [
    "for info_type, score_func, _, method_name, augs, seeds in tqdm(methods):\n",
    "    if method_name in scores:\n",
    "        continue\n",
    "    \n",
    "    score, mask = get_scores_and_masks(\n",
    "        seeds=seeds, \n",
    "        dataset=dataset,\n",
    "        method=info_type,\n",
    "        num_augs=augs)\n",
    "\n",
    "    score = score_func(score)\n",
    "    params = get_score_dist_params(score, mask, augs)\n",
    "    scores[method_name] = score\n",
    "    masks[method_name] = mask\n",
    "    dist_params[method_name] = params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8/8 [00:14<00:00,  1.85s/it]\n"
     ]
    }
   ],
   "source": [
    "test_scores = {}\n",
    "test_masks = {}\n",
    "\n",
    "for info_type, score_func, _, method_name, augs, _ in tqdm(methods):\n",
    "    test_score, test_mask = get_scores_and_masks(\n",
    "        seeds=test_seeds, \n",
    "        dataset=dataset,\n",
    "        method=info_type,\n",
    "        num_augs=augs)\n",
    "\n",
    "    test_scores[method_name] = score_func(test_score)\n",
    "    test_masks[method_name] = test_mask\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=================== Seed 64 ======================\n",
      "AUC Curv ZO NLL: 70.15, Acc 62.92\n",
      "AUC Curv ZO LR: 59.25, Acc 55.12\n",
      "AUC Carlini et al.: 61.93, Acc 58.44\n",
      "AUC Yeom et al.: 61.12, Acc 55.97\n",
      "AUC Sablayrolles et al.: 61.28, Acc 56.48\n",
      "AUC Watson et al.: 58.42, Acc 54.78\n",
      "AUC Ye et al.: 68.39, Acc 60.32\n",
      "AUC Song et al.: 61.17, Acc 56.10\n",
      "=================== Seed 65 ======================\n",
      "AUC Curv ZO NLL: 67.05, Acc 60.80\n",
      "AUC Curv ZO LR: 58.36, Acc 54.76\n",
      "AUC Carlini et al.: 61.29, Acc 57.82\n",
      "AUC Yeom et al.: 59.39, Acc 54.84\n",
      "AUC Sablayrolles et al.: 62.57, Acc 57.40\n",
      "AUC Watson et al.: 59.70, Acc 55.63\n",
      "AUC Ye et al.: 65.55, Acc 58.43\n",
      "AUC Song et al.: 59.23, Acc 54.78\n",
      "=================== Seed 66 ======================\n",
      "AUC Curv ZO NLL: 69.27, Acc 62.03\n",
      "AUC Curv ZO LR: 59.05, Acc 55.11\n",
      "AUC Carlini et al.: 61.99, Acc 58.43\n",
      "AUC Yeom et al.: 60.80, Acc 55.90\n",
      "AUC Sablayrolles et al.: 60.66, Acc 56.07\n",
      "AUC Watson et al.: 57.62, Acc 54.19\n",
      "AUC Ye et al.: 67.97, Acc 60.11\n",
      "AUC Song et al.: 60.86, Acc 56.02\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import roc_auc_score, roc_curve, balanced_accuracy_score\n",
    "\n",
    "auroc_all = {}\n",
    "bal_acc_all = {}\n",
    "\n",
    "for test_seed_idx, test_seed in enumerate(test_seeds):\n",
    "    print(f\"=================== Seed {test_seed} ======================\")\n",
    "    for idx, (info_type, _, pred_func, method_name, augs, _) in enumerate(methods):\n",
    "        y_pred = pred_func(\n",
    "            test_scores[method_name][..., test_seed_idx, np.newaxis],\n",
    "            dist_params[method_name],\n",
    "            augs,\n",
    "            scores[method_name],\n",
    "            masks[method_name])\n",
    "\n",
    "        y_true = test_masks[method_name][..., test_seed_idx, np.newaxis]\n",
    "        if method_name == 'Song et al.':\n",
    "            labels = get_numpy_from_azure(\n",
    "                container_name= \"curvature-mi-scores\", \n",
    "                container_dir=dataset.name, \n",
    "                file_name=f\"true_labels_cifar10_{0}_0.001.pt\")\n",
    "            \n",
    "            auc = 0 \n",
    "            n_classes = dataset.num_classes\n",
    "            balanced_accuracy = 0\n",
    "            for label in range(n_classes):\n",
    "                indices = np.where(labels == label)[0]\n",
    "                # Compute ROC curve for each class\n",
    "                fpr, tpr, thr = roc_curve(y_true[indices], y_pred[indices])\n",
    "                auc += roc_auc_score(y_true, y_pred)\n",
    "                \n",
    "                # Find the optimal threshold: where the sum of FPR and TPR is closest to 1\n",
    "                optimal_idx = np.argmin(np.abs(fpr + tpr - 1))\n",
    "                optimal_threshold = thr[optimal_idx]\n",
    "                # Binarize predictions based on the optimal threshold\n",
    "                y_pred_binarized = (y_pred >= optimal_threshold).astype(int)\n",
    "\n",
    "                # Calculate Balanced Accuracy\n",
    "                balanced_accuracy += balanced_accuracy_score(y_true, y_pred_binarized)\n",
    "            \n",
    "            auc = auc / n_classes\n",
    "            balanced_accuracy = balanced_accuracy / n_classes\n",
    "\n",
    "        else:\n",
    "            fpr, tpr, thr = roc_curve(y_true, y_pred, pos_label=1)\n",
    "            auc = roc_auc_score(y_true, y_pred)\n",
    "\n",
    "            # Find the optimal threshold: where the sum of FPR and TPR is closest to 1\n",
    "            # \"Balanced accuracy is symmetric. That is, the metric\n",
    "            # assigns equal cost to false-positives and to false-negatives.\"\n",
    "            # - LiRA https://arxiv.org/pdf/2112.03570.pdf\n",
    "            optimal_idx = np.argmin(np.abs(fpr + tpr - 1))\n",
    "            optimal_threshold = thr[optimal_idx]\n",
    "\n",
    "            # Binarize predictions based on the optimal threshold\n",
    "            y_pred_binarized = (y_pred >= optimal_threshold).astype(int)\n",
    "\n",
    "            # Calculate Balanced Accuracy\n",
    "            balanced_accuracy = balanced_accuracy_score(y_true, y_pred_binarized)\n",
    "        \n",
    "        print(f\"AUC {method_name}: {auc*100:.2f}, Acc {balanced_accuracy * 100:.2f}\")\n",
    "        if method_name in auroc_all:\n",
    "            auroc_all[method_name] += [auc]\n",
    "            bal_acc_all[method_name] += [balanced_accuracy]\n",
    "        else:\n",
    "            auroc_all[method_name] = [auc]\n",
    "            bal_acc_all[method_name] = [balanced_accuracy]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Curv ZO NLL, auroc 68.82 $\\pm$ 1.30, bal 61.92 $\\pm$ 0.87\n",
      "Curv ZO LR, auroc 58.89 $\\pm$ 0.38, bal 55.00 $\\pm$ 0.17\n",
      "Carlini et al., auroc 61.73 $\\pm$ 0.32, bal 58.23 $\\pm$ 0.29\n",
      "Yeom et al., auroc 60.44 $\\pm$ 0.75, bal 55.57 $\\pm$ 0.52\n",
      "Sablayrolles et al., auroc 61.50 $\\pm$ 0.79, bal 56.65 $\\pm$ 0.56\n",
      "Watson et al., auroc 58.58 $\\pm$ 0.86, bal 54.86 $\\pm$ 0.59\n",
      "Ye et al., auroc 67.30 $\\pm$ 1.25, bal 59.62 $\\pm$ 0.84\n",
      "Song et al., auroc 60.42 $\\pm$ 0.85, bal 55.63 $\\pm$ 0.61\n"
     ]
    }
   ],
   "source": [
    "for k, v in auroc_all.items():\n",
    "    auc_for_k = np.array(v)\n",
    "    bal_acc_for_k = np.array(bal_acc_all[k])\n",
    "\n",
    "    print(f\"{k}, auroc {auc_for_k.mean() * 100:.2f} $\\pm$ {auc_for_k.std()* 100:.2f}, bal {bal_acc_for_k.mean()* 100:.2f} $\\pm$ {bal_acc_for_k.std()* 100:.2f}\") "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.9",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
