{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ac51f8c-5abe-4d67-b6b9-506ab713d998",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tabular_datasets import HealthHeritage\n",
    "from constraints import ConstraintEvaluator\n",
    "from query import query_marginal\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "from programmable_synthesizer import ProgrammableSynthesizer\n",
    "from utils import evaluate_sampled_dataset "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81bc5c81-d98b-46cb-8d75-6865dc8adc4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set the random seed\n",
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb03dcca-0428-4fe9-a261-dab6253fd8a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = HealthHeritage(device='cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f43ef36b-4911-45ae-aec2-1cc43db3e28d",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_one_hot_train = dataset.get_Dtrain_full_one_hot(return_torch=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7a0f8b1-6d1a-44be-bc2f-9672785e0fb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_name = 'PrimaryConditionGroup'\n",
    "param = 0.00008"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21010f67-3691-4329-94a8-2cd6e9334b71",
   "metadata": {},
   "outputs": [],
   "source": [
    "program = f\"\"\"\n",
    "SYNTHESIZE: HealthHeritage;\n",
    "    MAXIMIZE: STATISTICAL: PARAM {param}:\n",
    "        H[{feature_name}];\n",
    "END;\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9addff9-6c55-4191-9cce-5bbbab7627c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "progsyn = ProgrammableSynthesizer(\n",
    "    constraint_program=program,\n",
    "    workload='all_three_with_labels'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "701683ad-8602-4ca8-8f2e-d557e6521d65",
   "metadata": {},
   "outputs": [],
   "source": [
    "progsyn.fit(verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2da93a28-817c-4b61-b696-d3a3242ba35d",
   "metadata": {},
   "outputs": [],
   "source": [
    "synthetic_data = progsyn.generate_data(len(dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38722f73-ada6-48c8-a359-b624dd263222",
   "metadata": {},
   "outputs": [],
   "source": [
    "primaryconditiongroup_synth_marginal = query_marginal(synthetic_data, (feature_name, ), dataset.full_one_hot_index_map, input_torch=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df5b1160-b268-489b-8c8d-176453122f54",
   "metadata": {},
   "outputs": [],
   "source": [
    "primaryconditiongroup_marginal = query_marginal(full_one_hot_train, (feature_name, ), dataset.full_one_hot_index_map, input_torch=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "755c9d03-bb83-41ba-aed0-199d2d4c2f9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, _, _, xgbacc, _, _ = evaluate_sampled_dataset(\n",
    "    synthetic_dataset=synthetic_data,\n",
    "    workload=[(feature_name,)],\n",
    "    true_measured_workload={(feature_name,): primaryconditiongroup_marginal},\n",
    "    dataset=dataset,\n",
    "    max_slice=1000,\n",
    "    random_seed=42\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee33b6e4-db7f-4889-b66d-8bc7436ac3e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'XGB accuracy: {100*xgbacc[0]:.1f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6c8b326-f613-4e5a-8b0f-c1576588b81e",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 8))\n",
    "plt.bar(np.arange(len(dataset.features[feature_name])) - 0.2, primaryconditiongroup_marginal.cpu().numpy(), width=0.4, color='pink', alpha=1, label='Original')\n",
    "plt.bar(np.arange(len(dataset.features[feature_name])) + 0.2, primaryconditiongroup_synth_marginal.cpu().numpy(), width=0.4, color='indigo', alpha=1, label='ProgSyn')\n",
    "plt.xticks(rotation=90)\n",
    "plt.box(False)\n",
    "plt.xticks([])\n",
    "plt.yticks([])\n",
    "plt.legend(fontsize=35, loc='upper left')\n",
    "plt.xlabel('Patient Condition Distribution', fontsize=50, labelpad=15)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d6166bb-cb5c-404a-9652-842aa00824e9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
