{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e265fe3d-7c54-44dc-bb17-56fb22885569",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import pandas as pd\n",
    "from xgboost import XGBClassifier\n",
    "from tabular_datasets import ADULT, HealthHeritage\n",
    "from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score\n",
    "from programmable_synthesizer import ProgrammableSynthesizer\n",
    "import torch\n",
    "import pickle\n",
    "from utils import Timer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c7cbd50-e7ee-434f-b758-7ebc6c729d8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_names = ['ADULT', 'HealthHeritage']\n",
    "datasets = {\n",
    "    'ADULT': ADULT,\n",
    "    'HealthHeritage': HealthHeritage\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70fdd06c-bf9e-4cf7-b4cc-22a37cbd35d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "random_seed = 42\n",
    "n_samples = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4f5f9bf-67dd-4c77-bfa6-b26d18055f53",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set the random seed\n",
    "np.random.seed(random_seed)\n",
    "torch.manual_seed(random_seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d85ba7f0-3afb-4163-bc64-87b4a2486eb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# xgb baselines\n",
    "dataset_xgb_metrics = {}\n",
    "for dataset_name in dataset_names:\n",
    "    \n",
    "    base_path = f'experiment_data/non_dp_benchmarks/'\n",
    "    save_path = f'{base_path}{dataset_name}_xgb_accuracies_{n_samples}_{random_seed}.pickle'\n",
    "    os.makedirs(base_path, exist_ok=True)\n",
    "    if os.path.isfile(save_path):\n",
    "        with open(save_path, 'rb') as f:\n",
    "            dataset_xgb_metrics[dataset_name] = pickle.load(f)\n",
    "    else:\n",
    "        \n",
    "        timer = Timer(n_samples)\n",
    "        \n",
    "        accs = []\n",
    "        bacs = []\n",
    "        f1s = []\n",
    "\n",
    "        disc_accs = []\n",
    "        disc_bacs = []\n",
    "        disc_f1s = []\n",
    "        \n",
    "        dataset = datasets[dataset_name](drop_education_num=True) if dataset_name == 'ADULT' else datasets[dataset_name]()\n",
    "    \n",
    "        for sample in range(n_samples):\n",
    "            \n",
    "            print(f'{dataset_name}    {timer}', end='\\r')\n",
    "            timer.start()\n",
    "\n",
    "            # train xgb for baseline\n",
    "            # non disc.\n",
    "            Xtrain, ytrain = dataset.get_Xtrain().numpy(), dataset.get_ytrain().numpy()\n",
    "            Xtest, ytest = dataset.get_Xtest().numpy(), dataset.get_ytest().numpy()\n",
    "\n",
    "            xgb = XGBClassifier()\n",
    "            xgb.fit(Xtrain, ytrain)\n",
    "            predictions = xgb.predict(Xtest)\n",
    "            acc, bac, f1 = accuracy_score(ytest, predictions), balanced_accuracy_score(ytest, predictions), f1_score(ytest, predictions)\n",
    "\n",
    "            # disc.\n",
    "            full_one_hot_train = dataset.get_Dtrain_full_one_hot()\n",
    "            full_one_hot_test = dataset.get_Dtest_full_one_hot()\n",
    "\n",
    "            Xtrain, ytrain = full_one_hot_train[:, :-2], full_one_hot_train[:, -1]\n",
    "            Xtest, ytest = full_one_hot_test[:, :-2], full_one_hot_test[:, -1]\n",
    "\n",
    "            xgb = XGBClassifier()\n",
    "            xgb.fit(Xtrain, ytrain)\n",
    "            predictions = xgb.predict(Xtest)\n",
    "            disc_acc, disc_bac, disc_f1 = accuracy_score(ytest, predictions), balanced_accuracy_score(ytest, predictions), f1_score(ytest, predictions)\n",
    "\n",
    "            accs.append(acc)\n",
    "            bacs.append(bac)\n",
    "            f1s.append(f1)\n",
    "\n",
    "            disc_accs.append(disc_acc)\n",
    "            disc_bacs.append(disc_bac)\n",
    "            disc_f1s.append(disc_f1)\n",
    "            \n",
    "            timer.end()\n",
    "\n",
    "        metrics_to_include = {\n",
    "            'xgb_acc': [np.mean(accs), np.mean(disc_accs), np.std(accs), np.std(disc_accs)],\n",
    "            'xgb_bac': [np.mean(bacs), np.mean(disc_bacs), np.std(bacs), np.std(disc_bacs)],\n",
    "            'xgb_f1': [np.mean(f1s), np.mean(disc_f1s), np.std(f1s), np.std(disc_f1s)]\n",
    "        }\n",
    "\n",
    "        dataset_xgb_metrics[dataset_name] = metrics_to_include\n",
    "        \n",
    "        with open(save_path, 'wb') as f:\n",
    "            pickle.dump(metrics_to_include, f)\n",
    "        \n",
    "        timer.duration()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9877e119-d540-49cc-8426-aa415161af4b",
   "metadata": {},
   "source": [
    "## DP Benchmark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9398f702-3c23-4cad-8d8f-efbbca97b246",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_samples = 5\n",
    "n_resamples = 5\n",
    "random_seed = 42\n",
    "algorithms = ['ProgSyn']\n",
    "\n",
    "chosen_metrics = {'tv_error': 0, 'xgb_acc': 3}\n",
    "epsilon_map = {1.0: 4}\n",
    "workload = 'all_three'\n",
    "epsilon = 1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c8ddbfd-3860-4299-87e8-9349ed28f403",
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset_name in dataset_names:\n",
    "    \n",
    "    for metric_name, metric_idx in chosen_metrics.items():\n",
    "        \n",
    "        plt.figure(figsize=(8, 6))\n",
    "        plt.title(f'{metric_name} on {dataset_name}')\n",
    "        plt.xlim([0.5, len(algorithms) + .5])\n",
    "    \n",
    "        for i, algorithm in enumerate(algorithms):\n",
    "        \n",
    "            evaluation_path = f'experiment_data/dp_benchmarks/evaluation_data/{dataset_name}/random_seed_{random_seed}/{algorithm}/collected_data_{algorithm}_{dataset_name}_{n_samples}_{n_resamples}_{workload}_{random_seed}.npy'\n",
    "            if os.path.isfile(evaluation_path):\n",
    "                collected_data = np.load(evaluation_path)\n",
    "            else:\n",
    "                print(f'Experiment of {algorithm} on {dataset_name} not found under the given conditions')\n",
    "                continue\n",
    "            \n",
    "            existing_samples = [k for k in range(n_samples) if collected_data[epsilon_map[epsilon], k].sum() > 0.]\n",
    "            collected_data = collected_data[epsilon_map[epsilon], existing_samples]\n",
    "            collected_data_means, collected_data_stds = np.mean(collected_data, axis=(0, 1)), np.std(collected_data, axis=(0, 1))\n",
    "            plt.plot([i+1, i+1], [collected_data_means[metric_idx, 0]-collected_data_stds[metric_idx, 0], collected_data_means[metric_idx, 0]+collected_data_stds[metric_idx, 0]], c='black', alpha=0.6)\n",
    "            plt.scatter((i+1) * np.ones(collected_data.shape[0] * collected_data.shape[1]), collected_data[:, :, metric_idx, 0].flatten(), alpha=0.2, s=70)\n",
    "            plt.scatter(i+1, collected_data_means[metric_idx, 0], c='red', s=70)\n",
    "            \n",
    "            print(dataset_name)\n",
    "            if metric_name == 'tv_error':\n",
    "                print(f'{algorithm} TV error:    ${collected_data_means[metric_idx, 0]:.2e} \\pm {collected_data_stds[metric_idx, 0]:.2e}$')\n",
    "            elif metric_name in ['xgb_acc', 'non_disc_xgb_acc']:\n",
    "                print(f'{algorithm} XGB accuracy:    ${100*collected_data_means[metric_idx, 0]:.1f} \\pm {100*collected_data_stds[metric_idx, 0]:.2f}$')\n",
    "        \n",
    "        if metric_name in dataset_xgb_metrics[dataset_name]:\n",
    "            plt.axhline(dataset_xgb_metrics[dataset_name][metric_name][0], linestyle='--', color='blue', label='True Baseline')\n",
    "            plt.axhline(dataset_xgb_metrics[dataset_name][metric_name][1], linestyle='--', color='green', label='Disc. True Baseline')\n",
    "            plt.legend()\n",
    "        \n",
    "        plt.grid(True, alpha=0.2)\n",
    "        plt.box(False)\n",
    "        plt.xticks(1 + np.arange(len(algorithms)), algorithms)\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9c7748c-6ef3-46c1-ad7d-9a01734be090",
   "metadata": {},
   "source": [
    "## Non-DP Benchmark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "290a2381-b598-426a-b14b-8c817e219436",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_samples = 5\n",
    "n_resamples = 5\n",
    "random_seed = 42\n",
    "models = ['ProgSyn']\n",
    "\n",
    "chosen_metrics = {'tv_error': 0, 'xgb_acc': 3}\n",
    "workload = 'all_three'\n",
    "workload_index_maps = {'all_two': 0, 'all_three': 1, 'all_three_with_labels': 2}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9eb1abba-6468-466c-b9aa-d79ccb463052",
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset_name in dataset_names:\n",
    "    \n",
    "    for metric_name, metric_idx in chosen_metrics.items():\n",
    "        \n",
    "        plt.figure(figsize=(8, 6))\n",
    "        plt.title(f'{metric_name} on {dataset_name}')\n",
    "        plt.xlim([0.5, len(models) + .5])\n",
    "    \n",
    "        for i, model in enumerate(models):\n",
    "            evaluation_path = f'experiment_data/non_dp_benchmarks/evaluation_data/{dataset_name}/random_seed_{random_seed}/{model}/collected_data_{model}_{dataset_name}_{n_samples}_{n_resamples}_{random_seed}_False.npy'\n",
    "            if os.path.isfile(evaluation_path):\n",
    "                collected_data = np.load(evaluation_path)\n",
    "            else:\n",
    "                print(f'Experiment of {model} on {dataset_name} not found under the given conditions')\n",
    "                continue\n",
    "            \n",
    "            selected_workload_data = collected_data[workload_index_maps[workload]]\n",
    "            existing_samples = [k for k in range(n_samples) if selected_workload_data[k].sum() > 0.]\n",
    "            selected_workload_data = selected_workload_data[existing_samples]\n",
    "            collected_data_means, collected_data_stds = np.mean(selected_workload_data, axis=(0, 1)), np.std(selected_workload_data, axis=(0, 1))\n",
    "            \n",
    "            plt.plot([i+1, i+1], [collected_data_means[metric_idx, 0]-collected_data_stds[metric_idx, 0], collected_data_means[metric_idx, 0]+collected_data_stds[metric_idx, 0]], c='black', alpha=0.6)\n",
    "            plt.scatter((i+1) * np.ones(selected_workload_data.shape[0] * selected_workload_data.shape[1]), selected_workload_data[:, :, metric_idx, 0].flatten(), alpha=0.2, s=70)\n",
    "            plt.scatter(i+1, collected_data_means[metric_idx, 0], c='red', s=70)\n",
    "            \n",
    "            print(dataset_name)\n",
    "            if metric_name == 'tv_error':\n",
    "                print(f'{model} TV error:    ${collected_data_means[metric_idx, 0]:.2e} \\pm {collected_data_stds[metric_idx, 0]:.2e}$')\n",
    "            elif metric_name in ['xgb_acc', 'non_disc_xgb_acc']:\n",
    "                print(f'{model} XGB accuracy:    ${100*collected_data_means[metric_idx, 0]:.1f} \\pm {100*collected_data_stds[metric_idx, 0]:.2f}$')\n",
    "        \n",
    "        if metric_name in dataset_xgb_metrics[dataset_name]:\n",
    "            plt.axhline(dataset_xgb_metrics[dataset_name][metric_name][0], linestyle='--', color='blue', label='True Baseline')\n",
    "            plt.axhline(dataset_xgb_metrics[dataset_name][metric_name][1], linestyle='--', color='green', label='Disc. True Baseline')\n",
    "            plt.legend()\n",
    "            print(f'True data XGB accuracy: ${100*dataset_xgb_metrics[dataset_name][metric_name][0]:.1f}$')\n",
    "        \n",
    "        plt.grid(True, alpha=0.2)\n",
    "        plt.box(False)\n",
    "        plt.xticks(1 + np.arange(len(models)), models)\n",
    "        plt.show()    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e18891a-8f25-460c-a6fe-3b920617e1e2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "992b5175-6659-473a-abef-d246c5fb73c4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "baca44b5-98ce-40e8-b61f-070690c00367",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "827950e3-14ab-4e86-85af-820a7a366dd5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72845187-3feb-4560-a2a4-c30347e08e7e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
