{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import datasets\n",
    "from privacy_estimates.experiments.aml import JobList\n",
    "from sklearn.metrics import roc_curve, roc_auc_score\n",
    "import os\n",
    "from datetime import datetime\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_performance(scores_members, scores_non_members):\n",
    "    mia_performance = {}\n",
    "    \n",
    "    member_vals = [val for val in scores_members]\n",
    "    non_member_vals = [val for val in scores_non_members]\n",
    "    mia_performance['auc'] = roc_auc_score([1]*len(member_vals) + [0]*len(non_member_vals), member_vals + non_member_vals)\n",
    "    fpr, tpr, thresholds = roc_curve([1]*len(member_vals) + [0]*len(non_member_vals), member_vals + non_member_vals)\n",
    "    for target_fpr in (0.01, 0.05, 0.1):\n",
    "        mia_performance[f'tpr_at_{target_fpr}'] = np.interp(target_fpr, fpr, tpr)\n",
    "    print(f\"AUC: {mia_performance['auc']}, TPR@0.01: {mia_performance['tpr_at_0.01']}, TPR@0.05: {mia_performance['tpr_at_0.05']}, TPR@0.1: {mia_performance['tpr_at_0.1']}\")\n",
    "\n",
    "    # also add the curves\n",
    "    mia_performance['fpr'] = fpr\n",
    "    mia_performance['tpr'] = tpr\n",
    "\n",
    "    return mia_performance\n",
    "\n",
    "def compute_performance_from_url(url, job_name = None):\n",
    "    jobs = JobList.from_urls([url])\n",
    "    if job_name is None:\n",
    "        job_name = str(datetime.now())\n",
    "    \n",
    "    if not os.path.exists(f'./mia_results/{job_name}'):\n",
    "        test = jobs[0].get_node('estimate_privacy').download_input('scores', f'./mia_results/{job_name}/scores')\n",
    "        test = jobs[0].get_node('estimate_privacy').download_input('challenge_bits', f'./mia_results/{job_name}/challenge_bits')\n",
    "    scores = datasets.load_from_disk(f'./mia_results/{job_name}/scores')\n",
    "    bits = datasets.load_from_disk(f'./mia_results/{job_name}/challenge_bits')\n",
    "    \n",
    "    membership_scores = np.array([k['score'] for k in scores])\n",
    "    membership_labels = np.array([k['challenge_bit'] for k in bits])\n",
    "    members = membership_scores[membership_labels == 1]\n",
    "    non_members = membership_scores[membership_labels == 0]\n",
    "    return compute_performance(members, non_members)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read json\n",
    "\n",
    "with open('ppl_experiment_urls.json') as f:\n",
    "    all_urls = json.load(f)\n",
    "all_urls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# first the non synthetic results\n",
    "\n",
    "no_synthetic_urls = list(all_urls['no_synthetic'].values())\n",
    "no_synthetic_ppls = list(all_urls['no_synthetic'].keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "non_synthetic_results = dict()\n",
    "\n",
    "for i in range(len(no_synthetic_ppls)):\n",
    "    print(f\"Perplexity: {no_synthetic_ppls[i]}\")\n",
    "    performance = compute_performance_from_url(no_synthetic_urls[i], job_name = f'no_synthetic_{no_synthetic_ppls[i]}')\n",
    "    non_synthetic_results[no_synthetic_ppls[i]] = performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now the synthetic results\n",
    "\n",
    "synthetic_urls = list(all_urls['synthetic'].values())\n",
    "synthetic_ppls = list(all_urls['synthetic'].keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "synthetic_results = dict()\n",
    "\n",
    "for i in range(len(synthetic_ppls)):\n",
    "    print(f\"Perplexity: {synthetic_ppls[i]}\")\n",
    "    performance = compute_performance_from_url(synthetic_urls[i], job_name = f'synthetic_{synthetic_ppls[i]}')\n",
    "    synthetic_results[synthetic_ppls[i]] = performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_ppl_val(ppl_str: str):\n",
    "    return (int(ppl_str.split('_')[0]) + int(ppl_str.split('_')[1])) / 2\n",
    "\n",
    "no_synthetic_ppl_vals = [get_ppl_val(k) for k in no_synthetic_ppls if k not in ('inf', 'in_canary')]\n",
    "synthetic_ppl_vals = [get_ppl_val(k) for k in synthetic_ppls if k not in ('inf', 'in_canary')]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle \n",
    "\n",
    "with open('ppl_in_canaries_boxplot', 'rb') as f:\n",
    "    ppl_in_canaries = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "median_random_token_ppl = 139524.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6,6))\n",
    "plt.plot(no_synthetic_ppl_vals + [median_random_token_ppl],\n",
    "        [non_synthetic_results[p]['auc'] for p in no_synthetic_ppls if p not in ('inf', 'in_canary')] + [non_synthetic_results['inf']['auc']],\n",
    "        '-o', alpha=0.8, color ='darkblue', markersize=10, linewidth=2, markeredgewidth=2, markeredgecolor='white', label=r'Model - $n_{rep}=4$')\n",
    "\n",
    "plt.boxplot(\n",
    "    ppl_in_canaries,\n",
    "    vert=False,\n",
    "    positions=[non_synthetic_results['in_canary']['auc']],\n",
    "    widths=0.02,  # Adjust width to fit your plot\n",
    "    patch_artist=True,\n",
    "    boxprops=dict(facecolor='darkblue', color='darkblue'),\n",
    "    whiskerprops=dict(color='darkblue'),\n",
    "    capprops=dict(color='darkblue'),\n",
    "    flierprops=dict(marker='o', color='darkblue', alpha=0.0),\n",
    "    medianprops=dict(color='white', linewidth=1)  # Change the median line color to red or black\n",
    ")\n",
    "\n",
    "plt.plot(synthetic_ppl_vals + [median_random_token_ppl],\n",
    "        [synthetic_results[p]['auc'] for p in synthetic_ppls if p not in ('inf', 'in_canary')] + [synthetic_results['inf']['auc']],\n",
    "        '-o', alpha=0.8, color ='darkorange', markersize=10, linewidth=2, markeredgewidth=2, markeredgecolor='white', label=r'Synthetic (2-gram) - $n_{rep}=16$')\n",
    "\n",
    "plt.boxplot(\n",
    "    ppl_in_canaries,\n",
    "    vert=False,\n",
    "    positions=[synthetic_results['in_canary']['auc']],\n",
    "    widths=0.02,  # Adjust width to fit your plot\n",
    "    patch_artist=True,\n",
    "    boxprops=dict(facecolor='darkorange', color='darkorange'),\n",
    "    whiskerprops=dict(color='darkorange'),\n",
    "    capprops=dict(color='darkorange'),\n",
    "    flierprops=dict(marker='o', color='darkorange', alpha=0.0),\n",
    "    medianprops=dict(color='white', linewidth=1)  # Change the median line color to red or black\n",
    ")\n",
    "\n",
    "# just for the label\n",
    "plt.boxplot(\n",
    "    ppl_in_canaries,\n",
    "    vert=False,\n",
    "    positions=[2.0],\n",
    "    widths=0.02,  # Adjust width to fit your plot\n",
    "    patch_artist=True,\n",
    "    boxprops=dict(facecolor='lightgrey', color='grey'),\n",
    "    whiskerprops=dict(color='grey'),\n",
    "    capprops=dict(color='grey'),\n",
    "    flierprops=dict(marker='o', color='grey', alpha=0.5),\n",
    "    medianprops=dict(color='white', linewidth=1), \n",
    "    label = 'In distribution'\n",
    ")\n",
    "\n",
    "plt.axhline(y=0.5, color='black', linestyle='--', alpha=0.5, label = 'Random guess baseline')\n",
    "\n",
    "plt.xticks([10**k for k in (0, 1, 2, 3, 4, 5)])\n",
    "plt.yticks([0.5, 0.6, 0.7, 0.8, 0.9, 1.0], labels=['0.5', '0.6', '0.7', '0.8', '0.9', '1.0'])\n",
    "\n",
    "# Enable the grid\n",
    "plt.grid(True, which=\"major\", ls=\"--\", alpha=0.8)\n",
    "\n",
    "plt.legend(loc='lower left', fontsize=11)\n",
    "plt.xlabel('Canary perplexity', fontsize=16)\n",
    "plt.ylabel('AUC', fontsize=16)\n",
    "plt.ylim(0.45, 1.02)\n",
    "plt.xscale('log')\n",
    "plt.savefig('figures/ppl_experiment.pdf', dpi=300, bbox_inches='tight', pad_inches=0)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "privacy-estimates",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
