{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import datasets\n",
    "from privacy_estimates.experiments.aml import JobList\n",
    "from sklearn.metrics import roc_curve, roc_auc_score, auc\n",
    "import re\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_performance(scores_members, scores_non_members):\n",
    "    mia_performance = {}\n",
    "    \n",
    "    member_vals = [val for val in scores_members]\n",
    "    non_member_vals = [val for val in scores_non_members]\n",
    "    mia_performance['auc'] = roc_auc_score([1]*len(member_vals) + [0]*len(non_member_vals), member_vals + non_member_vals)\n",
    "    fpr, tpr, thresholds = roc_curve([1]*len(member_vals) + [0]*len(non_member_vals), member_vals + non_member_vals)\n",
    "    for target_fpr in (0.01, 0.05, 0.1):\n",
    "        mia_performance[f'tpr_at_{target_fpr}'] = np.interp(target_fpr, fpr, tpr)\n",
    "    print(f\"AUC: {mia_performance['auc']}, TPR@0.01: {mia_performance['tpr_at_0.01']}, TPR@0.05: {mia_performance['tpr_at_0.05']}, TPR@0.1: {mia_performance['tpr_at_0.1']}\")\n",
    "\n",
    "    # also add the curves\n",
    "    mia_performance['fpr'] = fpr\n",
    "    mia_performance['tpr'] = tpr\n",
    "\n",
    "    return mia_performance\n",
    "\n",
    "# let's also make a function that computes the performance directly from the url\n",
    "\n",
    "from datetime import datetime\n",
    "\n",
    "def compute_performance_from_url(url, job_name = None):\n",
    "    jobs = JobList.from_urls([url])\n",
    "    if job_name is None:\n",
    "        job_name = str(datetime.now())\n",
    "    \n",
    "    if not os.path.exists(f'./mia_results/{job_name}'):\n",
    "        test = jobs[0].get_node('estimate_privacy').download_input('scores', f'./mia_results/{job_name}/scores')\n",
    "        test = jobs[0].get_node('estimate_privacy').download_input('challenge_bits', f'./mia_results/{job_name}/challenge_bits')\n",
    "    scores = datasets.load_from_disk(f'./mia_results/{job_name}/scores')\n",
    "    bits = datasets.load_from_disk(f'./mia_results/{job_name}/challenge_bits')\n",
    "    \n",
    "    membership_scores = np.array([k['score'] for k in scores])\n",
    "    membership_labels = np.array([k['challenge_bit'] for k in bits])\n",
    "    members = membership_scores[membership_labels == 1]\n",
    "    non_members = membership_scores[membership_labels == 0]\n",
    "    return compute_performance(members, non_members)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_aml_urls(log_file_path):\n",
    "    # Initialize a list to store the extracted URLs\n",
    "    aml_urls = []\n",
    "\n",
    "    # Define a regular expression to match the log entries containing AML URLs\n",
    "    aml_url_pattern = re.compile(r'AML URL: (https://ml\\.azure\\.com/runs/[\\w\\-\\?&=/]+)')\n",
    "\n",
    "    # Open and read the log file\n",
    "    with open(log_file_path, 'r') as file:\n",
    "        for line in file:\n",
    "            match = aml_url_pattern.search(line)\n",
    "            if match:\n",
    "                aml_urls.append(match.group(1))\n",
    "\n",
    "    return aml_urls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# first the non synthetic results\n",
    "\n",
    "perplexities = [10, 10**1.5, 10**2, 10**2.5, 10**3, 10**3.5, 10**4, 10**4.5, 10**5]\n",
    "\n",
    "no_synthetic_urls = extract_aml_urls('../job_launch_outputs/no_synthetic_ppl_sst2.txt')\n",
    "synthetic_urls = extract_aml_urls('../job_launch_outputs/synthetic_ppl_sst2.txt')\n",
    "\n",
    "# do the last one too, which was a relaunch with a small fix\n",
    "synthetic_url_last_one = extract_aml_urls('../job_launch_outputs/synthetic_ppl_sst2_lastonerelaunch.txt')\n",
    "synthetic_urls[-1] = synthetic_url_last_one[0]\n",
    "\n",
    "assert len(no_synthetic_urls) == len(perplexities)\n",
    "assert len(synthetic_urls) == len(perplexities)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "no_synthetic_urls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "synthetic_urls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "non_synthetic_results = dict()\n",
    "synthetic_results = dict()\n",
    "\n",
    "for i in range(len(perplexities)):\n",
    "    print(f\"Perplexity: {perplexities[i]}\")\n",
    "    print(\"Non synthetic\")\n",
    "    try:\n",
    "        performance = compute_performance_from_url(no_synthetic_urls[i])\n",
    "        non_synthetic_results[perplexities[i]] = performance\n",
    "        print('---')\n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "        non_synthetic_results[perplexities[i]] ={'auc': None, 'tpr_at_0.01': None, 'tpr_at_0.05': None, 'tpr_at_0.1': None}\n",
    "        print('---')\n",
    "    \n",
    "    print('Synthetic')\n",
    "    try:\n",
    "        performance = compute_performance_from_url(synthetic_urls[i])\n",
    "        synthetic_results[perplexities[i]] = performance\n",
    "        print('---')\n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "        synthetic_results[perplexities[i]] ={'auc': None, 'tpr_at_0.01': None, 'tpr_at_0.05': None, 'tpr_at_0.1': None}\n",
    "        print('---')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6,6))\n",
    "plt.plot(perplexities, [non_synthetic_results[p]['auc'] for p in perplexities], '-o', alpha=0.8, \n",
    "             color ='darkblue', markersize=10, linewidth=2, markeredgewidth=2, markeredgecolor='white', label=r'Model - $n_{rep}=4$')\n",
    "plt.plot(perplexities, [synthetic_results[p]['auc'] for p in perplexities], '-o', alpha=0.8, \n",
    "             color ='darkorange', markersize=10, linewidth=2, markeredgewidth=2, markeredgecolor='white', label=r'Synthetic (2-gram) - $n_{rep}=16$')\n",
    "\n",
    "plt.axhline(y=0.5, color='black', linestyle='--', alpha=0.5, label = 'Random guess baseline')\n",
    "\n",
    "\n",
    "plt.xticks([10**k for k in (0, 1, 2, 3, 4, 5)])\n",
    "plt.yticks([0.5, 0.6, 0.7, 0.8, 0.9, 1.0], labels=['0.5', '0.6', '0.7', '0.8', '0.9', '1.0'])\n",
    "\n",
    "# Enable the grid\n",
    "plt.grid(True, which=\"major\", ls=\"--\", alpha=0.8)\n",
    "\n",
    "plt.legend(loc='lower left', fontsize=14)\n",
    "plt.xlabel('Canary perplexity', fontsize=16)\n",
    "plt.ylabel('AUC', fontsize=16)\n",
    "plt.ylim(0.4, 1.02)\n",
    "plt.xscale('log')\n",
    "plt.savefig('figures/ppl_experiment_sst2.pdf', dpi=300, bbox_inches='tight', pad_inches=0)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "df = pd.DataFrame({'ppl': perplexities, 'model': [non_synthetic_results[p]['auc'] for p in perplexities], 'synthetic': [synthetic_results[p]['auc'] for p in perplexities]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_roc = dict()\n",
    "for p in perplexities:\n",
    "    df_roc[f\"perp_{int(p)}_model\"] = pd.DataFrame({\"fpr\": non_synthetic_results[p]['fpr'], \"tpr\": non_synthetic_results[p]['tpr']})\n",
    "    df_roc[f\"perp_{int(p)}_synthetic\"] = pd.DataFrame({\"fpr\": synthetic_results[p]['fpr'], \"tpr\": synthetic_results[p]['tpr']})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from latex import Project\n",
    "overleaf = Project(url=\"https://git@git.overleaf.com/667bde737a03ee4008a9359f\")\n",
    "overleaf.push_dataframe(df, 'data/canary_ppl/sst2/auc.tsv')\n",
    "\n",
    "for df_roc_name, df_roc_data in df_roc.items():\n",
    "    overleaf.push_dataframe(df_roc_data, f'data/canary_ppl/sst2/roc/{df_roc_name}.tsv')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "privacy-estimates",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
