{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5a227e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from scipy import optimize as opt\n",
    "import os\n",
    "from pathlib import Path\n",
    "import warnings\n",
    "import pickle\n",
    "import re\n",
    "from tqdm import tqdm\n",
    "from functools import partial\n",
    "\n",
    "from mbi import Dataset, Domain, LinearMeasurement, estimation\n",
    "\n",
    "import src.mst as mst\n",
    "import src.brownian_mechanism as bm\n",
    "from src.basic_mechanisms import gaussian_mechanism_rdp_variance\n",
    "from src.measurement_wrapper import MeasurementWrapper\n",
    "\n",
    "from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.base import clone\n",
    "from sklearn.metrics import brier_score_loss\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from tueplots import figsizes, fontsizes, fonts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b91fc4f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "result_base_path = \"results/accuracy-first/adult\"\n",
    "syn_df_base_path = \"results/accuracy-first/adult/synthetic-data\"\n",
    "cliques_file_path = \"results/accuracy-first/adult/cliques.p\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34506c3f",
   "metadata": {},
   "source": [
    "# Data Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8bbb600",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"datasets/adult-preprocessed.csv\", index_col=False)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9050bf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def X_y_split(df):\n",
    "    X = df.drop(columns=[\"income\"])\n",
    "    y = df[\"income\"]\n",
    "    return X, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "506e698e",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(54787529)\n",
    "\n",
    "y = df[\"income\"]\n",
    "train_df, test_validation_df = train_test_split(df, test_size=0.6)\n",
    "test_df, validation_df = train_test_split(test_validation_df, test_size=2/3, stratify=test_validation_df[\"income\"])\n",
    "\n",
    "domain_sizes = {col_name: df[col_name].nunique() for col_name in df.columns}\n",
    "domain = Domain.fromdict(domain_sizes)\n",
    "real_dataset = Dataset(train_df, domain)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14bcf6b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X, train_y = X_y_split(train_df)\n",
    "test_X, test_y = X_y_split(test_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "802a44d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_cliques(filename, cliques):\n",
    "    with open(filename, \"wb\") as file:\n",
    "        pickle.dump(cliques, file)\n",
    "\n",
    "def load_cliques(filename):\n",
    "    with open(filename, \"rb\") as file:\n",
    "        return pickle.load(filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f838187",
   "metadata": {},
   "source": [
    "# Privacy Calculations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bffc523b",
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 20.0\n",
    "epsilon_query_selection = 0.01\n",
    "epsilon_generation_list = list(np.logspace(-2, 0, 7)) # 0.01 to 1\n",
    "epsilon_accuracy_checking = 0.01\n",
    "accuracy_threshold = 0.825\n",
    "\n",
    "utility_sensitivity = 1 / validation_df.shape[0]\n",
    "\n",
    "def accuracy_check_noise_variance(split):\n",
    "    sigma_1, b = bm.calculate_mechanism_scales_for_svt_gauss_laplace(epsilon_accuracy_checking, split, alpha, utility_sensitivity)\n",
    "    return (sigma_1**2 + 2 * b**2)**0.5\n",
    "\n",
    "res = opt.minimize_scalar(accuracy_check_noise_variance, bounds=(0, 1))\n",
    "\n",
    "accuracy_checking_variance_split = res.x # type: ignore\n",
    "print(f\"Accuracy Checking Variance Split: {accuracy_checking_variance_split}\")\n",
    "print(f\"Accuracy Checking Variance: {res.fun}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a1b0025",
   "metadata": {},
   "outputs": [],
   "source": [
    "bm.calculate_mechanism_scale_no_svt(epsilon_accuracy_checking, utility_sensitivity, len(epsilon_generation_list), alpha)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "104d681d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rdp_to_adp_conversion(rdp_eps, alpha, delta):\n",
    "    return rdp_eps + np.log(1 / delta) / (alpha - 1)\n",
    "\n",
    "rdp_eps_col = \"RDP Epsilon @ $\\\\alpha = 20$\"\n",
    "eps_conversion_df = pd.DataFrame({rdp_eps_col: epsilon_generation_list})\n",
    "eps_conversion_df[\"ADP Epsilon @ $\\\\delta = 10^{-5}$\"] = rdp_to_adp_conversion(eps_conversion_df[rdp_eps_col], alpha, delta=1e-5)\n",
    "\n",
    "eps_conversion_df.to_latex(\"figures/adult/epsilon_conversion_table.tex\", index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "930af96d",
   "metadata": {},
   "source": [
    "# Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba950089",
   "metadata": {},
   "outputs": [],
   "source": [
    "class IntermediateResultTracker:\n",
    "    def __init__(self):\n",
    "        self.syn_dfs = []\n",
    "        self.query_releases = []\n",
    "        self.trained_models = []\n",
    "        self.validation_accuracies = []\n",
    "        self.test_accuracies = []\n",
    "        self.epsilons = []\n",
    "        self.noise_variances = []\n",
    "\n",
    "    def add_result(self, syn_df, query_release, trained_model, validation_accuracy, test_accuracy, epsilon, noise_variance):\n",
    "        self.syn_dfs.append(syn_df)\n",
    "        self.query_releases.append(query_release)\n",
    "        self.trained_models.append(trained_model)\n",
    "        self.validation_accuracies.append(validation_accuracy)\n",
    "        self.test_accuracies.append(test_accuracy)\n",
    "        self.epsilons.append(epsilon)\n",
    "        self.noise_variances.append(noise_variance)\n",
    "\n",
    "    def save_results(self, filename):\n",
    "        with open(filename, \"wb\") as file:\n",
    "            pickle.dump({\n",
    "                \"syn_dfs\": self.syn_dfs,\n",
    "                \"validation_accuracies\": self.validation_accuracies,\n",
    "                \"test_accuracies\": self.test_accuracies,\n",
    "                \"epsilons\": self.epsilons,\n",
    "                \"noise_variances\": self.noise_variances,\n",
    "            }, file)\n",
    "\n",
    "    @staticmethod\n",
    "    def load_results(filename):\n",
    "        with open(filename, \"rb\") as file:\n",
    "            obj = pickle.load(file)\n",
    "        \n",
    "        res = IntermediateResultTracker()\n",
    "        res.syn_dfs = obj[\"syn_dfs\"]\n",
    "        res.validation_accuracies = obj[\"validation_accuracies\"]\n",
    "        res.test_accuracies = obj[\"test_accuracies\"]\n",
    "        res.epsilons = obj[\"epsilons\"]\n",
    "        res.noise_variances = obj[\"noise_variances\"]\n",
    "        return res\n",
    "        \n",
    "\n",
    "def generate_and_evaluate_synthetic_data(\n",
    "        release_vector: np.ndarray, validation_data: np.ndarray,\n",
    "        release_variance: float, epsilon: float, test_data: np.ndarray, measurement_wrapper, log1, compressed_domain, \n",
    "        undo_compress_fn, intermediate_result_tracker\n",
    "        ) -> float:\n",
    "    log2 = measurement_wrapper.convert_vector_to_measurements(release_vector)\n",
    "    log2 = [LinearMeasurement(m.noisy_measurement, m.clique, release_variance**0.5, m.query) for m in log2]\n",
    "\n",
    "    est = estimation.mirror_descent(compressed_domain, log1+log2, iters=10000)\n",
    "    synth = est.synthetic_data()\n",
    "    synth = undo_compress_fn(synth)\n",
    "\n",
    "    syn_df = synth.df\n",
    "    X_syn, y_syn = X_y_split(syn_df)\n",
    "\n",
    "    model = GradientBoostingClassifier()\n",
    "    model.fit(X_syn, y_syn)\n",
    "    validation_accuracy = model.score(*X_y_split(validation_data))\n",
    "    test_accuracy = model.score(*X_y_split(test_data))\n",
    "\n",
    "    intermediate_result_tracker.add_result(syn_df, log2, model, validation_accuracy, test_accuracy, epsilon, release_variance)\n",
    "\n",
    "    return float(validation_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4693670a",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_repeats = 50\n",
    "threshold_checking_variants = [\"plain-gaussian\", \"svt\"]\n",
    "\n",
    "intermediate_results = {variant: [] for variant in threshold_checking_variants}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9739742",
   "metadata": {},
   "source": [
    "## Marginal Query Selection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e6bd8aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(4234275)\n",
    "rho_selection = epsilon_query_selection / alpha\n",
    "cliques, log1, compressed_data, undo_compress_fn = mst.run_selection(real_dataset, rho_selection)\n",
    "cliques += [(col, \"income\") for col in df.columns if col != \"income\" and (col, \"income\") not in cliques and (\"income\", col) not in cliques]\n",
    "save_cliques(cliques_file_path, cliques)\n",
    "\n",
    "cliques_sensitivity = 1 * np.sqrt(len(cliques)) # add-remove neighbourhood\n",
    "\n",
    "log2 = mst.measure(compressed_data, cliques, sigma=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e152d195",
   "metadata": {},
   "source": [
    "## Brownian Mechanism + Synthetic Data Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06b15078",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(69674586)\n",
    "warnings.simplefilter(\"ignore\", FutureWarning, 102)\n",
    "for repeat_ind in range(n_repeats):\n",
    "    print(f\"Repeat {repeat_ind + 1} / {n_repeats}\")\n",
    "\n",
    "    measurement_wrapper = MeasurementWrapper(log2)\n",
    "    real_query_values = measurement_wrapper.get_concatenated_values()\n",
    "\n",
    "    for variant in threshold_checking_variants:\n",
    "        print(f\"    Variant {variant}\")\n",
    "\n",
    "        intermediate_result_tracker = IntermediateResultTracker()\n",
    "        utility_function = partial(\n",
    "            generate_and_evaluate_synthetic_data, \n",
    "            test_data=test_df,\n",
    "            measurement_wrapper=measurement_wrapper, log1=log1, \n",
    "            compressed_domain=compressed_data.domain, \n",
    "            undo_compress_fn=undo_compress_fn, \n",
    "            intermediate_result_tracker=intermediate_result_tracker\n",
    "        )\n",
    "\n",
    "        if variant == \"svt\":\n",
    "            final_values, final_utility = bm.run_until_private_utility_above_threshold_gauss_laplace(\n",
    "                real_query_values, validation_df, utility_function, accuracy_threshold,\n",
    "                epsilon_generation_list, epsilon_accuracy_checking, \n",
    "                accuracy_checking_variance_split, alpha, cliques_sensitivity,\n",
    "                utility_sensitivity\n",
    "            )\n",
    "        elif variant == \"plain-gaussian\":\n",
    "            final_values, final_utility = bm.run_until_private_utility_above_threshold_no_svt(\n",
    "                real_query_values, validation_df, utility_function, accuracy_threshold,\n",
    "                epsilon_accuracy_checking, utility_sensitivity, epsilon_generation_list,\n",
    "                alpha, cliques_sensitivity,\n",
    "            )\n",
    "        else:\n",
    "            raise RuntimeError(f\"Unknown variant {variant}\")\n",
    "\n",
    "        intermediate_results[variant].append(intermediate_result_tracker)\n",
    "        intermediate_result_tracker.save_results(f\"{result_base_path}/repeat_{repeat_ind}_{variant}.p\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d67b367",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load saved results:\n",
    "# for variant in threshold_checking_variants:\n",
    "#     for repeat_ind in range(n_repeats):\n",
    "#         filename = f\"{result_base_path}/repeat_{repeat_ind}_{variant}.p\"\n",
    "#         intermediate_results[variant].append(IntermediateResultTracker.load_results(filename))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a358abc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "records = []\n",
    "\n",
    "for variant in threshold_checking_variants:\n",
    "    for repeat_ind, intermediate_result_tracker in enumerate(intermediate_results[variant]):\n",
    "        for i, (epsilon, validation_accuracy, test_accuracy) in enumerate(zip(intermediate_result_tracker.epsilons, intermediate_result_tracker.validation_accuracies, intermediate_result_tracker.test_accuracies)):\n",
    "            records.append({\n",
    "                \"Epsilon\": epsilon,\n",
    "                \"Validation Accuracy\": validation_accuracy,\n",
    "                \"Test Accuracy\": test_accuracy,\n",
    "                \"repeat_ind\": repeat_ind,\n",
    "                \"Variant\": variant,\n",
    "                \"Is Final\": i == len(intermediate_result_tracker.epsilons) - 1\n",
    "            })\n",
    "\n",
    "result_df = pd.DataFrame.from_records(records)\n",
    "result_df[\"Epsilon Rounded\"] = [round(eps, 3) for eps in result_df.Epsilon]\n",
    "result_df[\"Status\"] = [\"Accepted Output\" if is_final else \"Rejected Output\" for is_final in result_df[\"Is Final\"]]\n",
    "result_df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3492150f",
   "metadata": {},
   "source": [
    "# No Synthetic Data Baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "baef77fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(562219)\n",
    "model = GradientBoostingClassifier()\n",
    "model.fit(train_X, train_y)\n",
    "nonsyn_test_accuracy = model.score(test_X, test_y)\n",
    "nonsyn_validation_accuracy = model.score(*X_y_split(validation_df))\n",
    "\n",
    "print(nonsyn_test_accuracy)\n",
    "print(nonsyn_validation_accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4a458df",
   "metadata": {},
   "source": [
    "# Non-DP Synthetic Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "183ce711",
   "metadata": {},
   "outputs": [],
   "source": [
    "import mbi.marginal_loss\n",
    "\n",
    "\n",
    "warnings.simplefilter(\"ignore\", FutureWarning, 102)\n",
    "np.random.seed(54673568)\n",
    "nondp_validation_accuracies = []\n",
    "nondp_test_accuracies = []\n",
    "for i in tqdm(range(10)):\n",
    "    log2_nonzero_noise = [LinearMeasurement(m.noisy_measurement, m.clique, 1, m.query) for m in log2]\n",
    "    loss_fn = mbi.marginal_loss.from_linear_measurements(log1 + log2_nonzero_noise)\n",
    "\n",
    "    nondp_pgm_iters = 10000\n",
    "    losses = np.zeros(nondp_pgm_iters)\n",
    "    counter = 0\n",
    "    def callback(x):\n",
    "        global counter\n",
    "        losses[counter] = loss_fn.loss_fn(x)\n",
    "        counter += 1\n",
    "\n",
    "    est = estimation.mirror_descent(compressed_data.domain, log1+log2_nonzero_noise, iters=nondp_pgm_iters)#, callback_fn=callback)\n",
    "\n",
    "    synth = est.synthetic_data()\n",
    "    synth = undo_compress_fn(synth)\n",
    "\n",
    "    syn_df = synth.df\n",
    "    X_syn, y_syn = X_y_split(syn_df)\n",
    "\n",
    "    model = GradientBoostingClassifier()\n",
    "    model.fit(X_syn, y_syn)\n",
    "    nondp_validation_accuracy = model.score(*X_y_split(validation_df))\n",
    "    nondp_test_accuracy = model.score(test_X, test_y)\n",
    "\n",
    "    nondp_validation_accuracies.append(nondp_validation_accuracy)\n",
    "    nondp_test_accuracies.append(nondp_test_accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7820cdfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "nondp_validation_accuracy = np.array(nondp_validation_accuracies).mean()\n",
    "nondp_test_accuracy = np.array(nondp_test_accuracies).mean()\n",
    "print(nondp_validation_accuracy)\n",
    "print(nondp_test_accuracy)\n",
    "print()\n",
    "print(np.array(nondp_validation_accuracies).std())\n",
    "print(np.array(nondp_test_accuracies).std())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07356363",
   "metadata": {},
   "source": [
    "# Plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2309e492",
   "metadata": {},
   "outputs": [],
   "source": [
    "figsize_factory = figsizes.iclr2024\n",
    "plt.rcParams.update(fontsizes.iclr2024())\n",
    "plt.rcParams.update(fonts.iclr2024_tex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86d8e5d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "variant_names = {\n",
    "    \"svt\": \"SVT\",\n",
    "    \"plain-gaussian\": \"Plain Gaussian\"\n",
    "}\n",
    "hue_order = [\"Plain Gaussian\", \"SVT\"]\n",
    "\n",
    "eps_label = \"Epsilon (RDP, $\\\\alpha = 20$)\"\n",
    "variant_colors={\n",
    "    \"Plain Gaussian\": \"tan\",\n",
    "    \"SVT\": \"turquoise\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d940280",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_result_df = result_df[result_df[\"Is Final\"]].copy()\n",
    "final_result_df[\"Variant\"] = [variant_names[v] for v in final_result_df.Variant]\n",
    "final_epsilon_count_df = final_result_df.groupby(\"Variant\").value_counts(subset=[\"Epsilon Rounded\"], normalize=True).reset_index()\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=figsize_factory(nrows=1, ncols=1, rel_width=0.5)[\"figure.figsize\"])\n",
    "sns.barplot(data=final_epsilon_count_df, x=\"Epsilon Rounded\", y=\"proportion\", hue=\"Variant\", hue_order=hue_order, palette=variant_colors, ax=ax)\n",
    "ax.set_ylabel(\"Proportion\")\n",
    "ax.set_xlabel(eps_label)\n",
    "ax.set_title(\"Distribution of Accepted Epsilon\")\n",
    "plt.savefig(\"figures/adult/final-epsilon-distribution.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80c2591f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_legend_handles_labels_fix(ax):\n",
    "    handles = ax.legend_.legend_handles\n",
    "    labels = [t.get_text() for t in ax.legend_.get_texts()]\n",
    "    return handles, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "653f3247",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=figsize_factory(nrows=1, ncols=1, rel_width=0.5)[\"figure.figsize\"])\n",
    "sns.ecdfplot(data=final_result_df, x=\"Validation Accuracy\", hue=\"Variant\", hue_order=hue_order, palette=variant_colors, legend=True, ax=ax)\n",
    "sns_leg_info = get_legend_handles_labels_fix(ax)\n",
    "ax.axvline(accuracy_threshold, color=\"grey\", linestyle=\"dashed\", label=\"Accuracy Threshold\")\n",
    "ax.axvline(nondp_validation_accuracy, color=\"C2\", linestyle=\"dashed\", label=\"Non-DP Validation Accuracy\")\n",
    "leg_info = ax.get_legend_handles_labels()\n",
    "ax.get_legend().remove()\n",
    "ax.legend(sns_leg_info[0] + leg_info[0], sns_leg_info[1] + leg_info[1])\n",
    "ax.set_title(\"eCDF of Accepted Validation Accuracy\")\n",
    "plt.savefig(\"figures/adult/final-validation-accuracy-distribution.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "149fb915",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=figsize_factory(nrows=1, ncols=1, rel_width=0.5)[\"figure.figsize\"])\n",
    "sns.ecdfplot(data=final_result_df, x=\"Test Accuracy\", hue=\"Variant\", hue_order=hue_order, palette=variant_colors, legend=True, ax=ax)\n",
    "sns_leg_info = get_legend_handles_labels_fix(ax)\n",
    "ax.axvline(accuracy_threshold, color=\"grey\", linestyle=\"dashed\", label=\"Accuracy Threshold\")\n",
    "ax.axvline(nondp_test_accuracy, color=\"C2\", linestyle=\"dashed\", label=\"Non-DP Test Accuracy\")\n",
    "leg_info = ax.get_legend_handles_labels()\n",
    "ax.get_legend().remove()\n",
    "ax.legend(sns_leg_info[0] + leg_info[0], sns_leg_info[1] + leg_info[1])\n",
    "ax.set_title(\"eCDF of Accepted Test Accuracy\")\n",
    "plt.savefig(\"figures/adult/final-test-accuracy-distribution.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b093b41e",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=figsize_factory(nrows=1, ncols=1, rel_width=0.5)[\"figure.figsize\"])\n",
    "sns.swarmplot(data=final_result_df, x=\"Epsilon Rounded\", y=\"Validation Accuracy\", hue=\"Variant\", hue_order=hue_order, palette=variant_colors, dodge=True, size=2.6, ax=ax)\n",
    "ax.axhline(accuracy_threshold, color=\"grey\", linestyle=\"dashed\", label=\"Accuracy Threshold\")\n",
    "ax.axhline(nondp_validation_accuracy, color=\"C2\", linestyle=\"dashed\", label=\"Non-DP Validation Accuracy\")\n",
    "leg_info = ax.get_legend_handles_labels()\n",
    "ax.get_legend().remove()\n",
    "ax.legend(*leg_info, loc=\"upper center\", ncol=2, bbox_to_anchor=(0.4, -0.25))\n",
    "ax.set_title(\"Accepted Validation Accuracy by Epsilon\")\n",
    "ax.set_xlabel(eps_label)\n",
    "plt.savefig(\"figures/adult/final-validation-accuracy-by-epsilon.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dfe33a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=figsize_factory(nrows=1, ncols=1, rel_width=0.5)[\"figure.figsize\"])\n",
    "sns.swarmplot(data=final_result_df, x=\"Epsilon Rounded\", y=\"Test Accuracy\", hue=\"Variant\", hue_order=hue_order, palette=variant_colors, dodge=True, size=2.6, ax=ax)\n",
    "ax.axhline(accuracy_threshold, color=\"grey\", linestyle=\"dashed\", label=\"Accuracy Threshold\")\n",
    "ax.axhline(nondp_test_accuracy, color=\"C2\", linestyle=\"dashed\", label=\"Non-DP Test Accuracy\")\n",
    "leg_info = ax.get_legend_handles_labels()\n",
    "ax.get_legend().remove()\n",
    "ax.legend(*leg_info, loc=\"upper center\", ncol=2, bbox_to_anchor=(0.4, -0.25))\n",
    "ax.set_title(\"Accepted Test Accuracy by Epsilon\")\n",
    "ax.set_xlabel(eps_label)\n",
    "plt.savefig(\"figures/adult/final-test-accuracy-by-epsilon.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d4215c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=figsize_factory(nrows=1, ncols=1, rel_width=0.5)[\"figure.figsize\"])\n",
    "sns.scatterplot(data=final_result_df, x=\"Validation Accuracy\", y=\"Test Accuracy\", hue=\"Variant\", palette=variant_colors, s=7.5, ax=ax)\n",
    "ax.plot((0.805, 0.835), (0.805, 0.835), color=\"grey\", linestyle=\"dashed\")\n",
    "ax.set_title(\"Accepted Validation vs. Test Accuracy\")\n",
    "plt.savefig(\"figures/adult/final-test-validation-accuracy-scatter.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3c3a491",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 2, figsize=figsize_factory(nrows=1, ncols=2)[\"figure.figsize\"], sharey=True)\n",
    "\n",
    "for i, variant in enumerate(threshold_checking_variants):\n",
    "    ax = axes[i]\n",
    "    variant_df = result_df[result_df.Variant == variant]\n",
    "\n",
    "    sns.swarmplot(data=variant_df, x=\"Epsilon Rounded\", y=\"Validation Accuracy\", hue=\"Status\", hue_order=[\"Rejected Output\", \"Accepted Output\"], ax=ax, legend=True, size=2.6, dodge=True)\n",
    "    ax.axhline(accuracy_threshold, color=\"grey\", linestyle=\"dashed\", label=\"Accuracy Threshold\")\n",
    "    ax.axhline(nondp_validation_accuracy, color=\"C2\", linestyle=\"dashed\", label=\"Non-DP Validation Accuracy\")\n",
    "    ax.set_xlabel(eps_label)\n",
    "    leg_info = ax.get_legend_handles_labels()\n",
    "    ax.get_legend().remove()\n",
    "    ax.set_title(f\"All Releases - {variant_names[variant]}\")\n",
    "\n",
    "fig.legend(*leg_info, ncol=2, loc=\"upper center\", bbox_to_anchor=(0.5, -0.10)) # type: ignore\n",
    "plt.savefig(f\"figures/adult/output-plots.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37012767",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
