{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import datasets\n",
    "from privacy_estimates.experiments.aml import JobList\n",
    "from sklearn.metrics import roc_curve, roc_auc_score, auc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (1) Do this for one job. \n",
    "\n",
    "Quite simple: how can I just compute the AUC and tpr at low FPR for just one job? "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_performance(scores_members, scores_non_members):\n",
    "    mia_performance = {}\n",
    "    \n",
    "    member_vals = [val for val in scores_members]\n",
    "    non_member_vals = [val for val in scores_non_members]\n",
    "    mia_performance['auc'] = roc_auc_score([1]*len(member_vals) + [0]*len(non_member_vals), member_vals + non_member_vals)\n",
    "    fpr, tpr, thresholds = roc_curve([1]*len(member_vals) + [0]*len(non_member_vals), member_vals + non_member_vals)\n",
    "    for target_fpr in (0.01, 0.05, 0.1):\n",
    "        mia_performance[f'tpr_at_{target_fpr}'] = np.interp(target_fpr, fpr, tpr)\n",
    "    print(f\"AUC: {mia_performance['auc']}, TPR@0.01: {mia_performance['tpr_at_0.01']}, TPR@0.05: {mia_performance['tpr_at_0.05']}, TPR@0.1: {mia_performance['tpr_at_0.1']}\")\n",
    "\n",
    "    # also add the curves\n",
    "    mia_performance['fpr'] = fpr\n",
    "    mia_performance['tpr'] = tpr\n",
    "\n",
    "    return mia_performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# paste the url of the (completed) target job here\n",
    "\n",
    "jobs = JobList.from_urls([\n",
    "    'https://ml.azure.com/runs/calm_honey_k70j459lks?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47#'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# download the corresponding scores and challenge bits\n",
    "# Note that this faisl when the target directory already exists, so always use a new name for each job (or remove the old results)\n",
    "\n",
    "job_name = 'JOB_NAME'\n",
    "test = jobs[0].get_node('estimate_privacy').download_input('scores', f'./mia_results/{job_name}/scores')\n",
    "test = jobs[0].get_node('estimate_privacy').download_input('challenge_bits', f'./mia_results/{job_name}/challenge_bits')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = datasets.load_from_disk(f'./mia_results/{job_name}/scores')\n",
    "bits = datasets.load_from_disk(f'./mia_results/{job_name}/challenge_bits')\n",
    "scores[0], bits[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "membership_scores = np.array([k['score'] for k in scores])\n",
    "membership_labels = np.array([k['challenge_bit'] for k in bits])\n",
    "members = membership_scores[membership_labels == 1]\n",
    "non_members = membership_scores[membership_labels == 0]\n",
    "\n",
    "# compute performance\n",
    "performance = compute_performance(members, non_members)\n",
    "\n",
    "fpr, tpr = performance['fpr'], performance['tpr']\n",
    "roc_auc = performance['auc']\n",
    "\n",
    "# Plot ROC curve\n",
    "plt.figure(figsize = (5, 5))\n",
    "plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)\n",
    "plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')\n",
    "plt.xlim([0.0, 1.0])\n",
    "plt.ylim([0.0, 1.05])\n",
    "plt.xlabel('False Positive Rate')\n",
    "plt.ylabel('True Positive Rate')\n",
    "plt.title('Receiver Operating Characteristic')\n",
    "plt.legend(loc=\"lower right\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# let's also make a function that computes the performance directly from the url\n",
    "\n",
    "from datetime import datetime\n",
    "\n",
    "def compute_performance_from_url(url, job_name = None):\n",
    "    jobs = JobList.from_urls([url])\n",
    "    if job_name is None:\n",
    "        job_name = str(datetime.now())\n",
    "    \n",
    "    if not os.path.exists(f'./mia_results/{job_name}'):\n",
    "        test = jobs[0].get_node('estimate_privacy').download_input('scores', f'./mia_results/{job_name}/scores')\n",
    "        test = jobs[0].get_node('estimate_privacy').download_input('challenge_bits', f'./mia_results/{job_name}/challenge_bits')\n",
    "    scores = datasets.load_from_disk(f'./mia_results/{job_name}/scores')\n",
    "    bits = datasets.load_from_disk(f'./mia_results/{job_name}/challenge_bits')\n",
    "    \n",
    "    membership_scores = np.array([k['score'] for k in scores])\n",
    "    membership_labels = np.array([k['challenge_bit'] for k in bits])\n",
    "    members = membership_scores[membership_labels == 1]\n",
    "    non_members = membership_scores[membership_labels == 0]\n",
    "    return compute_performance(members, non_members)\n",
    "\n",
    "performance = compute_performance_from_url('https://ml.azure.com/runs/calm_honey_k70j459lks?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47#')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (2) Do this across jobs submitted using a bash script\n",
    "\n",
    "I just launched a series of jobs using a bash script and I wrote the job output to a txt file. How can I easily compute the MIA performance for all results? \n",
    "\n",
    "For for the main experiment I ran\n",
    "\n",
    "`./scripts/launch_no_synthetic_main_experiment.sh > job_launch_outputs/no_synthetic_main_exp.txt`. Let's get these results. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def extract_aml_urls(log_file_path):\n",
    "    # Initialize a list to store the extracted URLs\n",
    "    aml_urls = []\n",
    "\n",
    "    # Define a regular expression to match the log entries containing AML URLs\n",
    "    aml_url_pattern = re.compile(r'AML URL: (https://ml\\.azure\\.com/runs/[\\w\\-\\?&=/]+)')\n",
    "\n",
    "    # Open and read the log file\n",
    "    with open(log_file_path, 'r') as file:\n",
    "        for line in file:\n",
    "            match = aml_url_pattern.search(line)\n",
    "            if match:\n",
    "                aml_urls.append(match.group(1))\n",
    "\n",
    "    return aml_urls\n",
    "\n",
    "# Extract the AML URLs\n",
    "aml_urls = extract_aml_urls('../job_launch_outputs/no_synthetic_main_exp.txt')\n",
    "aml_urls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# list the canary methods in the same order as you did when submitting the jobs\n",
    "\n",
    "canary_methods=(\"no_synthetic_sst2_syntheticcanary_canarylabel\", \"no_synthetic_sst2_syntheticcanary_uniformlabel\",\n",
    "                \"no_synthetic_sst2_incanary\", \"no_synthetic_agnews_syntheticcanary_canarylabel\",\n",
    "                \"no_synthetic_agnews_syntheticcanary_uniformlabel\", \"no_synthetic_agnews_incanary\")\n",
    "\n",
    "assert len(aml_urls) == len(canary_methods)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now run through it all \n",
    "\n",
    "all_aucs, all_tpr_001, all_tpr_01 = [], [], []\n",
    "\n",
    "for i, url in enumerate(aml_urls):\n",
    "    job_name = f'main_exp_no_synthetic_{canary_methods[i]}'\n",
    "    print(f\"Processing {job_name}\")\n",
    "    performance = compute_performance_from_url(url, job_name=job_name)\n",
    "    all_aucs.append(performance['auc'])\n",
    "    all_tpr_001.append(performance['tpr_at_0.01'])\n",
    "    all_tpr_01.append(performance['tpr_at_0.1'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's also do this for the synthetic experiment. \n",
    "\n",
    "Here I launched `./scripts/launch_all_synthetic_mias_main_experiment.sh > job_launch_outputs/all_synthetic_main_exp.txt`, for all the synthetic attacks in the main experiment. So let's go. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# list the MIA methods in the same order as you did when submitting the jobs\n",
    "\n",
    "canary_methods=(\"synthetic_sst2_syntheticcanary_canarylabel\", \"synthetic_sst2_syntheticcanary_uniformlabel\",\n",
    "         \"synthetic_sst2_incanary\", \"synthetic_agnews_syntheticcanary_canarylabel\",\n",
    "         \"synthetic_agnews_syntheticcanary_uniformlabel\", \"synthetic_agnews_incanary\")\n",
    "\n",
    "# Define the array of mia_method values\n",
    "mia_methods=(\"jaccard_25\", \"embedding_25\", \"ngram_2\")\n",
    "\n",
    "aml_urls = extract_aml_urls('../job_launch_outputs/all_synthetic_main_exp.txt')\n",
    "assert len(aml_urls) == len(mia_methods) * len(canary_methods)\n",
    "aml_urls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now run through it all \n",
    "\n",
    "all_aucs, all_tpr_001, all_tpr_01 = [], [], []\n",
    "\n",
    "i = 0\n",
    "\n",
    "for canary_method in canary_methods:\n",
    "    for mia_method in mia_methods:\n",
    "        url = aml_urls[i]\n",
    "        job_name = f'main_exp_all_synthetic_{canary_method}_{mia_method}'\n",
    "        print(f\"Processing {job_name}\")\n",
    "        performance = compute_performance_from_url(url, job_name=job_name)\n",
    "        all_aucs.append(performance['auc'])\n",
    "        all_tpr_001.append(performance['tpr_at_0.01'])\n",
    "        all_tpr_01.append(performance['tpr_at_0.1'])\n",
    "        i += 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (3) let's get the main figure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##### AG NEWS synthetic canary uniform label, n_rep=12, ppl=250\n",
    "\n",
    "agnews_main_urls = {\n",
    "    'model': 'https://ml.azure.com/runs/clever_knot_1htrpfxl9n?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    "    'synthetic_2gram': 'https://ml.azure.com/runs/quirky_insect_7scfdk028l?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    "    'synthetic_jaccard25': 'https://ml.azure.com/runs/nifty_yak_yrqxg3k349?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    "    'synthetic_embedding25': 'https://ml.azure.com/runs/jovial_melon_fq99g5nr2g?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47'}\n",
    "\n",
    "job_names = ['Model', 'Synthetic (2-gram)', r'Synthetic ($\\text{SIM}_{jac}$ - $k=25$)', r'Synthetic ($\\text{SIM}_{emb}$ - $k=25$)']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "main_mia_results = {}\n",
    "\n",
    "for key in agnews_main_urls.keys():\n",
    "    main_mia_results[key] = compute_performance_from_url(agnews_main_urls[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = ['darkblue', 'darkorange', 'darkgreen', 'darkred']\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize = (10, 5))\n",
    "\n",
    "for i, key in enumerate(agnews_main_urls.keys()):\n",
    "    performance = main_mia_results[key]\n",
    "    fpr, tpr = performance['fpr'], performance['tpr']\n",
    "\n",
    "    axes[0].plot(fpr, tpr, lw=2, label=f'{job_names[i]} - AUC = %0.2f' % performance['auc'], color = colors[i])\n",
    "    axes[1].plot(fpr, tpr, lw=2, label=f'{job_names[i]} - AUC = %0.2f' % performance['auc'], color=colors[i])\n",
    "\n",
    "axes[0].plot([0, 1], [0, 1], color='black', lw=2, linestyle='--', label = 'Random guess baseline')\n",
    "axes[1].plot([0, 1], [0, 1], color='black', lw=2, linestyle='--', label = 'Random guess baseline')\n",
    "\n",
    "axes[0].set_xlim([0.0, 1.0])\n",
    "axes[0].set_ylim([0.0, 1.05])\n",
    "axes[0].set_xlabel('False Positive Rate', fontsize = 12)\n",
    "axes[0].set_ylabel('True Positive Rate', fontsize = 12)\n",
    "\n",
    "axes[1].set_xscale('log')\n",
    "axes[1].set_yscale('log')\n",
    "axes[1].set_xlim([1e-3, 1e0])\n",
    "axes[1].set_xlabel('False Positive Rate (log)', fontsize = 12)\n",
    "axes[1].set_ylabel('True Positive Rate (log)', fontsize = 12)\n",
    "\n",
    "axes[0].legend(loc=\"lower right\", fontsize = 9)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('figures/tpr_fpr_all_attacks_agnews_syntheticcanary_uniformlabel.pdf', dpi=300, bbox_inches='tight', pad_inches=0)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.concat([df, pd.DataFrame({\n",
    "    \"method\": [\"2_gram\", \"jac\", \"emb\"],\n",
    "    \"dataset\": [\"agnews\"]*3,\n",
    "    \"fpr\": [main_mia_results['synthetic_2gram']['fpr'], main_mia_results['synthetic_jaccard25']['fpr'], main_mia_results['synthetic_embedding25']['fpr']],\n",
    "    \"tpr\": [main_mia_results['synthetic_2gram']['tpr'], main_mia_results['synthetic_jaccard25']['tpr'], main_mia_results['synthetic_embedding25']['tpr']],\n",
    "})])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# do the same for sst-2\n",
    "\n",
    "sst2_main_urls = {\n",
    "    'model': 'https://ml.azure.com/runs/gentle_tail_0bfkrfjv50?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    "    'synthetic_2gram': 'https://ml.azure.com/runs/quirky_battery_tvw7wcm4wh?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47', \n",
    "    'synthetic_jaccard25': 'https://ml.azure.com/runs/polite_vulture_66jf4synlb?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47', \n",
    "    'synthetic_embedding25': 'https://ml.azure.com/runs/tidy_shelf_q5r6tfffyp?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "main_mia_results = {}\n",
    "\n",
    "for key in sst2_main_urls.keys():\n",
    "    main_mia_results[key] = compute_performance_from_url(sst2_main_urls[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = ['darkblue', 'darkorange', 'darkgreen', 'darkred']\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize = (10, 5))\n",
    "\n",
    "for i, key in enumerate(sst2_main_urls.keys()):\n",
    "    performance = main_mia_results[key]\n",
    "    fpr, tpr = performance['fpr'], performance['tpr']\n",
    "\n",
    "    axes[0].plot(fpr, tpr, lw=2, label=f'{job_names[i]} - AUC = %0.2f' % performance['auc'], color = colors[i])\n",
    "    axes[1].plot(fpr, tpr, lw=2, label=f'{job_names[i]} - AUC = %0.2f' % performance['auc'], color=colors[i])\n",
    "\n",
    "axes[0].plot([0, 1], [0, 1], color='black', lw=2, linestyle='--', label = 'Random guess baseline')\n",
    "axes[1].plot([0, 1], [0, 1], color='black', lw=2, linestyle='--', label = 'Random guess baseline')\n",
    "\n",
    "axes[0].set_xlim([0.0, 1.0])\n",
    "axes[0].set_ylim([0.0, 1.05])\n",
    "axes[0].set_xlabel('False Positive Rate', fontsize = 12)\n",
    "axes[0].set_ylabel('True Positive Rate', fontsize = 12)\n",
    "\n",
    "axes[1].set_xscale('log')\n",
    "axes[1].set_yscale('log')\n",
    "axes[1].set_xlim([1e-3, 1e0])\n",
    "axes[1].set_xlabel('False Positive Rate (log)', fontsize = 12)\n",
    "axes[1].set_ylabel('True Positive Rate (log)', fontsize = 12)\n",
    "\n",
    "axes[0].legend(loc=\"lower right\", fontsize = 9)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('figures/tpr_fpr_all_attacks_sst2_syntheticcanary_uniformlabel.pdf', dpi=300, bbox_inches='tight', pad_inches=0)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.concat([df, pd.DataFrame({\n",
    "    \"method\": [\"2_gram\", \"jac\", \"emb\"],\n",
    "    \"dataset\": [\"sst2\"]*3,\n",
    "    \"fpr\": [main_mia_results['synthetic_2gram']['fpr'], main_mia_results['synthetic_jaccard25']['fpr'], main_mia_results['synthetic_embedding25']['fpr']],\n",
    "    \"tpr\": [main_mia_results['synthetic_2gram']['tpr'], main_mia_results['synthetic_jaccard25']['tpr'], main_mia_results['synthetic_embedding25']['tpr']],\n",
    "})])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from latex import Project\n",
    "overleaf = Project.from_env(path_env_name=\"LATEX_GIT_PATH\")\n",
    "for ds in [\"agnews\", \"sst2\"]:\n",
    "    for method in [\"2_gram\", \"jac\", \"emb\"]:\n",
    "        df_i = df[(df[\"dataset\"] == ds) & (df[\"method\"] == method)]\n",
    "        overleaf.add_dataframe(pd.DataFrame({\"fpr\": df_i[\"fpr\"].values[0], \"tpr\": df_i[\"tpr\"].values[0]}), f\"data/method/{ds}/roc/{method}.tsv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "privacy-estimates",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
