{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1ced8cba-91f1-4760-9871-54ea8a3167e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import bnlearn as bn\n",
    "import os\n",
    "import sys\n",
    "parent_dir = os.path.abspath('../..')\n",
    "sys.path.append(parent_dir)\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from eval_utils import compare_MLE, data_profiling, cate_info, col_check,  transform_label\n",
    "from openai import AzureOpenAI\n",
    "from sklearn.preprocessing import LabelBinarizer\n",
    "import os\n",
    "import json\n",
    "import pickle\n",
    "from model_glm import MultiAgentGAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e65b78fb-6fb7-47fb-9120-5b3802ba905f",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv('../../../data/medical_insurance.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0a0ea2d6-b84e-4811-acef-6e113a53e802",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_cols = ['age', 'sex', 'bmi', 'children', 'smoker', 'region']\n",
    "y_col = 'charges'\n",
    "bool_var = []\n",
    "dummy_cols = list(pd.get_dummies(data[x_cols]))\n",
    "cols = x_cols + [y_col]\n",
    "cate_var = ['sex', 'smoker', 'region']\n",
    "num_var = list(set(cols) - set(cate_var))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "92c0d335-4882-4463-90a4-752222739d11",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_test = data.sample(200, random_state=42)\n",
    "data_curated = data.drop(data_test.index)\n",
    "data100 = data_curated.sample(n = 100, replace=False, random_state=42)\n",
    "data200 = data_curated.sample(n = 200, replace=False, random_state=42)\n",
    "data400 = data_curated.sample(n = 400, replace=False, random_state=42)\n",
    "data800 = data_curated.sample(n = 800, replace=False, random_state=42)\n",
    "data100.reset_index(inplace=True, drop=True)\n",
    "data200.reset_index(inplace=True, drop=True)\n",
    "data400.reset_index(inplace=True, drop=True)\n",
    "data800.reset_index(inplace=True, drop=True)\n",
    "\n",
    "data_test.to_csv(\"../sample/Insurance/data_test.csv\", index=False)\n",
    "data100.to_csv(\"../sample/Insurance/data100.csv\", index=False)\n",
    "data200.to_csv(\"../sample/Insurance/data200.csv\", index=False)\n",
    "data400.to_csv(\"../sample/Insurance/data400.csv\", index=False)\n",
    "data800.to_csv(\"../sample/Insurance/data800.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0c8530ea-6f72-4416-83d1-cccde7c8dcc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "api_version = \"2023-05-15\"\n",
    "\n",
    "api_key1 = \"\"# API KEY\n",
    "azure_endpoint1 = \"\"# END point\n",
    "\n",
    "api_key2 = \"\"# API KEY\n",
    "azure_endpoint2 = \"\" # END POINT\n",
    "\n",
    "gen_client = AzureOpenAI(\n",
    "            api_key = api_key1,\n",
    "            api_version = api_version,\n",
    "            azure_endpoint = azure_endpoint1\n",
    "        )\n",
    "\n",
    "opt_client = AzureOpenAI(\n",
    "    api_key = api_key2,\n",
    "    api_version = api_version, \n",
    "    azure_endpoint = azure_endpoint2\n",
    ")\n",
    "\n",
    "gen_model_nm = 'generator4'\n",
    "\n",
    "opt_model_nm = 'yaobin_gpt4'\n",
    "\n",
    "params = {\n",
    "    'max_depth': 3,  # the maximum depth of each tree\n",
    "    'eta': 0.3,  # the training step for each iteration\n",
    "    'objective': 'binary:logistic',  # logistic regression for binary classification, output probability\n",
    "    'eval_metric': 'logloss'  # evaluation metric for binary classification\n",
    "}\n",
    "data_desc = \"The dataset include people's social economic factors and demographics with the label that indicates whether their income is higher than 50k.\"\n",
    "num_var = list((set(cols) - set(cate_var)) - set(bool_var))\n",
    "num_var"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b3014d6d-de62-422c-8cbd-480658945e4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = {}  \n",
    "model_dict = {}\n",
    "num_samples = [100, 200, 400, 800]\n",
    "sampled_datasets = [data100, data200, data400, data800]\n",
    "epochs = [5, 4, 3, 2]\n",
    "batch_size = 50\n",
    "seed_lst = list(range(5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a25c7b30-b986-4096-8436-b16e0cea04af",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for i in range(1, 4):\n",
    "    num = num_samples[i]\n",
    "    sampled_data = sampled_datasets[i]\n",
    "    epoch = epochs[i]\n",
    "    res[str(num)] = {}\n",
    "    df_comb = pd.concat([sampled_data, data_test])\n",
    "    df_comb.reset_index(inplace=True, drop = True)\n",
    "    metadata = data_profiling(df_comb, cate_var, bool_var, cols)\n",
    "    dummy_cols = list(pd.get_dummies(df_comb[x_cols]).columns)\n",
    "    cate_desc = cate_info(df_comb, cate_var)\n",
    "    log_file = f'log/insurance_{num}.txt'\n",
    "    magan = MultiAgentGAN(gen_client, opt_client, gen_model_nm, opt_model_nm, params, sampled_data, cols, y_col, num_var, metadata, cate_desc, data_desc, log_file, opt_temperature=1, real_samples_num=4)\n",
    "    magan._run(batch_size, epoch)\n",
    "    magan.run_with_fixed_discriminator(0, epoch, batch_size, 1)\n",
    "    model_dict[str(num)] = magan\n",
    "    \n",
    "    for j in range(len(seed_lst)):\n",
    "        seed = seed_lst[j]\n",
    "        res[str(num)][str(seed)] = {}\n",
    "        mallm_syn_df = magan.gen_without_optimization(num_folds=1)\n",
    "        mallm_syn_df.to_csv(f'gen/Insurance/{num}/df_{seed}.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5dd33709-1fcb-496b-891c-f457576ddddd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10",
   "language": "python",
   "name": "py3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
