{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import datasets\n",
    "from privacy_estimates.experiments.aml import JobList\n",
    "from sklearn.metrics import roc_curve, roc_auc_score, auc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datetime import datetime\n",
    "\n",
    "def compute_performance(scores_members, scores_non_members):\n",
    "    mia_performance = {}\n",
    "    \n",
    "    member_vals = [val for val in scores_members]\n",
    "    non_member_vals = [val for val in scores_non_members]\n",
    "    mia_performance['auc'] = roc_auc_score([1]*len(member_vals) + [0]*len(non_member_vals), member_vals + non_member_vals)\n",
    "    fpr, tpr, thresholds = roc_curve([1]*len(member_vals) + [0]*len(non_member_vals), member_vals + non_member_vals)\n",
    "    for target_fpr in (0.01, 0.05, 0.1):\n",
    "        mia_performance[f'tpr_at_{target_fpr}'] = np.interp(target_fpr, fpr, tpr)\n",
    "    print(f\"AUC: {mia_performance['auc']}, TPR@0.01: {mia_performance['tpr_at_0.01']}, TPR@0.05: {mia_performance['tpr_at_0.05']}, TPR@0.1: {mia_performance['tpr_at_0.1']}\")\n",
    "\n",
    "    # also add the curves\n",
    "    mia_performance['fpr'] = fpr\n",
    "    mia_performance['tpr'] = tpr\n",
    "\n",
    "    return mia_performance\n",
    "\n",
    "def compute_performance_from_url(url, job_name = None):\n",
    "    jobs = JobList.from_urls([url])\n",
    "    if job_name is None:\n",
    "        job_name = str(datetime.now())\n",
    "    \n",
    "    if not os.path.exists(f'./mia_results/{job_name}'):\n",
    "        test = jobs[0].get_node('estimate_privacy').download_input('scores', f'./mia_results/{job_name}/scores')\n",
    "        test = jobs[0].get_node('estimate_privacy').download_input('challenge_bits', f'./mia_results/{job_name}/challenge_bits')\n",
    "    scores = datasets.load_from_disk(f'./mia_results/{job_name}/scores')\n",
    "    bits = datasets.load_from_disk(f'./mia_results/{job_name}/challenge_bits')\n",
    "    \n",
    "    membership_scores = np.array([k['score'] for k in scores])\n",
    "    membership_labels = np.array([k['challenge_bit'] for k in bits])\n",
    "    members = membership_scores[membership_labels == 1]\n",
    "    non_members = membership_scores[membership_labels == 0]\n",
    "    return compute_performance(members, non_members)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Lets do it for synthetic multiple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def extract_aml_urls(log_file_path):\n",
    "    # Initialize a list to store the extracted URLs\n",
    "    aml_urls = []\n",
    "\n",
    "    # Define a regular expression to match the log entries containing AML URLs\n",
    "    aml_url_pattern = re.compile(r'AML URL: (https://ml\\.azure\\.com/runs/[\\w\\-\\?&=/]+)')\n",
    "\n",
    "    # Open and read the log file\n",
    "    with open(log_file_path, 'r') as file:\n",
    "        for line in file:\n",
    "            match = aml_url_pattern.search(line)\n",
    "            if match:\n",
    "                aml_urls.append(match.group(1))\n",
    "\n",
    "    return aml_urls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mia_methods=(\"jaccard_25\", \"embedding_25\", \"ngram_2\")\n",
    "multiples = (1, 2, 4, 8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (1) Let's first do sst-2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET = 'sst2'\n",
    "\n",
    "# Extract the AML URLs\n",
    "aml_urls = extract_aml_urls(f'../job_launch_outputs/synthetic_multiples_all_mias_{DATASET}.txt')\n",
    "aml_urls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method_to_aucs = {}\n",
    "\n",
    "i = 0\n",
    "for j, multiple in enumerate(multiples):\n",
    "    for method in mia_methods:\n",
    "        job_name = f'multiple_{multiple}_synthetic_{method}'\n",
    "        print(f\"Processing {job_name}\")\n",
    "        performance = compute_performance_from_url(aml_urls[i])\n",
    "        if j == 0:\n",
    "            method_to_aucs[method] = [performance['auc']]\n",
    "        else:\n",
    "            method_to_aucs[method].append(performance['auc'])\n",
    "        i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method_to_aucs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = ['ngram_2', 'jaccard_25', 'embedding_25']\n",
    "labels = ['Synthetic (2-gram)', r'Synthetic ($\\text{SIM}_{jac}$ - $k=25$)', r'Synthetic ($\\text{SIM}_{emb}$ - $k=25$)']\n",
    "colors = ['darkorange', 'darkgreen', 'darkred']\n",
    "\n",
    "plt.figure(figsize=(6, 6))\n",
    "plt.axhline(y=0.998, color='darkblue', linestyle='--', linewidth = 2, alpha=1, label = 'Model')\n",
    "for i, method in enumerate(methods):\n",
    "    plt.plot([1, 2, 4, 8], method_to_aucs[method], '-o', alpha=0.8, label=labels[i], color = colors[i],\n",
    "              markersize=10, linewidth=2, markeredgewidth=2, markeredgecolor='white')\n",
    "plt.xscale('log', base=2)\n",
    "plt.xlabel(r'Synthetic multiple $m$', fontsize=15)\n",
    "plt.ylabel('MIA AUC', fontsize=15)\n",
    "plt.grid(True, which=\"both\", ls=\"--\", alpha=0.8)\n",
    "plt.axhline(y=0.5, color='black', linestyle='--', alpha=1.0, label = 'Random guess baseline')\n",
    "plt.legend(fontsize=12)\n",
    "plt.ylim(0.45, 1.02)\n",
    "plt.savefig(f'figures/synthetic_multiple_{DATASET}.pdf', dpi=300, bbox_inches='tight', pad_inches=0)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (2) Let's repeat this for AgNews"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET = 'agnews'\n",
    "\n",
    "# Extract the AML URLs\n",
    "aml_urls = extract_aml_urls(f'../job_launch_outputs/synthetic_multiples_all_mias_{DATASET}.txt')\n",
    "aml_urls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method_to_aucs = {}\n",
    "\n",
    "i = 0\n",
    "for j, multiple in enumerate(multiples):\n",
    "    for method in mia_methods:\n",
    "        job_name = f'multiple_{multiple}_synthetic_{method}'\n",
    "        print(f\"Processing {job_name}\")\n",
    "        performance = compute_performance_from_url(aml_urls[i])\n",
    "        if j == 0:\n",
    "            method_to_aucs[method] = [performance['auc']]\n",
    "        else:\n",
    "            method_to_aucs[method].append(performance['auc'])\n",
    "        i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = ['ngram_2', 'jaccard_25', 'embedding_25']\n",
    "labels = ['Synthetic (2-gram)', r'Synthetic ($\\text{SIM}_{jac}$ - $k=25$)', r'Synthetic ($\\text{SIM}_{emb}$ - $k=25$)']\n",
    "colors = ['darkorange', 'darkgreen', 'darkred']\n",
    "\n",
    "plt.figure(figsize=(6, 6))\n",
    "plt.axhline(y=0.998, color='darkblue', linestyle='--', linewidth = 2, alpha=1, label = 'Model')\n",
    "for i, method in enumerate(methods):\n",
    "    plt.plot([1, 2, 4, 8], method_to_aucs[method], '-o', alpha=0.8, label=labels[i], color = colors[i],\n",
    "              markersize=10, linewidth=2, markeredgewidth=2, markeredgecolor='white')\n",
    "plt.xscale('log', base=2)\n",
    "plt.xlabel(r'Synthetic multiple $m$', fontsize=15)\n",
    "plt.ylabel('MIA AUC', fontsize=15)\n",
    "plt.grid(True, which=\"both\", ls=\"--\", alpha=0.8)\n",
    "plt.axhline(y=0.5, color='black', linestyle='--', alpha=1.0, label = 'Random guess baseline')\n",
    "plt.legend(fontsize=12)\n",
    "plt.ylim(0.45, 1.02)\n",
    "plt.savefig(f'figures/synthetic_multiple_{DATASET}.pdf', dpi=300, bbox_inches='tight', pad_inches=0)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "privacy-estimates",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
