{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import datasets\n",
    "from privacy_estimates.experiments.aml import JobList\n",
    "from sklearn.metrics import roc_curve, roc_auc_score, auc\n",
    "from datetime import datetime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_performance(scores_members, scores_non_members):\n",
    "    mia_performance = {}\n",
    "    \n",
    "    member_vals = [val for val in scores_members]\n",
    "    non_member_vals = [val for val in scores_non_members]\n",
    "    mia_performance['auc'] = roc_auc_score([1]*len(member_vals) + [0]*len(non_member_vals), member_vals + non_member_vals)\n",
    "    fpr, tpr, thresholds = roc_curve([1]*len(member_vals) + [0]*len(non_member_vals), member_vals + non_member_vals)\n",
    "    for target_fpr in (0.01, 0.05, 0.1):\n",
    "        mia_performance[f'tpr_at_{target_fpr}'] = np.interp(target_fpr, fpr, tpr)\n",
    "    print(f\"AUC: {mia_performance['auc']}, TPR@0.01: {mia_performance['tpr_at_0.01']}, TPR@0.05: {mia_performance['tpr_at_0.05']}, TPR@0.1: {mia_performance['tpr_at_0.1']}\")\n",
    "\n",
    "    # also add the curves\n",
    "    mia_performance['fpr'] = fpr\n",
    "    mia_performance['tpr'] = tpr\n",
    "\n",
    "    return mia_performance\n",
    "\n",
    "def compute_performance_from_url(url, job_name = None):\n",
    "    jobs = JobList.from_urls([url])\n",
    "    if job_name is None:\n",
    "        job_name = str(datetime.now())\n",
    "    \n",
    "    if not os.path.exists(f'./mia_results/{job_name}'):\n",
    "        test = jobs[0].get_node('estimate_privacy').download_input('scores', f'./mia_results/{job_name}/scores')\n",
    "        test = jobs[0].get_node('estimate_privacy').download_input('challenge_bits', f'./mia_results/{job_name}/challenge_bits')\n",
    "    scores = datasets.load_from_disk(f'./mia_results/{job_name}/scores')\n",
    "    bits = datasets.load_from_disk(f'./mia_results/{job_name}/challenge_bits')\n",
    "    \n",
    "    membership_scores = np.array([k['score'] for k in scores])\n",
    "    membership_labels = np.array([k['challenge_bit'] for k in bits])\n",
    "    members = membership_scores[membership_labels == 1]\n",
    "    non_members = membership_scores[membership_labels == 0]\n",
    "    return compute_performance(members, non_members)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this is all for syntheticcanary_uniformlabel with n_rep = 12\n",
    "all_urls = {\n",
    "    'sst2': {\n",
    "        'no_synthetic': 'https://ml.azure.com/runs/amiable_roof_qs9w099y2f?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    "        'synthetic_2gram': 'https://ml.azure.com/runs/quirky_battery_tvw7wcm4wh?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47'\n",
    "    },\n",
    "        'agnews': {\n",
    "        'no_synthetic': 'https://ml.azure.com/runs/blue_whale_w9k49kxdvn?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47', \n",
    "        'synthetic_2gram': 'https://ml.azure.com/runs/quirky_insect_7scfdk028l?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PICK THE DATASET OF CHOICE\n",
    "DATASET = 'agnews'\n",
    "no_synthetic_job_name = all_urls[DATASET]['no_synthetic']\n",
    "synthetic_job_name = all_urls[DATASET]['synthetic_2gram']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jobs = JobList.from_urls([no_synthetic_job_name, synthetic_job_name])\n",
    "no_synthetic_job = jobs[0]\n",
    "synthetic_job = jobs[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Clean up before proceeding.**\n",
    "\n",
    "Before you run the below, I recommend running './clean_for_scatter.sh' in the notebooks directory. This makes sure nothing remains from the previous run and everything can be downloaded again for the right jobs. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (1) Let's first make sure the canaries are the same!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the canaries\n",
    "test = jobs[0].get_node('add_index_to_dataset_2').get_node('append_column_incrementing').download_input('data', 'canaries_no_synthetic')\n",
    "canaries_no_synthetic = datasets.load_from_disk('canaries_no_synthetic')\n",
    "\n",
    "# get the in indices\n",
    "test =  jobs[0].get_node('train_many_models').get_node('train_final_model_group').get_node('train_model_and_predict').get_node('data_for_model').download_input('in_indices', 'in_indices_no_synthetic')\n",
    "in_indices_no_synthetic = datasets.load_from_disk('in_indices_no_synthetic')\n",
    "\n",
    "# get the canaries\n",
    "test = jobs[1].get_node('add_index_to_dataset_2').get_node('append_column_incrementing').download_input('data', 'canaries_synthetic')\n",
    "canaries_synthetic = datasets.load_from_disk('canaries_synthetic')\n",
    "# get the in indices\n",
    "test =  jobs[1].get_node('train_many_models').get_node('train_final_model_group').get_node('train_model_and_predict').get_node('data_for_model').download_input('in_indices', 'in_indices_synthetic')\n",
    "in_indices_synthetic = datasets.load_from_disk('in_indices_synthetic')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# let's print them out and check if they are the same\n",
    "\n",
    "for i in range(200, 300):\n",
    "    print(f\"Canary {i} no synthetic: {canaries_no_synthetic[i]}\")\n",
    "    print(f\"Canary {i} synthetic: {canaries_synthetic[i]}\")\n",
    "    print(f\"Index {i} no synthetic: {in_indices_no_synthetic[i]}\")\n",
    "    print(f\"Index {i} synthetic: {in_indices_synthetic[i]}\")\n",
    "    print('---')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the in out data too, assuming this will be the same\n",
    "test =  jobs[0].get_node('train_many_models').get_node('train_final_model_group').get_node('train_model_and_predict').get_node('data_for_model').download_input('in_out_data', 'in_out_data_no_synthetic')\n",
    "in_out_data_no_synthetic = datasets.load_from_disk('in_out_data_no_synthetic')\n",
    "\n",
    "test =  jobs[1].get_node('train_many_models').get_node('train_final_model_group').get_node('train_model_and_predict').get_node('data_for_model').download_input('in_out_data', 'in_out_data_synthetic')\n",
    "in_out_data_synthetic = datasets.load_from_disk('in_out_data_synthetic')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(20, 30):\n",
    "    print(f\"In out data {i} no synthetic: {in_out_data_no_synthetic[i]}\")\n",
    "    print(f\"In out data {i} synthetic: {in_out_data_synthetic[i]}\")\n",
    "    print('---')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (2) Ok nice, let's move to the scores\n",
    "\n",
    "Let's first make sure we get the right MIA scores and AUC curves. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = jobs[0].get_node('estimate_privacy').download_input('scores', './mia_results/scores_no_synthetic/scores')\n",
    "scores_no_synthetic = datasets.load_from_disk('./mia_results/scores_no_synthetic/scores')\n",
    "\n",
    "test = jobs[0].get_node('estimate_privacy').download_input('challenge_bits', './mia_results/challenge_bits_no_synthetic/challenge_bits')\n",
    "challenge_bits_no_synthetic = datasets.load_from_disk('./mia_results/challenge_bits_no_synthetic/challenge_bits')\n",
    "\n",
    "test = jobs[1].get_node('estimate_privacy').download_input('scores', './mia_results/scores_synthetic/scores')\n",
    "scores_synthetic = datasets.load_from_disk('./mia_results/scores_synthetic/scores')\n",
    "\n",
    "test = jobs[1].get_node('estimate_privacy').download_input('challenge_bits', './mia_results/challenge_bits_synthetic/challenge_bits')\n",
    "challenge_bits_synthetic = datasets.load_from_disk('./mia_results/challenge_bits_synthetic/challenge_bits')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# let's verify that the challenger scores are all the same\n",
    "\n",
    "for i in range(50):\n",
    "    assert challenge_bits_no_synthetic[i] == challenge_bits_synthetic[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# let's compare the roc curves again, should match with what we have in the other figures\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize = (10, 5))\n",
    "\n",
    "print('For non synthetic we get: ')\n",
    "membership_scores = np.array([k['score'] for k in scores_no_synthetic])\n",
    "membership_labels = np.array([k['challenge_bit'] for k in challenge_bits_no_synthetic])\n",
    "members = membership_scores[membership_labels == 1]\n",
    "non_members = membership_scores[membership_labels == 0]\n",
    "\n",
    "performance = compute_performance( members, non_members)\n",
    "fpr, tpr = performance['fpr'], performance['tpr']\n",
    "roc_auc_no_synthetic = auc(fpr, tpr)\n",
    "axes[0].plot(fpr, tpr, color='darkorange', lw=2, label='Model (loss) (AUC = %0.2f)' % roc_auc_no_synthetic)\n",
    "axes[1].plot(fpr, tpr, color='darkorange', lw=2, label='Model (loss) (AUC = %0.2f)' % roc_auc_no_synthetic)\n",
    "\n",
    "# let's now do the synthetic one\n",
    "print('For synthetic we get: ')\n",
    "membership_scores = np.array([k['score'] for k in scores_synthetic])\n",
    "membership_labels = np.array([k['challenge_bit'] for k in challenge_bits_synthetic])\n",
    "members = membership_scores[membership_labels == 1]\n",
    "non_members = membership_scores[membership_labels == 0]\n",
    "\n",
    "performance = compute_performance( members, non_members)\n",
    "fpr, tpr = performance['fpr'], performance['tpr']\n",
    "roc_auc_synthetic = auc(fpr, tpr)\n",
    "axes[0].plot(fpr, tpr, color='darkgreen', lw=2, label='MIA Synthetic - 2-gram loss (AUC = %0.2f)' % roc_auc_synthetic)\n",
    "axes[1].plot(fpr, tpr, color='darkgreen', lw=2, label='MIA Synthetic - 2-gram loss (AUC = %0.2f)' % roc_auc_synthetic)\n",
    "\n",
    "axes[0].plot([0, 1], [0, 1], color='black', lw=2, linestyle='--', label = 'Random guess baseline')\n",
    "axes[1].plot([0, 1], [0, 1], color='black', lw=2, linestyle='--', label = 'Random guess baseline')\n",
    "\n",
    "axes[0].set_xlim([0.0, 1.0])\n",
    "axes[0].set_ylim([0.0, 1.05])\n",
    "axes[0].set_xlabel('False Positive Rate')\n",
    "axes[0].set_ylabel('True Positive Rate')\n",
    "\n",
    "axes[1].set_xscale('log')\n",
    "axes[1].set_yscale('log')\n",
    "axes[1].set_xlim([1e-3, 1e0])\n",
    "axes[1].set_xlabel('False Positive Rate (log)')\n",
    "axes[1].set_ylabel('True Positive Rate (log)')\n",
    "\n",
    "axes[1].legend(loc=\"lower right\")\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# let's plot the tpr at fpr for both methods\n",
    "\n",
    "membership_scores = np.array([k['score'] for k in scores_no_synthetic])\n",
    "membership_labels = np.array([k['challenge_bit'] for k in challenge_bits_no_synthetic])\n",
    "members = membership_scores[membership_labels == 1]\n",
    "non_members = membership_scores[membership_labels == 0]\n",
    "\n",
    "no_synthetic_fpr, no_synthetic_tpr, no_synthetic_thresholds = roc_curve(membership_labels, membership_scores)\n",
    "\n",
    "print('For synthetic we get: ')\n",
    "membership_scores = np.array([k['score'] for k in scores_synthetic])\n",
    "membership_labels = np.array([k['challenge_bit'] for k in challenge_bits_synthetic])\n",
    "members = membership_scores[membership_labels == 1]\n",
    "non_members = membership_scores[membership_labels == 0]\n",
    "\n",
    "synthetic_fpr, synthetic_tpr, synthetic_thresholds = roc_curve(membership_labels, membership_scores)\n",
    "\n",
    "vals_no_synthetic = []\n",
    "vals_synthetic = []\n",
    "for fpr in np.linspace(0, 1, 500):\n",
    "    # do interpolation\n",
    "    vals_synthetic.append(np.interp(fpr, synthetic_fpr, synthetic_tpr))\n",
    "    vals_no_synthetic.append(np.interp(fpr, no_synthetic_fpr, no_synthetic_tpr))\n",
    "\n",
    "fig, axes = plt.subplots(1, 1, figsize = (5, 5))\n",
    "axes.plot(vals_no_synthetic, vals_synthetic, lw=2)\n",
    "axes.set_xlabel('Non synthetic TPR at FPR')\n",
    "axes.set_ylabel('Synthetic TPR at FPR')\n",
    "axes.set_xlim([0.0, 1.0])\n",
    "axes.set_ylim([0.0, 1.0])\n",
    "axes.plot([0, 1], [0, 1], color='black', lw=2, linestyle='--', label = 'y=x')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (3) Let's compare the rmia scores across attacks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from scipy.stats import pearsonr\n",
    "\n",
    "# Assuming scores_no_synthetic, challenge_bits_no_synthetic, scores_synthetic, challenge_bits_synthetic are defined\n",
    "alpha = 0.7\n",
    "s = 18\n",
    "\n",
    "# Prepare the data\n",
    "membership_scores_no_synthetic = np.array([np.log(k['score']) for k in scores_no_synthetic])\n",
    "membership_labels_no_synthetic = np.array([k['challenge_bit'] for k in challenge_bits_no_synthetic])\n",
    "members_no_synthetic = membership_scores_no_synthetic[membership_labels_no_synthetic == 1]\n",
    "non_members_no_synthetic = membership_scores_no_synthetic[membership_labels_no_synthetic == 0]\n",
    "\n",
    "membership_scores_synthetic = np.array([np.log(k['score']) for k in scores_synthetic])\n",
    "membership_labels_synthetic = np.array([k['challenge_bit'] for k in challenge_bits_synthetic])\n",
    "members_synthetic = membership_scores_synthetic[membership_labels_synthetic == 1]\n",
    "non_members_synthetic = membership_scores_synthetic[membership_labels_synthetic == 0]\n",
    "\n",
    "# Create the main plot\n",
    "fig, ax = plt.subplots(figsize=(6, 6))\n",
    "\n",
    "# Plot the members and non-members\n",
    "scatter_members = ax.scatter(members_no_synthetic, members_synthetic, s=s, alpha=alpha,\n",
    "                             color='darkcyan', label='Members')\n",
    "scatter_non_members = ax.scatter(non_members_no_synthetic, non_members_synthetic, s=s, alpha=alpha,\n",
    "                                 color='slategrey', label='Non-members')\n",
    "\n",
    "# Set labels and limits\n",
    "ax.set_xlabel('RMIA scores (log) - Model - AUC=%0.3f' % roc_auc_no_synthetic, fontsize = 12)\n",
    "ax.set_ylabel(f'RMIA scores (log) - Synthetic (2-gram) AUC=%0.3f' % roc_auc_synthetic, fontsize = 12)\n",
    "\n",
    "ax.grid()\n",
    "\n",
    "# Create inset axes for histograms\n",
    "divider = make_axes_locatable(ax)\n",
    "ax_hist_x = divider.append_axes(\"top\", 1.2, pad=0.1, sharex=ax)\n",
    "ax_hist_y = divider.append_axes(\"right\", 1.2, pad=0.1, sharey=ax)\n",
    "\n",
    "# Plot histograms\n",
    "# first no synthetic\n",
    "bins = np.histogram(np.hstack((members_no_synthetic, non_members_no_synthetic)), bins=50)[1]\n",
    "ax_hist_x.hist(members_no_synthetic, bins=bins, color='darkcyan', alpha=alpha)\n",
    "ax_hist_x.hist(non_members_no_synthetic, bins=bins, color='slategrey', alpha=alpha)\n",
    "ax_hist_x.grid()\n",
    "# then synthetic\n",
    "bins = np.histogram(np.hstack((members_synthetic, non_members_synthetic)), bins=50)[1]\n",
    "ax_hist_y.hist(members_synthetic, bins=bins, orientation='horizontal', color='darkcyan', alpha=alpha)\n",
    "ax_hist_y.hist(non_members_synthetic, bins=bins, orientation='horizontal', color='slategrey', alpha=alpha)\n",
    "ax_hist_y.grid()\n",
    "\n",
    "# also fit a linear regression through all scatter points\n",
    "lr = LinearRegression()\n",
    "lr.fit(membership_scores_no_synthetic.reshape(-1, 1), membership_scores_synthetic.reshape(-1, 1))\n",
    "# find the min and max x axis\n",
    "x = np.linspace(min(membership_scores_no_synthetic), max(membership_scores_no_synthetic), 100).reshape(-1, 1)\n",
    "y = lr.predict(x)\n",
    "# also compute the pearson correlation and its pvalue\n",
    "corr, pval = pearsonr(membership_scores_no_synthetic, membership_scores_synthetic)\n",
    "print(f'Pearson correlation coefficient: {corr}')\n",
    "print(f'p-value: {pval}')\n",
    "ax.plot(x, y, color='black', lw=2, linestyle='--', label='Correlation=%0.3f' % corr)\n",
    "\n",
    "# Hide tick labels on the inset axes\n",
    "ax_hist_x.tick_params(axis=\"x\", labelbottom=False)\n",
    "ax_hist_y.tick_params(axis=\"y\", labelleft=False)\n",
    "\n",
    "ax.legend(loc=\"lower left\", fontsize = 13)\n",
    "\n",
    "# Show the plot\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'figures/scatter_{DATASET}_syntheticcanary_uniformlabel.pdf', dpi=300, bbox_inches='tight', pad_inches=0)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's do a quantile plot\n",
    "from scipy.stats import rankdata\n",
    "\n",
    "# Assuming scores_no_synthetic, challenge_bits_no_synthetic, scores_synthetic, challenge_bits_synthetic are defined\n",
    "alpha = 0.7\n",
    "s = 15\n",
    "\n",
    "# Prepare the data\n",
    "membership_scores_no_synthetic = np.array([k['score'] for k in scores_no_synthetic])\n",
    "ranks_no_synthetic = rankdata(membership_scores_no_synthetic, method='average')  # 'average' method to handle ties\n",
    "quantiles_no_synthetic = ranks_no_synthetic / len(membership_scores_no_synthetic)\n",
    "\n",
    "membership_scores_synthetic = np.array([k['score'] for k in scores_synthetic])\n",
    "ranks_synthetic = rankdata(membership_scores_synthetic, method='average')  # 'average' method to handle ties\n",
    "quantiles_synthetic = ranks_synthetic / len(membership_scores_synthetic)\n",
    "\n",
    "membership_labels_no_synthetic = np.array([k['challenge_bit'] for k in challenge_bits_no_synthetic])\n",
    "members_no_synthetic = quantiles_no_synthetic[membership_labels_no_synthetic == 1]\n",
    "non_members_no_synthetic = quantiles_no_synthetic[membership_labels_no_synthetic == 0]\n",
    "\n",
    "membership_labels_synthetic = np.array([k['challenge_bit'] for k in challenge_bits_synthetic])\n",
    "members_synthetic = quantiles_synthetic[membership_labels_synthetic == 1]\n",
    "non_members_synthetic = quantiles_synthetic[membership_labels_synthetic == 0]\n",
    "\n",
    "# Create the main plot\n",
    "fig, ax = plt.subplots(figsize=(5, 5))\n",
    "\n",
    "# Plot the members and non-members\n",
    "scatter_members = ax.scatter(members_no_synthetic, members_synthetic, s=s, alpha=alpha,\n",
    "                             color='darkcyan', label='Members')\n",
    "scatter_non_members = ax.scatter(non_members_no_synthetic, non_members_synthetic, s=s, alpha=alpha,\n",
    "                                 color='slategrey', label='Non-members')\n",
    "\n",
    "# Set labels and limits\n",
    "ax.set_xlabel('RMIA scores - Model (loss) - AUC=%0.2f' % roc_auc_no_synthetic, fontsize = 12)\n",
    "ax.set_ylabel(f'RMIA scores - synthetic (2-gram loss) AUC=%0.2f' % roc_auc_synthetic, fontsize = 12)\n",
    "\n",
    "ax.grid()\n",
    "\n",
    "ax.legend(loc=\"lower right\", fontsize = 12)\n",
    "\n",
    "# Show the plot\n",
    "plt.title('Quantile-Quantile plot', fontsize=12)\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "privacy-estimates",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
