{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import datasets\n",
    "from privacy_estimates.experiments.aml import JobList\n",
    "from sklearn.metrics import roc_curve, roc_auc_score, auc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_performance(scores_members, scores_non_members):\n",
    "    mia_performance = {}\n",
    "    \n",
    "    member_vals = [val for val in scores_members]\n",
    "    non_member_vals = [val for val in scores_non_members]\n",
    "    mia_performance['auc'] = roc_auc_score([1]*len(member_vals) + [0]*len(non_member_vals), member_vals + non_member_vals)\n",
    "    fpr, tpr, thresholds = roc_curve([1]*len(member_vals) + [0]*len(non_member_vals), member_vals + non_member_vals)\n",
    "    for target_fpr in (0.01, 0.05, 0.1):\n",
    "        mia_performance[f'tpr_at_{target_fpr}'] = np.interp(target_fpr, fpr, tpr)\n",
    "    print(f\"AUC: {mia_performance['auc']}, TPR@0.01: {mia_performance['tpr_at_0.01']}, TPR@0.05: {mia_performance['tpr_at_0.05']}, TPR@0.1: {mia_performance['tpr_at_0.1']}\")\n",
    "\n",
    "    # also add the curves\n",
    "    mia_performance['fpr'] = fpr\n",
    "    mia_performance['tpr'] = tpr\n",
    "\n",
    "    return mia_performance\n",
    "\n",
    "# let's also make a function that computes the performance directly from the url\n",
    "\n",
    "from datetime import datetime\n",
    "\n",
    "def compute_performance_from_url(url, job_name = None):\n",
    "    jobs = JobList.from_urls([url])\n",
    "    if job_name is None:\n",
    "        job_name = str(datetime.now())\n",
    "    \n",
    "    if not os.path.exists(f'./mia_results/{job_name}'):\n",
    "        test = jobs[0].get_node('estimate_privacy').download_input('scores', f'./mia_results/{job_name}/scores')\n",
    "        test = jobs[0].get_node('estimate_privacy').download_input('challenge_bits', f'./mia_results/{job_name}/challenge_bits')\n",
    "    scores = datasets.load_from_disk(f'./mia_results/{job_name}/scores')\n",
    "    bits = datasets.load_from_disk(f'./mia_results/{job_name}/challenge_bits')\n",
    "    \n",
    "    membership_scores = np.array([k['score'] for k in scores])\n",
    "    membership_labels = np.array([k['challenge_bit'] for k in bits])\n",
    "    members = membership_scores[membership_labels == 1]\n",
    "    non_members = membership_scores[membership_labels == 0]\n",
    "    return compute_performance(members, non_members)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (1) Lets do it for SST-2\n",
    "\n",
    "Vary neighbors k and see what happens to the mia performance, for both datasets.  \n",
    "\n",
    "We ran the following:\n",
    "\n",
    "`./scripts/launch_synthetic_mias_ablation_{DATASET}.sh > synthetic_mias_ablation_{DATASET}.txt`. For synthetic canaries, 12 repetitions, uniform label. \n",
    "\n",
    "Let's get these results. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def extract_aml_urls(log_file_path):\n",
    "    # Initialize a list to store the extracted URLs\n",
    "    aml_urls = []\n",
    "\n",
    "    # Define a regular expression to match the log entries containing AML URLs\n",
    "    aml_url_pattern = re.compile(r'AML URL: (https://ml\\.azure\\.com/runs/[\\w\\-\\?&=/]+)')\n",
    "\n",
    "    # Open and read the log file\n",
    "    with open(log_file_path, 'r') as file:\n",
    "        for line in file:\n",
    "            match = aml_url_pattern.search(line)\n",
    "            if match:\n",
    "                aml_urls.append(match.group(1))\n",
    "\n",
    "    return aml_urls\n",
    "\n",
    "DATASET = 'sst2'\n",
    "\n",
    "# Extract the AML URLs\n",
    "aml_urls = extract_aml_urls(f'../job_launch_outputs/synthetic_mias_ablation_{DATASET}.txt')\n",
    "aml_urls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mia_methods=(\"jaccard_1\", \"jaccard_5\", \"jaccard_10\", \"jaccard_25\",\n",
    "             \"embedding_1\", \"embedding_5\", \"embedding_10\", \"embedding_25\",\n",
    "             \"ngram_1\", \"ngram_2\", \"ngram_3\", \"ngram_4\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method_to_auc = {}\n",
    "\n",
    "for i, method in enumerate(mia_methods):\n",
    "    job_name = f'ablation_synthetic_{method}'\n",
    "    print(f\"Processing {job_name}\")\n",
    "    performance = compute_performance_from_url(aml_urls[i])\n",
    "    method_to_auc[method] = performance['auc']\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Added the results in a table, still leave the code for figures too. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = ['ngram_loss_2', 'jaccard_25', 'embedding_10']\n",
    "\n",
    "ks = [1, 5, 10, 25]\n",
    "\n",
    "plt.figure(figsize=(6, 6))\n",
    "plt.plot(ks, [method_to_auc[f'jaccard_{k}'] for k in ks], '-o', alpha=0.8, label=r'Synthetic ($\\text{SIM}_{jac}$)', \n",
    "         color = 'darkgreen', markersize=10, linewidth=2, markeredgewidth=2, markeredgecolor='white')\n",
    "plt.plot(ks, [method_to_auc[f'embedding_{k}'] for k in ks], '-o', alpha=0.8, label=r'Synthetic ($\\text{SIM}_{emb}$)', \n",
    "         color = 'darkred', markersize=10, linewidth=2, markeredgewidth=2, markeredgecolor='white')\n",
    "plt.xlabel(r'$k$ in similarity-based MIAs', fontsize=15)\n",
    "plt.ylabel('MIA AUC', fontsize=15)\n",
    "plt.grid(True, which=\"both\", ls=\"--\", alpha=0.8)\n",
    "plt.axhline(y=0.5, color='black', linestyle='--', alpha=1.0, label = 'Random guess baseline')\n",
    "plt.legend(fontsize=12)\n",
    "plt.ylim(0.45, 0.7)\n",
    "plt.savefig('figures/ablation_k.pdf', dpi=300, bbox_inches='tight', pad_inches=0)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ns = [1, 2, 3, 4]\n",
    "\n",
    "plt.figure(figsize=(6, 6))\n",
    "plt.plot(ns, [method_to_auc[f'ngram_{n}'] for n in ns], '-o', alpha=0.8, label=r'Synthetic ($n$-gram)', \n",
    "         color = 'darkorange', markersize=10, linewidth=2, markeredgewidth=2, markeredgecolor='white')\n",
    "plt.xlabel(r'$n$ in n-gram MIAs', fontsize=15)\n",
    "plt.ylabel('MIA AUC', fontsize=15)\n",
    "plt.grid(True, which=\"both\", ls=\"--\", alpha=0.8)\n",
    "plt.axhline(y=0.5, color='black', linestyle='--', alpha=1.0, label = 'Random guess baseline')\n",
    "plt.legend(fontsize=12)\n",
    "plt.ylim(0.4, 0.7)\n",
    "plt.savefig('figures/ablation_n.pdf', dpi=300, bbox_inches='tight', pad_inches=0)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (2) Lets do it for AgNews"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET = 'agnews'\n",
    "\n",
    "# Extract the AML URLs\n",
    "aml_urls = extract_aml_urls(f'../synthetic_mias_ablation_{DATASET}.txt')\n",
    "aml_urls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method_to_auc = {}\n",
    "\n",
    "for i, method in enumerate(mia_methods):\n",
    "    job_name = f'ablation_synthetic_{method}'\n",
    "    print(f\"Processing {job_name}\")\n",
    "    performance = compute_performance_from_url(aml_urls[i])\n",
    "    method_to_auc[method] = performance['auc']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "privacy-estimates",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
