{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import datasets\n",
    "from privacy_estimates.experiments.aml import JobList\n",
    "from sklearn.metrics import roc_curve, roc_auc_score, auc\n",
    "import re\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_performance(scores_members, scores_non_members):\n",
    "    mia_performance = {}\n",
    "    \n",
    "    member_vals = [val for val in scores_members]\n",
    "    non_member_vals = [val for val in scores_non_members]\n",
    "    mia_performance['auc'] = roc_auc_score([1]*len(member_vals) + [0]*len(non_member_vals), member_vals + non_member_vals)\n",
    "    fpr, tpr, thresholds = roc_curve([1]*len(member_vals) + [0]*len(non_member_vals), member_vals + non_member_vals)\n",
    "    for target_fpr in (0.01, 0.05, 0.1):\n",
    "        mia_performance[f'tpr_at_{target_fpr}'] = np.interp(target_fpr, fpr, tpr)\n",
    "    print(f\"AUC: {mia_performance['auc']}, TPR@0.01: {mia_performance['tpr_at_0.01']}, TPR@0.05: {mia_performance['tpr_at_0.05']}, TPR@0.1: {mia_performance['tpr_at_0.1']}\")\n",
    "\n",
    "    # also add the curves\n",
    "    mia_performance['fpr'] = fpr\n",
    "    mia_performance['tpr'] = tpr\n",
    "\n",
    "    return mia_performance\n",
    "\n",
    "# let's also make a function that computes the performance directly from the url\n",
    "\n",
    "from datetime import datetime\n",
    "\n",
    "def compute_performance_from_url(url, job_name = None):\n",
    "    jobs = JobList.from_urls([url])\n",
    "    if job_name is None:\n",
    "        job_name = str(datetime.now())\n",
    "    \n",
    "    if not os.path.exists(f'./mia_results/{job_name}'):\n",
    "        test = jobs[0].get_node('estimate_privacy').download_input('scores', f'./mia_results/{job_name}/scores')\n",
    "        test = jobs[0].get_node('estimate_privacy').download_input('challenge_bits', f'./mia_results/{job_name}/challenge_bits')\n",
    "    scores = datasets.load_from_disk(f'./mia_results/{job_name}/scores')\n",
    "    bits = datasets.load_from_disk(f'./mia_results/{job_name}/challenge_bits')\n",
    "    \n",
    "    membership_scores = np.array([k['score'] for k in scores])\n",
    "    membership_labels = np.array([k['challenge_bit'] for k in bits])\n",
    "    members = membership_scores[membership_labels == 1]\n",
    "    non_members = membership_scores[membership_labels == 0]\n",
    "    return compute_performance(members, non_members)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_aml_urls(log_file_path):\n",
    "    # Initialize a list to store the extracted URLs\n",
    "    aml_urls = []\n",
    "\n",
    "    # Define a regular expression to match the log entries containing AML URLs\n",
    "    aml_url_pattern = re.compile(r'AML URL: (https://ml\\.azure\\.com/runs/[\\w\\-\\?&=/]+)')\n",
    "\n",
    "    # Open and read the log file\n",
    "    with open(log_file_path, 'r') as file:\n",
    "        for line in file:\n",
    "            match = aml_url_pattern.search(line)\n",
    "            if match:\n",
    "                aml_urls.append(match.group(1))\n",
    "\n",
    "    return aml_urls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# first the non synthetic results\n",
    "\n",
    "perplexities = [10, 10**1.5, 10**2, 10**2.5, 10**3, 10**3.5, 10**4, 10**4.5, 10**5]\n",
    "\n",
    "# we recycle two of the urls from previous experiments, so I'm writing the urls out to avoide any mess\n",
    "no_synthetic_urls = ['https://ml.azure.com/runs/crimson_street_3wtxkmjtj5?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    " 'https://ml.azure.com/runs/musing_sail_wmjq1sh67n?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    " 'https://ml.azure.com/runs/nifty_brush_m79xdknbbm?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    " 'https://ml.azure.com/runs/goofy_foot_rs86w6b5rp?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    " 'https://ml.azure.com/runs/silly_wolf_1pwzrw5zl1?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    " 'https://ml.azure.com/runs/silver_rabbit_vdf9vrby6b?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    " 'https://ml.azure.com/runs/joyful_pen_1zjx6wwwgs?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    " 'https://ml.azure.com/runs/amusing_music_ms4rgwx19z?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    " 'https://ml.azure.com/runs/keen_energy_w8hmyn721f?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47']\n",
    "\n",
    "synthetic_urls = ['https://ml.azure.com/runs/serene_crayon_qkgkmsx782?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    " 'https://ml.azure.com/runs/ashy_knee_1pxmw4c3wf?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    " 'https://ml.azure.com/runs/loving_bottle_ys6qmynzkw?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    " 'https://ml.azure.com/runs/upbeat_eagle_w9wh27mk25?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    " 'https://ml.azure.com/runs/eager_nose_lcxm2pbn34?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    " 'https://ml.azure.com/runs/strong_reggae_b6mkgkrtw0?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    " 'https://ml.azure.com/runs/orange_milk_9wnz380ggr?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    " 'https://ml.azure.com/runs/willing_candle_2h1b0yj6pv?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    " 'https://ml.azure.com/runs/lucid_lizard_gq68gf94s4?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47']\n",
    "\n",
    "assert len(no_synthetic_urls) == len(perplexities)\n",
    "assert len(synthetic_urls) == len(perplexities)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "non_synthetic_results = dict()\n",
    "synthetic_results = dict()\n",
    "\n",
    "for i in range(len(perplexities)):\n",
    "    print(f\"Perplexity: {perplexities[i]}\")\n",
    "    print(\"Non synthetic\")\n",
    "    try:\n",
    "        performance = compute_performance_from_url(no_synthetic_urls[i])\n",
    "        non_synthetic_results[perplexities[i]] = performance\n",
    "        print('---')\n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "        non_synthetic_results[perplexities[i]] ={'auc': None, 'tpr_at_0.01': None, 'tpr_at_0.05': None, 'tpr_at_0.1': None}\n",
    "        print('---')\n",
    "    \n",
    "    print('Synthetic')\n",
    "    try:\n",
    "        performance = compute_performance_from_url(synthetic_urls[i])\n",
    "        synthetic_results[perplexities[i]] = performance\n",
    "        print('---')\n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "        synthetic_results[perplexities[i]] ={'auc': None, 'tpr_at_0.01': None, 'tpr_at_0.05': None, 'tpr_at_0.1': None}\n",
    "        print('---')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6,6))\n",
    "plt.plot(perplexities, [non_synthetic_results[p]['auc'] for p in perplexities], '-o', alpha=0.8, \n",
    "             color ='darkblue', markersize=10, linewidth=2, markeredgewidth=2, markeredgecolor='white', label=r'Model - $n_{rep}=4$')\n",
    "plt.plot(perplexities, [synthetic_results[p]['auc'] for p in perplexities], '-o', alpha=0.8, \n",
    "             color ='darkorange', markersize=10, linewidth=2, markeredgewidth=2, markeredgecolor='white', label=r'Synthetic (2-gram) - $n_{rep}=16$')\n",
    "\n",
    "plt.axhline(y=0.5, color='black', linestyle='--', alpha=0.5, label = 'Random guess baseline')\n",
    "\n",
    "\n",
    "plt.xticks([10**k for k in (0, 1, 2, 3, 4, 5)])\n",
    "plt.yticks([0.5, 0.6, 0.7, 0.8, 0.9, 1.0], labels=['0.5', '0.6', '0.7', '0.8', '0.9', '1.0'])\n",
    "\n",
    "# Enable the grid\n",
    "plt.grid(True, which=\"major\", ls=\"--\", alpha=0.8)\n",
    "\n",
    "plt.legend(loc='lower left', fontsize=14)\n",
    "plt.xlabel('Canary perplexity', fontsize=16)\n",
    "plt.ylabel('AUC', fontsize=16)\n",
    "plt.ylim(0.4, 1.02)\n",
    "plt.xscale('log')\n",
    "plt.savefig('figures/ppl_experiment_agnews.pdf', dpi=300, bbox_inches='tight', pad_inches=0)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we do want to add results for in-canary results, we need to run the notebook `notebooks/compute_perplexity_canaries.ipynb` (on GPU instance). \n",
    "\n",
    "Then we can load: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle \n",
    "\n",
    "with open('ppl_in_canaries_boxplot_agnews', 'rb') as f:\n",
    "    ppl_in_canaries = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# also get the performance\n",
    "\n",
    "in_canary_no_synthetic_url =  \"https://ml.azure.com/runs/jovial_crowd_tqywfhsgfj?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47#\"\n",
    "performance = compute_performance_from_url(in_canary_no_synthetic_url)\n",
    "non_synthetic_results['in_canary'] = performance\n",
    "\n",
    "in_canary_synthetic_url = \"https://ml.azure.com/runs/icy_balloon_ww1kjz7gb5?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\"\n",
    "performance = compute_performance_from_url(in_canary_synthetic_url)\n",
    "synthetic_results['in_canary'] = performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6,6))\n",
    "plt.plot(perplexities, [non_synthetic_results[p]['auc'] for p in perplexities], '-o', alpha=0.8, \n",
    "             color ='darkblue', markersize=10, linewidth=2, markeredgewidth=2, markeredgecolor='white', label=r'Model - $n_{rep}=4$')\n",
    "\n",
    "plt.boxplot(\n",
    "    ppl_in_canaries,\n",
    "    vert=False,\n",
    "    positions=[non_synthetic_results['in_canary']['auc']],\n",
    "    widths=0.02,  # Adjust width to fit your plot\n",
    "    patch_artist=True,\n",
    "    boxprops=dict(facecolor='darkblue', color='darkblue'),\n",
    "    whiskerprops=dict(color='darkblue'),\n",
    "    capprops=dict(color='darkblue'),\n",
    "    flierprops=dict(marker='o', color='darkblue', alpha=0.0),\n",
    "    medianprops=dict(color='white', linewidth=1)  # Change the median line color to red or black\n",
    ")\n",
    "\n",
    "plt.plot(perplexities,\n",
    "        [synthetic_results[p]['auc'] for p in perplexities],\n",
    "        '-o', alpha=0.8, color ='darkorange', markersize=10, linewidth=2, markeredgewidth=2, markeredgecolor='white', label=r'Synthetic (2-gram) - $n_{rep}=16$')\n",
    "\n",
    "plt.boxplot(\n",
    "    ppl_in_canaries,\n",
    "    vert=False,\n",
    "    positions=[synthetic_results['in_canary']['auc']],\n",
    "    widths=0.02,  # Adjust width to fit your plot\n",
    "    patch_artist=True,\n",
    "    boxprops=dict(facecolor='darkorange', color='darkorange'),\n",
    "    whiskerprops=dict(color='darkorange'),\n",
    "    capprops=dict(color='darkorange'),\n",
    "    flierprops=dict(marker='o', color='darkorange', alpha=0.0),\n",
    "    medianprops=dict(color='white', linewidth=1)  # Change the median line color to red or black\n",
    ")\n",
    "\n",
    "# just for the label\n",
    "plt.boxplot(\n",
    "    ppl_in_canaries,\n",
    "    vert=False,\n",
    "    positions=[2.0],\n",
    "    widths=0.02,  # Adjust width to fit your plot\n",
    "    patch_artist=True,\n",
    "    boxprops=dict(facecolor='lightgrey', color='grey'),\n",
    "    whiskerprops=dict(color='grey'),\n",
    "    capprops=dict(color='grey'),\n",
    "    flierprops=dict(marker='o', color='grey', alpha=0.5),\n",
    "    medianprops=dict(color='white', linewidth=1), \n",
    "    label = 'In distribution'\n",
    ")\n",
    "\n",
    "plt.axhline(y=0.5, color='black', linestyle='--', alpha=0.5, label = 'Random guess baseline')\n",
    "\n",
    "plt.xticks([10**k for k in (0, 1, 2, 3, 4, 5)])\n",
    "plt.yticks([0.5, 0.6, 0.7, 0.8, 0.9, 1.0], labels=['0.5', '0.6', '0.7', '0.8', '0.9', '1.0'])\n",
    "\n",
    "# Enable the grid\n",
    "plt.grid(True, which=\"major\", ls=\"--\", alpha=0.8)\n",
    "\n",
    "plt.legend(loc='lower left', fontsize=11)\n",
    "plt.xlabel('Canary perplexity', fontsize=16)\n",
    "plt.ylabel('AUC', fontsize=16)\n",
    "plt.ylim(0.45, 1.02)\n",
    "plt.xscale('log')\n",
    "plt.savefig('figures/ppl_experiment_agnews_w_in_canaries.pdf', dpi=300, bbox_inches='tight', pad_inches=0)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "id_canaries = pd.DataFrame({\"ppl\": ppl_in_canaries}).quantile([0, 0.25, 0.5, 0.75, 1]).rename_axis('percentile').reset_index()\n",
    "id_canaries[\"model\"] = [non_synthetic_results['in_canary']['auc']]*5\n",
    "id_canaries[\"synthetic\"] = [synthetic_results['in_canary']['auc']]*5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = {\n",
    "    \"ood_canaries\": pd.DataFrame({\"ppl\": perplexities, \"model\": [non_synthetic_results[p]['auc'] for p in perplexities], \"synthetic\": [synthetic_results[p]['auc'] for p in perplexities]}),\n",
    "    \"id_canaries\": id_canaries\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_roc = dict()\n",
    "for p in perplexities:\n",
    "    df_roc[f\"perp_{int(p)}_model\"] = pd.DataFrame({\"fpr\": non_synthetic_results[p]['fpr'], \"tpr\": non_synthetic_results[p]['tpr']})\n",
    "    df_roc[f\"perp_{int(p)}_synthetic\"] = pd.DataFrame({\"fpr\": synthetic_results[p]['fpr'], \"tpr\": synthetic_results[p]['tpr']})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from latex import Project\n",
    "overleaf = Project.from_env(path_env_name=\"LATEX_GIT_PATH\")\n",
    "overleaf.add_dataframe(df[\"ood_canaries\"], 'data/canary_ppl/agnews/auc/ood_canaries.tsv')\n",
    "overleaf.add_dataframe(df[\"id_canaries\"], 'data/canary_ppl/agnews/auc/id_canaries.tsv')\n",
    "\n",
    "for df_roc_name, df_roc_data in df_roc.items():\n",
    "    overleaf.add_dataframe(df_roc_data, f'data/canary_ppl/agnews/roc/{df_roc_name}.tsv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "privacy-estimates",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
