{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a122c34-94ca-4d95-ac68-af207585bb6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4a02743-9ef7-48d6-b6a6-7375433204e6",
   "metadata": {},
   "source": [
    "# Adult -- Non-Private"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4c54f39-11f7-4ca0-8ab1-635acdcc30b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'ADULT'\n",
    "random_seed = 42\n",
    "n_samples = 5\n",
    "n_resamples = 5\n",
    "\n",
    "x_metrics = {'XGB accuracy [%]': 3}\n",
    "statistics_map = {'mean':0, 'std': 1, 'median': 2,'min': 3, 'max': 4}\n",
    "statistic_over_queries = 'mean'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8654690-4dcf-4b4a-a232-f7c7bde784ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "program_names_and_eval_metrics = {\n",
    "    'eliminate_predictability_sex.txt':\n",
    "        ('XGB bacc. predicting sex [%]', 7),\n",
    "    'fairness_downstream_sex.txt':\n",
    "        ('Dem. Parity dist. on sex', 6),\n",
    "    'minimize_correlation_sex.txt':\n",
    "        ('Correlation sex-salary', 6),\n",
    "    'avg_age_30.txt':\n",
    "        ('Mean age', 6),\n",
    "    'avg_male_female_age.txt':\n",
    "        ('Mean age difference', 7),\n",
    "    'implication1.txt':\n",
    "        ('CSR [%]', 6),\n",
    "    'implication2.txt':\n",
    "        ('CSR [%]', 6),\n",
    "    'implication3.txt':\n",
    "        ('CSR [%]', 6),\n",
    "    'line_constraint1.txt':\n",
    "        ('CSR [%]', 6),\n",
    "    'line_constraint2.txt':\n",
    "        ('CSR [%]', 6)\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a4a4677-ceef-43a8-873b-6ada4d90aeea",
   "metadata": {},
   "outputs": [],
   "source": [
    "program_names_and_display_indices = {\n",
    "    'eliminate_predictability_sex.txt': 5,\n",
    "    'fairness_downstream_sex.txt': 6,\n",
    "    'minimize_correlation_sex.txt': 3,\n",
    "    'avg_age_30.txt': 2,\n",
    "    'avg_male_female_age.txt': 1,\n",
    "    'implication1.txt': 3,\n",
    "    'implication2.txt': 3,\n",
    "    'implication3.txt': 3,\n",
    "    'line_constraint1.txt': 1,\n",
    "    'line_constraint2.txt': 3\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1c28569-6c7d-4e66-a765-f2b4d47f926d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_specs = {\n",
    "    'eliminate_predictability_sex.txt':\n",
    "        {\n",
    "            'xticks': [[0.79, 0.80, 0.81, 0.82, 0.83, 0.84, 0.85], ['79', '80', '81', '82', '83', '84', '85']],\n",
    "            'yticks': [[0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85], ['50', '55', '60', '65', '70', '75', '80', '85']],\n",
    "            'legend': [True, 'upper left'],\n",
    "        },\n",
    "    'fairness_downstream_sex.txt':\n",
    "        {\n",
    "            'xticks': [[0.79, 0.80, 0.81, 0.82, 0.83, 0.84, 0.85], ['79', '80', '81', '82', '83', '84', '85']],\n",
    "            'yticks': [[0.000, 0.025, 0.050, 0.075, 0.100, 0.125, 0.150, 0.175, 0.200], ['0.000', '0.025', '0.050', '0.075', '0.100', '0.125', '0.150', '0.175', '0.200']],\n",
    "            'legend': [False, 'upper left']\n",
    "        },\n",
    "    'minimize_correlation_sex.txt':\n",
    "        {\n",
    "            'xticks': [[0.830, 0.8350, 0.840, 0.845, 0.850], ['83', '83.5', '84', '84.5', '85']],\n",
    "            'yticks': [[-0.04, -0.03, -0.02, -0.01, 0.0], ['-0.04', '-0.03', '-0.02', '-0.01', '0.0']],\n",
    "            'legend': [False, 'upper right']\n",
    "        },\n",
    "    'avg_age_30.txt':\n",
    "        {\n",
    "            'xticks': [[0.830, 0.8350, 0.840, 0.845, 0.850], ['83', '83.5', '84', '84.5', '85']],\n",
    "            'yticks': [[30, 31, 32, 33, 34, 35, 36, 37], ['30', '31', '32', '33', '34', '35', '36', '37']],\n",
    "            'legend': [True, 'lower right']\n",
    "        },\n",
    "    'avg_male_female_age.txt':\n",
    "        {\n",
    "            'xticks': [[0.830, 0.8350, 0.840, 0.845, 0.850], ['83', '83.5', '84', '84.5', '85']],\n",
    "            'yticks': [[0.0, 0.5, 1.0, 1.5, 2.0, 2.5], ['0.0', '0.5', '1.0', '1.5', '2.0', '2.5']],\n",
    "            'legend': [False, 'lower right']\n",
    "        },\n",
    "    'implication1.txt':\n",
    "        {\n",
    "            'xticks':[[0.830, 0.8350, 0.840, 0.845, 0.850], ['83', '83.5', '84', '84.5', '85']],\n",
    "            'yticks': [[0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0], ['93', '94', '95', '96', '97', '98', '99', '100']],\n",
    "            'legend': [False, 'lower left']\n",
    "        },\n",
    "    'implication2.txt':\n",
    "        {\n",
    "            'xticks':[[0.830, 0.8350, 0.840, 0.845, 0.850], ['83', '83.5', '84', '84.5', '85']],\n",
    "            'yticks': [[0.96, 0.97, 0.98, 0.99, 1.0], ['96', '97', '98', '99', '100']],\n",
    "            'legend': [False, 'lower left']\n",
    "        },\n",
    "    'implication3.txt':\n",
    "        {\n",
    "            'xticks':[[0.830, 0.8350, 0.840, 0.845, 0.850], ['83', '83.5', '84', '84.5', '85']],\n",
    "            'yticks': [[0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0], ['65', '70', '75', '80', '85', '90', '95', '100']],\n",
    "            'legend': [False, 'lower left']\n",
    "        },\n",
    "    'line_constraint1.txt':\n",
    "        {\n",
    "            'xticks':[[0.74, 0.76, 0.78, 0.8, 0.82, 0.84], ['74', '76', '78', '80', '82', '84']],\n",
    "            'yticks': [[0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], ['30', '40', '50','60', '70', '80', '90', '100']],\n",
    "            'legend': [False, 'lower left']\n",
    "        },\n",
    "    'line_constraint2.txt':\n",
    "        {\n",
    "            'xticks':[[0.65, 0.7, 0.75, 0.8, 0.85], ['65', '70', '75', '80', '85']],\n",
    "            'yticks': [[0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], ['30', '40', '50','60', '70', '80', '90', '100']],\n",
    "            'legend': [False, 'lower left']\n",
    "        },\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c939f716-6402-46c0-89c9-ddf4294882b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_non_dp_path = f'experiment_data/constraint_program_experiments/{dataset}/non_dp_constraints/testing_results/'\n",
    "\n",
    "for program_name, (eval_metric_name, eval_metric_idx) in program_names_and_eval_metrics.items():\n",
    "    \n",
    "    # paths\n",
    "    stripped_program_name = program_name.split('.')[0]\n",
    "    eval_save_non_dp_path = base_non_dp_path + f'{stripped_program_name}_all_three_with_labels_{n_samples}_{n_resamples}_{random_seed}_0.0.npy'\n",
    "    baseline_save_non_dp_path = base_non_dp_path + f'{stripped_program_name}_all_three_with_labels_{n_samples}_{n_resamples}_{random_seed}_0.0_baselines.npy'\n",
    "    \n",
    "    if os.path.isfile(eval_save_non_dp_path):\n",
    "        eval_non_dp_data = np.load(eval_save_non_dp_path)\n",
    "    else:\n",
    "        print(f'Eval of {stripped_program_name} under the current settings not found')\n",
    "        continue\n",
    "    \n",
    "    if os.path.isfile(baseline_save_non_dp_path):\n",
    "        baseline_non_dp_data = np.load(baseline_save_non_dp_path)\n",
    "    else:\n",
    "        print(f'Baselines of {stripped_program_name} under the current settings not found')\n",
    "        continue\n",
    "    \n",
    "    # load the command\n",
    "    load_path = f'experiment_data/constraint_program_experiments/{dataset}/non_dp_constraints/training_constraints/{program_name}'\n",
    "    with open(load_path, 'r') as f:\n",
    "        program = f.read()\n",
    "    \n",
    "    print(program)\n",
    "    display_data_non_dp = eval_non_dp_data[:, :, :, :, statistics_map[statistic_over_queries]]\n",
    "    display_data_non_dp_mean = np.mean(display_data_non_dp, axis=(1, 2))\n",
    "    display_data_non_dp_std = np.std(display_data_non_dp, axis=(1, 2))\n",
    "    display_baseline_non_dp = baseline_non_dp_data[:2, :1, :, :, statistics_map[statistic_over_queries]]\n",
    "    display_baseline_non_dp_mean = np.mean(display_baseline_non_dp, axis=(1, 2))\n",
    "    display_baseline_non_dp_std = np.std(display_baseline_non_dp, axis=(1, 2))\n",
    "    \n",
    "    for x_metric_name, x_metric_idx in x_metrics.items():\n",
    "        \n",
    "        plt.figure(figsize=(8, 6))\n",
    "        \n",
    "        if program_name.startswith('implication') or program_name.startswith('line'):\n",
    "            \n",
    "            rejection_sampling_baseline = np.load(f'experiment_data/constraint_program_experiments/ADULT/non_dp_constraints/testing_results/{stripped_program_name}_all_three_with_labels_{n_samples}_{n_resamples}_{random_seed}_0.0_rejection_sampling.npy')\n",
    "            # non dp\n",
    "            rsampling_mean = np.mean(rejection_sampling_baseline[:, :, :, statistics_map[statistic_over_queries]], axis=(0, 1))\n",
    "            rsampling_std = np.std(rejection_sampling_baseline[:, :, :, statistics_map[statistic_over_queries]], axis=(0, 1))\n",
    "            print(f'Non-DP Rsampling: ${100*rsampling_mean[x_metric_idx]:.1f} \\pm {100*rsampling_std[x_metric_idx]:.2f}$')\n",
    "            plt.scatter(rsampling_mean[x_metric_idx], rsampling_mean[eval_metric_idx], c='orange', marker='*', s=300, label='Rejection Sampling, non DP')\n",
    "            plt.scatter(rejection_sampling_baseline[:, :, x_metric_idx, statistics_map[statistic_over_queries]], rejection_sampling_baseline[:, :, eval_metric_idx, statistics_map[statistic_over_queries]], c='orange', marker='*', s=300, alpha=0.1)\n",
    "            \n",
    "        plt.scatter(display_baseline_non_dp_mean[0, x_metric_idx], display_baseline_non_dp_mean[0, eval_metric_idx], c='red', marker='X', s=200, label='True data')\n",
    "        # non DP data\n",
    "        plt.plot(display_data_non_dp_mean[:, x_metric_idx], display_data_non_dp_mean[:, eval_metric_idx], '--o', markersize=10, c='indigo', label='Fine-tuned')\n",
    "        print('\\nTrue Data')\n",
    "        print(f'{x_metric_name}: ${100*display_baseline_non_dp_mean[0, x_metric_idx]:.1f} \\pm {100*display_baseline_non_dp_std[0, x_metric_idx]:.2f}$')\n",
    "        print(f'{eval_metric_name}: ${display_baseline_non_dp_mean[0, eval_metric_idx]:.3f} \\pm {display_baseline_non_dp_std[0, eval_metric_idx]:.4f}$')\n",
    "        print('\\nNon-Private ProgSyn Fine-Tuned')\n",
    "        print(f'{x_metric_name}: ${100*display_data_non_dp_mean[:, x_metric_idx][program_names_and_display_indices[program_name]]:.1f} \\pm {100*display_data_non_dp_std[:, x_metric_idx][program_names_and_display_indices[program_name]]:.2f}$')\n",
    "        print(f'{eval_metric_name}: ${display_data_non_dp_mean[:, eval_metric_idx][program_names_and_display_indices[program_name]]:.3f} \\pm {display_data_non_dp_std[:, eval_metric_idx][program_names_and_display_indices[program_name]]:.4f}$')\n",
    "        print('\\nNon-Private ProgSyn Base')\n",
    "        print(f'{x_metric_name}: ${100*display_baseline_non_dp_mean[1, x_metric_idx]:.1f} \\pm {100*display_baseline_non_dp_std[1, x_metric_idx]:.2f}$')\n",
    "        print(f'{eval_metric_name}: ${display_baseline_non_dp_mean[1, eval_metric_idx]:.3f} \\pm {display_baseline_non_dp_std[1, eval_metric_idx]:.4f}$')\n",
    "        plt.scatter(display_baseline_non_dp_mean[1, x_metric_idx], display_baseline_non_dp_mean[1, eval_metric_idx], c='indigo', marker='X', s=200, label='No Fine-tuning')\n",
    "        # layout setups\n",
    "        plt.xlabel(x_metric_name, fontsize=20, labelpad=10)\n",
    "        plt.ylabel(eval_metric_name, fontsize=20, labelpad=10)\n",
    "        plt.grid(True, alpha=0.1)\n",
    "        plt.tick_params(axis='both', length=0)\n",
    "        if x_metric_name == 'XGB accuracy [%]':\n",
    "            plt.xticks(plot_specs[program_name]['xticks'][0], plot_specs[program_name]['xticks'][1], fontsize=20)\n",
    "            # plt.yticks(plot_specs[program_name]['yticks'][0], plot_specs[program_name]['yticks'][1], fontsize=20)\n",
    "        plt.box(False)\n",
    "        if plot_specs[program_name]['legend'][0]:\n",
    "            plt.legend(fontsize=20, loc=plot_specs[program_name]['legend'][1])\n",
    "        plt.tight_layout()\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01b7d05d-19bd-44b2-81cc-efa5b25e7fe0",
   "metadata": {},
   "source": [
    "# Adult -- Private"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "093d7bb8-3d8e-476e-a0de-befa5cd109d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'ADULT'\n",
    "random_seed = 42\n",
    "n_samples = 5\n",
    "n_resamples = 5\n",
    "\n",
    "x_metrics = {'XGB accuracy [%]': 3}\n",
    "statistics_map = {'mean':0, 'std': 1, 'median': 2,'min': 3, 'max': 4}\n",
    "statistic_over_queries = 'mean'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e998fd28-e2dc-4cf7-b790-502549f23493",
   "metadata": {},
   "outputs": [],
   "source": [
    "program_names_and_eval_metrics = {\n",
    "    'fairness_downstream_sex.txt':\n",
    "        ('Dem. Parity dist. on sex', 6),\n",
    "    'implication1.txt':\n",
    "        ('CSR [%]', 6),\n",
    "    'implication2.txt':\n",
    "        ('CSR [%]', 6),\n",
    "    'implication3.txt':\n",
    "        ('CSR [%]', 6),\n",
    "    'line_constraint1.txt':\n",
    "        ('CSR [%]', 6),\n",
    "    'line_constraint2.txt':\n",
    "        ('CSR [%]', 6)\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b412890-9e3e-4532-b62a-5a715578732f",
   "metadata": {},
   "outputs": [],
   "source": [
    "program_names_and_display_indices = {\n",
    "    'fairness_downstream_sex.txt': 6,\n",
    "    'implication1.txt': 4,\n",
    "    'implication2.txt': 1,\n",
    "    'implication3.txt': 3,\n",
    "    'line_constraint1.txt': 3,\n",
    "    'line_constraint2.txt': 1\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6f42ff8-0d21-4863-b15f-d152eb95d7af",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_specs = {\n",
    "    'fairness_downstream_sex.txt':\n",
    "        {\n",
    "            'xticks': [[0.79, 0.80, 0.81, 0.82, 0.83, 0.84, 0.85], ['79', '80', '81', '82', '83', '84', '85']],\n",
    "            'yticks': [[0.000, 0.025, 0.050, 0.075, 0.100, 0.125, 0.150, 0.175, 0.200], ['0.000', '0.025', '0.050', '0.075', '0.100', '0.125', '0.150', '0.175', '0.200']],\n",
    "            'legend': [False, 'upper left']\n",
    "        },\n",
    "    'implication1.txt':\n",
    "        {\n",
    "            'xticks':[[0.830, 0.8350, 0.840, 0.845, 0.850], ['83', '83.5', '84', '84.5', '85']],\n",
    "            'yticks': [[0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0], ['93', '94', '95', '96', '97', '98', '99', '100']],\n",
    "            'legend': [False, 'lower left']\n",
    "        },\n",
    "    'implication2.txt':\n",
    "        {\n",
    "            'xticks':[[0.830, 0.8350, 0.840, 0.845, 0.850], ['83', '83.5', '84', '84.5', '85']],\n",
    "            'yticks': [[0.96, 0.97, 0.98, 0.99, 1.0], ['96', '97', '98', '99', '100']],\n",
    "            'legend': [False, 'lower left']\n",
    "        },\n",
    "    'implication3.txt':\n",
    "        {\n",
    "            'xticks':[[0.830, 0.8350, 0.840, 0.845, 0.850], ['83', '83.5', '84', '84.5', '85']],\n",
    "            'yticks': [[0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0], ['65', '70', '75', '80', '85', '90', '95', '100']],\n",
    "            'legend': [False, 'lower left']\n",
    "        },\n",
    "    'line_constraint1.txt':\n",
    "        {\n",
    "            'xticks':[[0.74, 0.76, 0.78, 0.8, 0.82, 0.84], ['74', '76', '78', '80', '82', '84']],\n",
    "            'yticks': [[0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], ['30', '40', '50','60', '70', '80', '90', '100']],\n",
    "            'legend': [False, 'lower left']\n",
    "        },\n",
    "    'line_constraint2.txt':\n",
    "        {\n",
    "            'xticks':[[0.65, 0.7, 0.75, 0.8, 0.85], ['65', '70', '75', '80', '85']],\n",
    "            'yticks': [[0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], ['30', '40', '50','60', '70', '80', '90', '100']],\n",
    "            'legend': [False, 'lower left']\n",
    "        },\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "731dca60-3109-4ec6-bcc0-3007dd3b6e10",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_dp_path = f'experiment_data/constraint_program_experiments/{dataset}/dp_constraints/testing_results/'\n",
    "\n",
    "for program_name, (eval_metric_name, eval_metric_idx) in program_names_and_eval_metrics.items():\n",
    "    \n",
    "    # paths\n",
    "    stripped_program_name = program_name.split('.')[0]\n",
    "    eval_save_dp_path = base_dp_path + f'{stripped_program_name}_dp_all_three_{n_samples}_{n_resamples}_{random_seed}_1.0.npy'\n",
    "    baseline_save_dp_path = base_dp_path + f'{stripped_program_name}_dp_all_three_{n_samples}_{n_resamples}_{random_seed}_1.0_baselines.npy'\n",
    "        \n",
    "    if os.path.isfile(eval_save_dp_path):\n",
    "        eval_dp_data = np.load(eval_save_dp_path)\n",
    "    else:\n",
    "        print(f'Eval of {stripped_program_name} under the current settings not found')\n",
    "        continue\n",
    "        \n",
    "    if os.path.isfile(baseline_save_dp_path):\n",
    "        baseline_dp_data = np.load(baseline_save_dp_path)\n",
    "    else:\n",
    "        print(f'Baselines of {stripped_program_name} under the current settings not found')\n",
    "    \n",
    "    # load the command\n",
    "    load_path = f'experiment_data/constraint_program_experiments/{dataset}/non_dp_constraints/training_constraints/{program_name}'\n",
    "    with open(load_path, 'r') as f:\n",
    "        program = f.read()\n",
    "    \n",
    "    print(program)\n",
    "    display_data_dp = eval_dp_data[:, :, :, :, statistics_map[statistic_over_queries]]\n",
    "    display_data_dp_mean = np.mean(display_data_dp, axis=(1, 2))\n",
    "    display_data_dp_std = np.std(display_data_dp, axis=(1, 2))\n",
    "    display_baseline_dp = baseline_dp_data[:2, :1, :, :, statistics_map[statistic_over_queries]]\n",
    "    display_baseline_dp_mean = np.mean(display_baseline_dp, axis=(1, 2))\n",
    "    display_baseline_dp_std = np.std(display_baseline_dp, axis=(1, 2))\n",
    "    \n",
    "    for x_metric_name, x_metric_idx in x_metrics.items():\n",
    "        \n",
    "        plt.figure(figsize=(8, 6))\n",
    "        \n",
    "        if program_name.startswith('implication') or program_name.startswith('line'):\n",
    "            \n",
    "            rejection_sampling_baseline_dp = np.load(f'experiment_data/constraint_program_experiments/ADULT/dp_constraints/testing_results/{stripped_program_name}_dp_all_three_{n_samples}_{n_resamples}_{random_seed}_1.0_rejection_sampling.npy')\n",
    "            # dp\n",
    "            rsampling_mean_dp = np.mean(rejection_sampling_baseline_dp[:, :, :, statistics_map[statistic_over_queries]], axis=(0, 1))\n",
    "            rsampling_std_dp = np.std(rejection_sampling_baseline_dp[:, :, :, statistics_map[statistic_over_queries]], axis=(0, 1))\n",
    "            print(f'DP Rsampling:     ${100*rsampling_mean_dp[x_metric_idx]:.1f} \\pm {100*rsampling_std_dp[x_metric_idx]:.2f}$')\n",
    "            plt.scatter(rsampling_mean_dp[x_metric_idx], rsampling_mean_dp[eval_metric_idx], c='green', marker='*', s=300, label='Rejection Sampling, DP')\n",
    "            plt.scatter(rejection_sampling_baseline_dp[:, :, x_metric_idx, statistics_map[statistic_over_queries]], rejection_sampling_baseline_dp[:, :, eval_metric_idx, statistics_map[statistic_over_queries]], c='green', marker='*', s=300, alpha=0.1)\n",
    "            \n",
    "        # DP data\n",
    "        plt.plot(display_data_dp_mean[:, x_metric_idx], display_data_dp_mean[:, eval_metric_idx], '--o', markersize=10, c='cornflowerblue', label='DP: fine-tuned')\n",
    "        print('\\nPrivate ProgSyn Fine-Tuned')\n",
    "        print(f'{x_metric_name}: ${100*display_data_dp_mean[:, x_metric_idx][program_names_and_display_indices[program_name]]:.1f} \\pm {100*display_data_dp_std[:, x_metric_idx][program_names_and_display_indices[program_name]]:.2f}$')\n",
    "        print(f'{eval_metric_name}: ${display_data_dp_mean[:, eval_metric_idx][program_names_and_display_indices[program_name]]:.3f} \\pm {display_data_dp_std[:, eval_metric_idx][program_names_and_display_indices[program_name]]:.4f}$')\n",
    "        print('\\nPrivate ProgSyn Base')\n",
    "        print(f'{x_metric_name}: ${100*display_baseline_dp_mean[1, x_metric_idx]:.1f} \\pm {100*display_baseline_dp_std[1, x_metric_idx]:.2f}$')\n",
    "        print(f'{eval_metric_name}: ${display_baseline_dp_mean[1, eval_metric_idx]:.3f} \\pm {display_baseline_dp_std[1, eval_metric_idx]:.4f}$')\n",
    "        # plt.scatter(display_data_dp[:, :, :, x_metric_idx].flatten(), display_data_dp[:, :, :, eval_metric_idx].flatten(), s=50, alpha=0.1, c='cornflowerblue')\n",
    "        plt.scatter(display_baseline_dp_mean[1, x_metric_idx], display_baseline_dp_mean[1, eval_metric_idx], c='cornflowerblue', marker='X', s=200, label='DP: No fine-tuning')\n",
    "        # plt.scatter(display_data_dp_mean[0, x_metric_idx], display_data_dp_mean[0, eval_metric_idx], c='green', marker='X', s=200, label='DP: No fine-tuning')\n",
    "        # layout setups\n",
    "        plt.xlabel(x_metric_name, fontsize=20, labelpad=10)\n",
    "        plt.ylabel(eval_metric_name, fontsize=20, labelpad=10)\n",
    "        plt.grid(True, alpha=0.1)\n",
    "        plt.tick_params(axis='both', length=0)\n",
    "        if x_metric_name == 'XGB accuracy [%]':\n",
    "            plt.xticks(plot_specs[program_name]['xticks'][0], plot_specs[program_name]['xticks'][1], fontsize=20)\n",
    "            # plt.yticks(plot_specs[program_name]['yticks'][0], plot_specs[program_name]['yticks'][1], fontsize=20)\n",
    "        plt.box(False)\n",
    "        if plot_specs[program_name]['legend'][0]:\n",
    "            plt.legend(fontsize=20, loc=plot_specs[program_name]['legend'][1])\n",
    "        plt.tight_layout()\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b4fbeff-c88f-4c7d-9a7f-749634c40ff5",
   "metadata": {},
   "source": [
    "# HealthHeritage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecb1eb8e-6a4e-4643-937f-2e7bf9fbd569",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'HealthHeritage'\n",
    "random_seed = 42\n",
    "n_samples = 5\n",
    "n_resamples = 5\n",
    "\n",
    "x_metrics = {'XGB accuracy [%]': 3}\n",
    "statistics_map = {'mean': 0, 'std': 1, 'median': 2,'min': 3, 'max': 4}\n",
    "statistic_over_queries = 'mean'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a043b2b7-ef1a-4739-bd38-b293052e0b75",
   "metadata": {},
   "outputs": [],
   "source": [
    "program_names_and_eval_metrics_for_thesis = {\n",
    "    'implication1.txt':\n",
    "        ('CSR [%]', 6),\n",
    "    'line_constraint1.txt':\n",
    "        ('CSR [%]', 6)\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0b8ce90-fa52-45b1-94b0-534c80bbcc5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "program_names_and_display_indices = {\n",
    "    'implication1.txt': 1,\n",
    "    'line_constraint1.txt': 3\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29221bc4-4923-401c-9f8f-7893cb3f24a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_specs = {\n",
    "    'implication1.txt':\n",
    "        {\n",
    "            'xticks': [[0.81, 0.82, 0.83, 0.84, 0.85], ['81', '82', '83', '84', '85']],\n",
    "            'yticks': [[0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], ['30', '40', '50', '60', '70', '80', '90', '100']],\n",
    "            'legend': [True, 'lower left']\n",
    "        },\n",
    "    'line_constraint1.txt':\n",
    "        {\n",
    "            'xticks':[[0.74, 0.76, 0.78, 0.8, 0.82, 0.84], ['74', '76', '78', '80', '82', '84']],\n",
    "            'yticks': [[0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], ['30', '40', '50','60', '70', '80', '90', '100']],\n",
    "            'legend': [False, 'lower left']\n",
    "        }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5794170a-31d5-477e-ae26-fcefa1f35fd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_non_dp_path = f'experiment_data/constraint_program_experiments/{dataset}/non_dp_constraints/testing_results/'\n",
    "\n",
    "for program_name, (eval_metric_name, eval_metric_idx) in program_names_and_eval_metrics_for_thesis.items():\n",
    "    \n",
    "    # paths\n",
    "    stripped_program_name = program_name.split('.')[0]\n",
    "    eval_save_non_dp_path = base_non_dp_path + f'{stripped_program_name}_all_three_with_labels_{n_samples}_{n_resamples}_{random_seed}_0.0.npy'\n",
    "    baseline_save_non_dp_path = base_non_dp_path + f'{stripped_program_name}_all_three_with_labels_{n_samples}_{n_resamples}_{random_seed}_0.0_baselines.npy'\n",
    "    \n",
    "    if os.path.isfile(eval_save_non_dp_path):\n",
    "        eval_non_dp_data = np.load(eval_save_non_dp_path)\n",
    "    else:\n",
    "        print(f'Eval of {stripped_program_name} under the current settings not found')\n",
    "    \n",
    "    if os.path.isfile(baseline_save_non_dp_path):\n",
    "        baseline_non_dp_data = np.load(baseline_save_non_dp_path)\n",
    "    else:\n",
    "        print(f'Baselines of {stripped_program_name} under the current settings not found')\n",
    "    \n",
    "    # load the command\n",
    "    load_path = f'experiment_data/constraint_program_experiments/{dataset}/non_dp_constraints/training_constraints/{program_name}'\n",
    "    with open(load_path, 'r') as f:\n",
    "        program = f.read()\n",
    "    \n",
    "    print(program)\n",
    "    display_data_non_dp = eval_non_dp_data[:, :, :, :, statistics_map[statistic_over_queries]]\n",
    "    display_data_non_dp_mean = np.mean(display_data_non_dp, axis=(1, 2))\n",
    "    display_data_non_dp_std = np.std(display_data_non_dp, axis=(1, 2))\n",
    "    display_baseline_non_dp = baseline_non_dp_data[:2, :1, :, :, statistics_map[statistic_over_queries]]\n",
    "    display_baseline_non_dp_mean = np.mean(display_baseline_non_dp, axis=(1, 2))\n",
    "    display_baseline_non_dp_std = np.std(display_baseline_non_dp, axis=(1, 2))\n",
    "    \n",
    "    for x_metric_name, x_metric_idx in x_metrics.items():\n",
    "        \n",
    "        plt.figure(figsize=(8, 6))\n",
    "        \n",
    "        if program_name.startswith('implication') or program_name.startswith('line'):\n",
    "            \n",
    "            rejection_sampling_baseline = np.load(f'experiment_data/constraint_program_experiments/{dataset}/non_dp_constraints/testing_results/{stripped_program_name}_all_three_with_labels_{n_samples}_{n_resamples}_{random_seed}_0.0_rejection_sampling.npy')\n",
    "            # non dp\n",
    "            rsampling_mean = np.mean(rejection_sampling_baseline[:, :, :, statistics_map[statistic_over_queries]], axis=(0, 1))\n",
    "            rsampling_std = np.std(rejection_sampling_baseline[:, :, :, statistics_map[statistic_over_queries]], axis=(0, 1))\n",
    "            print(f'\\nRsampling: ${100*rsampling_mean[x_metric_idx]:.1f} \\pm {rsampling_std[x_metric_idx]:.2f}$')\n",
    "            plt.scatter(rsampling_mean[x_metric_idx], rsampling_mean[eval_metric_idx], c='orange', marker='*', s=300, label='Rejection Sampling, non DP')\n",
    "            plt.scatter(rejection_sampling_baseline[:, :, x_metric_idx, statistics_map[statistic_over_queries]], rejection_sampling_baseline[:, :, eval_metric_idx, statistics_map[statistic_over_queries]], c='orange', marker='*', s=300, alpha=0.1)            \n",
    "            \n",
    "        plt.scatter(display_baseline_non_dp_mean[0, x_metric_idx], display_baseline_non_dp_mean[0, eval_metric_idx], c='red', marker='X', s=200, label='True data')\n",
    "        # non DP data\n",
    "        #print(display_data_non_dp_mean[:, x_metric_idx])\n",
    "        plt.plot(display_data_non_dp_mean[:, x_metric_idx], display_data_non_dp_mean[:, eval_metric_idx], '--o', markersize=10, c='indigo', label='Fine-tuned')\n",
    "        print('\\nTrue Data')\n",
    "        print(f'{x_metric_name}: ${100*display_baseline_non_dp_mean[0, x_metric_idx]:.1f} \\pm {100*display_baseline_non_dp_std[0, x_metric_idx]:.2f}$')\n",
    "        print(f'{eval_metric_name}: ${display_baseline_non_dp_mean[0, eval_metric_idx]:.3f} \\pm {display_baseline_non_dp_std[0, eval_metric_idx]:.4f}$')\n",
    "        print('\\nNon-Private ProgSyn Fine-Tuned')\n",
    "        print(f'{x_metric_name}: ${100*display_data_non_dp_mean[:, x_metric_idx][program_names_and_display_indices[program_name]]:.1f} \\pm {100*display_data_non_dp_std[:, x_metric_idx][program_names_and_display_indices[program_name]]:.2f}$')\n",
    "        print(f'{eval_metric_name}: ${display_data_non_dp_mean[:, eval_metric_idx][program_names_and_display_indices[program_name]]:.3f} \\pm {display_data_non_dp_std[:, eval_metric_idx][program_names_and_display_indices[program_name]]:.4f}$')\n",
    "        print('\\nNon-Private ProgSyn Base')\n",
    "        print(f'{x_metric_name}: ${100*display_baseline_non_dp_mean[1, x_metric_idx]:.1f} \\pm {100*display_baseline_non_dp_std[1, x_metric_idx]:.2f}$')\n",
    "        print(f'{eval_metric_name}: ${display_baseline_non_dp_mean[1, eval_metric_idx]:.3f} \\pm {display_baseline_non_dp_std[1, eval_metric_idx]:.4f}$')\n",
    "        \n",
    "        plt.scatter(display_baseline_non_dp_mean[1, x_metric_idx], display_baseline_non_dp_mean[1, eval_metric_idx], c='indigo', marker='X', s=200, label='No Fine-tuning')\n",
    "        # layout setups\n",
    "        plt.xlabel(x_metric_name, fontsize=20, labelpad=10)\n",
    "        plt.ylabel(eval_metric_name, fontsize=20, labelpad=10)\n",
    "        plt.grid(True, alpha=0.1)\n",
    "        plt.tick_params(axis='both', length=0)\n",
    "        plt.box(False)\n",
    "        if plot_specs[program_name]['legend'][0]:\n",
    "            plt.legend(fontsize=20, loc=plot_specs[program_name]['legend'][1])\n",
    "        plt.tight_layout()\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fba7b6e0-3278-4b78-b05f-aaa3e2f3a7d7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
