{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e06a2105-628f-43f1-9f13-1267ecf17b58",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "parent_dir = os.path.abspath('../..')\n",
    "sys.path.append(parent_dir)\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from eval_utils import compare_MLE, data_profiling, cate_info, col_check,  transform_label\n",
    "from openai import AzureOpenAI\n",
    "from sklearn.preprocessing import LabelBinarizer\n",
    "import os\n",
    "import json\n",
    "import pickle\n",
    "from model_glm import MultiAgentGAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d70b0fb1-a4db-4a0a-b3b6-5e39009876c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv('../../../data/erich.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5b6491f8-9f04-4042-86a5-cf17bca5ee70",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.drop(columns=['ID'], inplace=True)\n",
    "data = data.dropna()\n",
    "data.reset_index(inplace=True, drop=True)\n",
    "data.rename(columns={'OUTCOME_mRS90': 'mRS score after 90 days'}, inplace=True)\n",
    "x_cols = ['sex', 'age', 'race', 'ICH volume', 'IVH volume',\n",
    "       'ICH location', 'GCS score', 'Systolic blood pressure',\n",
    "       'Diastolic blood pressure', 'Hypertension', 'Hyperlipidemia',\n",
    "       'Type 2 diabetes', 'Type 1 diabetes', 'Atrial Fibrillation',\n",
    "       'Congestive Heart Failure', 'Vascular disease', 'Cocaine use', 'Smoke',\n",
    "       'PTCA', 'Myocardial infraction', 'anti-diabetic', 'Anti-hypertensives',\n",
    "       'White blood cell count', 'Hemoglobin', 'Platelet count', 'APTT', 'INR',\n",
    "       'Glucose']\n",
    "y_col = 'mRS score after 90 days'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1a8bf841-315c-47c5-b3f0-f25ec77dfdcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "cols = x_cols + [y_col]\n",
    "cate_var = ['race', 'sex', 'ICH location']\n",
    "bool_var = ['Hypertension', 'Hyperlipidemia',\n",
    "       'Type 2 diabetes', 'Type 1 diabetes', 'Atrial Fibrillation',\n",
    "       'Congestive Heart Failure', 'Vascular disease', 'Cocaine use', 'Smoke',\n",
    "       'PTCA', 'Myocardial infraction', 'anti-diabetic', 'Anti-hypertensives']\n",
    "num_var = list((set(cols) - set(cate_var)) - set(bool_var))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d7254f09-dde3-49e2-bbcb-718b27590815",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = data[cols].copy()\n",
    "\n",
    "data_test = data.sample(n=200, random_state=42)\n",
    "data_curated = data.drop(data_test.index)\n",
    "data100 = data_curated.sample(n = 100, replace=False, random_state=42)\n",
    "data200 = data_curated.sample(n = 200, replace=False, random_state=42)\n",
    "data400 = data_curated.sample(n = 400, replace=False, random_state=42)\n",
    "data800 = data_curated.sample(n = 800, replace=False, random_state=42)\n",
    "data100.reset_index(inplace=True, drop=True)\n",
    "data200.reset_index(inplace=True, drop=True)\n",
    "data400.reset_index(inplace=True, drop=True)\n",
    "data800.reset_index(inplace=True, drop=True)\n",
    "\n",
    "data_test.to_csv(\"../sample/ERICH/data_test.csv\", index=False)\n",
    "data100.to_csv(\"../sample/ERICH/data100.csv\", index=False)\n",
    "data200.to_csv(\"../sample/ERICH/data200.csv\", index=False)\n",
    "data400.to_csv(\"../sample/ERICH/data400.csv\", index=False)\n",
    "data800.to_csv(\"../sample/ERICH/data800.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "055bce1e-de2e-4844-a56d-47aa9fda5b3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "api_version = \"2023-05-15\"\n",
    "\n",
    "api_key1 = \"\"# API KEY\n",
    "azure_endpoint1 = \"\"# END point\n",
    "\n",
    "api_key2 = \"\"# API KEY\n",
    "azure_endpoint2 = \"\" # END POINT\n",
    "\n",
    "gen_client = AzureOpenAI(\n",
    "            api_key = api_key1,\n",
    "            api_version = api_version,\n",
    "            azure_endpoint = azure_endpoint1\n",
    "        )\n",
    "\n",
    "opt_client = AzureOpenAI(\n",
    "    api_key = api_key2,\n",
    "    api_version = api_version, \n",
    "    azure_endpoint = azure_endpoint2\n",
    ")\n",
    "\n",
    "gen_model_nm = 'generator4'\n",
    "\n",
    "opt_model_nm = 'yaobin_gpt4'\n",
    "\n",
    "params = {\n",
    "    'max_depth': 3,  # the maximum depth of each tree\n",
    "    'eta': 0.3,  # the training step for each iteration\n",
    "    'objective': 'binary:logistic',  # logistic regression for binary classification, output probability\n",
    "    'eval_metric': 'logloss'  # evaluation metric for binary classification\n",
    "}\n",
    "data_desc = \"The dataset include people's social economic factors and demographics with the label that indicates whether their income is higher than 50k.\"\n",
    "num_var = list((set(cols) - set(cate_var)) - set(bool_var))\n",
    "num_var"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "359d3058-129e-49fa-aa4c-f922feae1443",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_desc = \"The data is from a case-control study of Intracerebral Hemorrhage study which aims to investigate in the Ethnic/Racial variations.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "22aaca1d-cf42-4f84-96e0-7d3d4161a872",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = {}  \n",
    "model_dict = {}\n",
    "num_samples = [100, 200, 400, 800]\n",
    "sampled_datasets = [data100, data200, data400, data800]\n",
    "epochs = [5, 4, 3, 2]\n",
    "batch_size = 50\n",
    "seed_lst = list(range(5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63c32fbd-9adf-4b2d-96f9-ea55785214a0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "df_test = pd.read_csv(\"../sample/ERICH/data_test.csv\")\n",
    "for i in range(1, 2):\n",
    "    num = num_samples[i]\n",
    "    sampled_data = sampled_datasets[i]\n",
    "    epoch = epochs[i]\n",
    "    res[str(num)] = {}\n",
    "    df_comb = pd.concat([sampled_data, df_test])\n",
    "    df_comb.reset_index(inplace=True, drop = True)\n",
    "    metadata = data_profiling(df_comb, cate_var, bool_var, cols)\n",
    "    dummy_cols = list(pd.get_dummies(df_comb[x_cols]).columns)\n",
    "    cate_desc = cate_info(df_comb, cate_var)\n",
    "    df_test_dummy = pd.get_dummies(df_test[x_cols])\n",
    "    df_test_dummy = col_check(df_test_dummy, dummy_cols)\n",
    "    log_file = f'log/erich_{num}.txt'\n",
    "    magan = MultiAgentGAN(gen_client, opt_client, gen_model_nm, opt_model_nm, params, sampled_data, cols, y_col, num_var, metadata, cate_desc, data_desc, log_file, opt_temperature=1, real_samples_num=1)\n",
    "    magan._run(batch_size, epoch)\n",
    "    magan.run_with_fixed_discriminator(0, epoch, batch_size, 1)\n",
    "    model_dict[str(num)] = magan\n",
    "    \n",
    "    for j in range(len(seed_lst)):\n",
    "        seed = seed_lst[j]\n",
    "        res[str(num)][str(seed)] = {}\n",
    "        mallm_syn_df = magan.gen_without_optimization(num_folds=1)\n",
    "        mallm_syn_df.to_csv(f'gen/erich/{num}/df_{seed}.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cc761c0-0d69-4ee9-ad94-22ffb5f52880",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10",
   "language": "python",
   "name": "py3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
