{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import datasets\n",
    "from privacy_estimates.experiments.aml import JobList, DatastoreURI\n",
    "from sklearn.metrics import roc_curve, auc\n",
    "from tempfile import TemporaryDirectory\n",
    "from datasets import load_from_disk\n",
    "from latex import Project"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jobs = JobList.from_urls([\n",
    "    # prefix_length = 0\n",
    "    'https://ml.azure.com/runs/careful_curtain_wwgj9l5v6x?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    "    'https://ml.azure.com/runs/silver_oxygen_y3crykxm7h?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    "    'https://ml.azure.com/runs/quiet_dress_vl262j25my?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    "    'https://ml.azure.com/runs/frosty_raisin_0vghsl6hd7?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    "    'https://ml.azure.com/runs/honest_rocket_tmsgx803cg?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    "    'https://ml.azure.com/runs/boring_pepper_lwvwfd7xz4?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    "    'https://ml.azure.com/runs/funny_muscle_nd3qqmzhj1?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    "    'https://ml.azure.com/runs/olden_sponge_b74p3t2tbw?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    "    'https://ml.azure.com/runs/sharp_crowd_qg1brt2j46?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    "    # prefix_length = 10\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/funny_clock_62nz8l14fb?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47#\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/busy_malanga_4zygv56497?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47#\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/brave_shelf_cfjz5w26fx?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47#\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/tough_scooter_rrjfzqph52?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47#\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/salmon_root_mmtgk0sj89?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47#\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/zen_school_nv4j2tlg73?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47#\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/gifted_collar_tr4xlmq33y?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47#\",\n",
    "    # prefix_length = 20\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/teal_wing_s78t22zg2l?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/joyful_deer_w6b3rbmmhm?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/blue_sun_spbbf1t57s?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/amusing_library_wq66vflwzj?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "    # prefix_length = 30\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/gifted_jewel_w50pc8cd89?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/nice_foot_d6fwd2cxx1?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "    \"https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/lemon_battery_ws0704ly2n?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/ppml/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
    "])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.DataFrame([{k.replace(\"AZUREML_PARAMETER_\", \"\"): v for k, v in job.get_node(\"get_ood_canaries\").environment_variables.items()} for job in jobs])\n",
    "data[\"prefix_length\"] = data[\"prefix_length\"].fillna(0).astype(int)  # prefix_length 0 experiments don't have that parameter so set it to 0 if it's missing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data[\"min_ppl\"] = data[\"min_ppl\"].astype(float)\n",
    "data[\"max_ppl\"] = data[\"max_ppl\"].astype(float)\n",
    "data[\"prefix_length\"] = data[\"prefix_length\"].astype(int)\n",
    "data[\"ppl\"] = (data[\"min_ppl\"] + data[\"max_ppl\"]) / 2\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_from_uri(uri: DatastoreURI):\n",
    "    with TemporaryDirectory() as tmpdir:\n",
    "        uri.download_content(tmpdir)\n",
    "        return load_from_disk(tmpdir)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_metrics(job):\n",
    "    estimate_privacy = job.get_node(\"estimate_privacy\")\n",
    "    scores = load_from_uri(estimate_privacy.inputs[\"scores\"])\n",
    "    challenge_bits = load_from_uri(estimate_privacy.inputs[\"challenge_bits\"])\n",
    "\n",
    "    fpr, tpr, _ = roc_curve(challenge_bits[\"challenge_bit\"], scores[\"score\"])\n",
    "    metrics = {\"FPR\": fpr, \"TPR\": tpr}\n",
    "    metrics[\"AuC\"] = auc(fpr, tpr)\n",
    "    for target_fpr in [0.01, 0.05, 0.1]:\n",
    "        metrics[f\"TPR@FPR={target_fpr}\"] = np.interp(target_fpr, fpr, tpr)\n",
    "        \n",
    "    return metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.concat([data, pd.DataFrame([compute_metrics(job) for job in jobs])], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = data.sort_values(\"ppl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6,6))\n",
    "plot_options = {\n",
    "    0: {\"color\": \"darkred\", \"label\": \"In-distribution prefix length: 0\"},\n",
    "    10: {\"color\": \"darkblue\", \"label\": \"In-distribution prefix length: 10\"},\n",
    "    20: {\"color\": \"darkgreen\", \"label\": \"In-distribution prefix length: 20\"},\n",
    "    30: {\"color\": \"darkorange\", \"label\": \"In-distribution prefix length: 30\"},\n",
    "}\n",
    "for prefix in plot_options:\n",
    "    ppl = data[data[\"prefix_length\"] == prefix][\"ppl\"]\n",
    "    auc = data[data[\"prefix_length\"] == prefix][\"AuC\"]\n",
    "    plt.plot(ppl, auc, \"-o\", **plot_options[prefix])\n",
    "\n",
    "plt.axhline(y=0.5, color='black', linestyle='--', alpha=0.5, label = 'Random guess baseline')\n",
    "\n",
    "\n",
    "plt.xticks([10**k for k in (0, 1, 2, 3, 4, 5)])\n",
    "plt.yticks([0.5, 0.6, 0.7, 0.8, 0.9, 1.0], labels=['0.5', '0.6', '0.7', '0.8', '0.9', '1.0'])\n",
    "\n",
    "# Enable the grid\n",
    "plt.grid(True, which=\"major\", ls=\"--\", alpha=0.8)\n",
    "\n",
    "plt.legend(loc='upper right', fontsize=14)\n",
    "plt.xlabel('Canary perplexity', fontsize=16)\n",
    "plt.ylabel('AUC', fontsize=16)\n",
    "plt.ylim(0.4, 1.02)\n",
    "plt.xscale('log')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_auc = {\n",
    "    \"prefix_0\": pd.DataFrame({'ppl': data[data['prefix_length']==0]['ppl'].values, 'auc': data[data['prefix_length']==0]['AuC'].values}),\n",
    "    \"prefix_10\": pd.DataFrame({'ppl': data[data['prefix_length']==10]['ppl'].values, 'auc': data[data['prefix_length']==10]['AuC'].values}),\n",
    "    \"prefix_20\": pd.DataFrame({'ppl': data[data['prefix_length']==20]['ppl'].values, 'auc': data[data['prefix_length']==20]['AuC'].values}),\n",
    "    \"prefix_30\": pd.DataFrame({'ppl': data[data['prefix_length']==30]['ppl'].values, 'auc': data[data['prefix_length']==30]['AuC'].values}),\n",
    "}\n",
    "df_auc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "perp = 31\n",
    "fig, ax = plt.subplots(1, 2, figsize=(12, 6))\n",
    "plot_options = {\n",
    "    0: {\"color\": \"darkred\", \"label\": \"In-distribution prefix length: 0\", \"linestyle\": \"-\"},\n",
    "    10: {\"color\": \"darkblue\", \"label\": \"In-distribution prefix length: 10\", \"linestyle\": \"--\"},\n",
    "    20: {\"color\": \"darkgreen\", \"label\": \"In-distribution prefix length: 20\", \"linestyle\": \"-.\"},\n",
    "    30: {\"color\": \"darkorange\", \"label\": \"In-distribution prefix length: 30\", \"linestyle\": \":\"},\n",
    "}\n",
    "\n",
    "for prefix in plot_options:\n",
    "    data_i = data[(perp*0.9 <= data[\"ppl\"]) & (data[\"ppl\"] <= perp*1.1)]\n",
    "    data_i = data_i[data_i[\"prefix_length\"] == prefix]\n",
    "    fpr = data_i[\"FPR\"].values[0]\n",
    "    tpr = data_i[\"TPR\"].values[0]\n",
    "    ax[0].plot(fpr, tpr, **plot_options[prefix])\n",
    "    ax[1].plot(fpr, tpr, **plot_options[prefix])\n",
    "\n",
    "\n",
    "ax[0].plot([0, 1], [0, 1], \"--\", color=\"black\", alpha=0.5, label=\"Random guess baseline\")\n",
    "ax[1].plot([0, 1], [0, 1], \"--\", color=\"black\", alpha=0.5, label=\"Random guess baseline\")\n",
    "ax[0].set_xlabel(\"False positive rate\", fontsize=16)\n",
    "ax[0].set_ylabel(\"True positive rate\", fontsize=16)\n",
    "ax[1].set_xlabel(\"False positive rate\", fontsize=16)\n",
    "ax[1].set_ylabel(\"True positive rate\", fontsize=16)\n",
    "ax[1].legend(loc=\"lower right\", fontsize=14)\n",
    "ax[0].grid(True, which=\"major\", ls=\"--\", alpha=0.8)\n",
    "ax[1].grid(True, which=\"major\", ls=\"--\", alpha=0.8)\n",
    "ax[1].set_xscale(\"log\")\n",
    "ax[1].set_yscale(\"log\")\n",
    "ax[0].set_xlim(0, 1)\n",
    "ax[0].set_ylim(0, 1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_ppl = data[(perp*0.9 <= data[\"ppl\"]) & (data[\"ppl\"] <= perp*1.1)]\n",
    "df_roc = {\n",
    "    \"prefix_0\": pd.DataFrame({'fpr': data_ppl[data_ppl['prefix_length']==0]['FPR'].values[0], 'tpr': data_ppl[data_ppl['prefix_length']==0]['TPR'].values[0]}),\n",
    "    \"prefix_10\": pd.DataFrame({'fpr': data_ppl[data_ppl['prefix_length']==10]['FPR'].values[0], 'tpr': data_ppl[data_ppl['prefix_length']==10]['TPR'].values[0]}),\n",
    "    \"prefix_20\": pd.DataFrame({'fpr': data_ppl[data_ppl['prefix_length']==20]['FPR'].values[0], 'tpr': data_ppl[data_ppl['prefix_length']==20]['TPR'].values[0]}),\n",
    "    \"prefix_30\": pd.DataFrame({'fpr': data_ppl[data_ppl['prefix_length']==30]['FPR'].values[0], 'tpr': data_ppl[data_ppl['prefix_length']==30]['TPR'].values[0]}),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "overleaf = Project(url=\"https://git@git.overleaf.com/667bde737a03ee4008a9359f\")\n",
    "for name, df in df_auc.items():\n",
    "    overleaf.push_dataframe(df, f\"data/prefix/sst2/auc/{name}.tsv\")\n",
    "for name, df in df_roc.items():\n",
    "    overleaf.push_dataframe(df, f\"data/prefix/sst2/roc/{name}.tsv\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "privacy-estimates",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
