{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import datasets\n",
    "from privacy_estimates.experiments.aml import JobList, DatastoreURI, Job\n",
    "from sklearn.metrics import roc_curve, auc\n",
    "from tempfile import TemporaryDirectory\n",
    "from datasets import load_from_disk\n",
    "from latex import Project\n",
    "from typing import Dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jobs = JobList.from_urls([\n",
    "    # synthetic - sst2\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/yellow_ring_lx2h1d9qzk?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/PPML/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/serene_garden_6616pjjqqg?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/PPML/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/frosty_chaconia_753kyps8my?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/PPML/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/cool_picture_lsyzr3zmkq?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/PPML/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "    # model - sst2\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/upbeat_calypso_6tl7xynv7x?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/PPML/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/epic_dolphin_1vlnxww2nf?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/PPML/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/modest_battery_12qjzvz5qn?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/PPML/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "    # synthetic - agnews\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/red_chicken_lnczxkyq9x?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/PPML/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/loyal_sprout_kv6mx3gzhv?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/PPML/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/lemon_station_tyrcqctd5v?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/PPML/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/plum_vase_ckkng3bx5k?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/PPML/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "    # model - agnews\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/gifted_camel_nsmpmpk678?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/PPML/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/stoic_boot_fqsbh32gtr?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/PPML/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/elated_chayote_nxw6ztghb8?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/PPML/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"agnews\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.DataFrame({\n",
    "    \"type\": [\"synthetic\", \"synthetic\", \"synthetic\", \"synthetic\", \"model\", \"model\", \"model\", \"synthetic\", \"synthetic\", \"synthetic\", \"synthetic\", \"model\", \"model\", \"model\"],\n",
    "    \"n_rep\": [2, 4, 8, 16, 1, 2, 4, 2, 4, 8, 16, 1, 2, 4],\n",
    "    \"dataset\": [\"sst2\"] * 7 + [\"agnews\"] * 7,\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_from_uri(uri: DatastoreURI):\n",
    "    with TemporaryDirectory() as tmpdir:\n",
    "        uri.download_content(tmpdir)\n",
    "        return load_from_disk(tmpdir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_metrics(job):\n",
    "    estimate_privacy = job.get_node(\"estimate_privacy\")\n",
    "    scores = load_from_uri(estimate_privacy.inputs[\"scores\"])\n",
    "    challenge_bits = load_from_uri(estimate_privacy.inputs[\"challenge_bits\"])\n",
    "\n",
    "    fpr, tpr, _ = roc_curve(challenge_bits[\"challenge_bit\"], scores[\"score\"])\n",
    "    metrics = {\"FPR\": fpr, \"TPR\": tpr}\n",
    "    metrics[\"AuC\"] = auc(fpr, tpr)\n",
    "    for target_fpr in [0.01, 0.05, 0.1]:\n",
    "        metrics[f\"TPR@FPR={target_fpr}\"] = np.interp(target_fpr, fpr, tpr)\n",
    "        \n",
    "    return metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.concat([data, pd.DataFrame([compute_metrics(job) for job in jobs])], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "perp = 31\n",
    "fig, ax = plt.subplots(1, 2, figsize=(12, 6))\n",
    "plot_options = {\n",
    "    (\"synthetic\", 2): {\"color\": \"darkred\", \"label\": \"Synth, n_rep=2\"},\n",
    "    (\"synthetic\", 4): {\"color\": \"darkblue\", \"label\": \"Synth, n_rep=4\"},\n",
    "    (\"synthetic\", 8): {\"color\": \"darkgreen\", \"label\": \"Synth, n_rep=8\"},\n",
    "    (\"synthetic\", 16): {\"color\": \"darkorange\", \"label\": \"Synth, n_rep=16\"},\n",
    "    (\"model\", 1): {\"color\": \"lightgreen\", \"label\": \"Model, n_rep=1\"},\n",
    "    (\"model\", 2): {\"color\": \"lightblue\", \"label\": \"Model, n_rep=2\"},\n",
    "    (\"model\", 4): {\"color\": \"lightcoral\", \"label\": \"Model, n_rep=4\"},\n",
    "}\n",
    "\n",
    "for type, n_rep in plot_options:\n",
    "    data_i = data[(data[\"n_rep\"] == n_rep) & (data[\"type\"] == type) & (data[\"dataset\"] == dataset)]\n",
    "    fpr = data_i[\"FPR\"].values[0]\n",
    "    tpr = data_i[\"TPR\"].values[0]\n",
    "    ax[0].plot(fpr, tpr, **plot_options[(type, n_rep)])\n",
    "    ax[1].plot(fpr, tpr, **plot_options[(type, n_rep)])\n",
    "\n",
    "\n",
    "ax[0].plot([0, 1], [0, 1], \"--\", color=\"black\", alpha=0.5, label=\"Random guess baseline\")\n",
    "ax[1].plot([0, 1], [0, 1], \"--\", color=\"black\", alpha=0.5, label=\"Random guess baseline\")\n",
    "ax[0].set_xlabel(\"False positive rate\", fontsize=16)\n",
    "ax[0].set_ylabel(\"True positive rate\", fontsize=16)\n",
    "ax[1].set_xlabel(\"False positive rate\", fontsize=16)\n",
    "ax[1].set_ylabel(\"True positive rate\", fontsize=16)\n",
    "ax[1].legend(loc=\"lower right\", fontsize=14)\n",
    "ax[0].grid(True, which=\"major\", ls=\"--\", alpha=0.8)\n",
    "ax[1].grid(True, which=\"major\", ls=\"--\", alpha=0.8)\n",
    "ax[1].set_xscale(\"log\")\n",
    "ax[1].set_yscale(\"log\")\n",
    "ax[0].set_xlim(0, 1)\n",
    "ax[0].set_ylim(0, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "overleaf = Project.from_env(path_env_name=\"LATEX_GIT_PATH\")\n",
    "for (type, n_rep), options in plot_options.items():\n",
    "    df_i = data[(data[\"n_rep\"] == n_rep) & (data[\"type\"] == type) & (data[\"dataset\"] == dataset)]\n",
    "    roc = pd.DataFrame({\"fpr\": df_i[\"FPR\"].values[0], \"tpr\": df_i[\"TPR\"].values[0]})\n",
    "    overleaf.add_dataframe(roc, f\"data/n_rep/{dataset}/roc/{type}_{n_rep}.tsv\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "privacy-estimates",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
