{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "42b718e6-bd3d-46dc-832e-1d8968d18f55",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "parent_dir = os.path.abspath('../..')\n",
    "sys.path.append(parent_dir)\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from eval_utils import compare_MLE, data_profiling, cate_info, col_check,  transform_label\n",
    "from openai import AzureOpenAI\n",
    "from sklearn.preprocessing import LabelBinarizer\n",
    "import os\n",
    "import json\n",
    "import pickle\n",
    "from model_glm import MultiAgentGAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fd174354-d52e-435a-895e-149c4e4bb862",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv('../../../data/df_imputed.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3e819b3a-c271-4530-9617-e073fbef772a",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.drop(columns=[\"Unnamed: 0\"], inplace=True)\n",
    "data.loc[data[\"treatment\"] == 2, \"treatment\"] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "56e6b60b-18d9-4e43-873d-d97f054d731a",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.columns = ['treatment', 'age', 'mRS score after 90 days',\n",
    "       'mRS score after 30 days', 'adverse events', 'ICH volume',\n",
    "       'ICH Location', 'IVH volume', 'GCS score', 'NIHSS score',\n",
    "       'Systolic blood pressure',\n",
    "       'Diastolic Blood Pressure',\n",
    "       'Hypertension', 'Hyperlipidemia',\n",
    "       'Type I Diabetes', 'Type II Diabetes',\n",
    "       'Congestive heart failure', 'Atrial Fibrillation',\n",
    "       'PTCA',\n",
    "       'Peripheral Vascular Disease',\n",
    "       'Myocardial fraction', 'Anti-diabetic',\n",
    "       'Antihypertensives', 'White blood count', 'Hemoglobin',\n",
    "       'Hematocrit', 'Platelet count', 'APTT', 'INR', 'Glucose', 'Sodium',\n",
    "       'Potassium', 'Chloride', 'CD', 'Blood urea nitrogen', 'Creatinine',\n",
    "       'race', 'sex', 'ethnicity']\n",
    "x_cols = ['treatment', 'age', 'ICH volume',\n",
    "       'ICH Location', 'IVH volume', 'GCS score', 'NIHSS score',\n",
    "       'Systolic blood pressure',\n",
    "       'Diastolic Blood Pressure',\n",
    "       'Hypertension', 'Hyperlipidemia',\n",
    "       'Type I Diabetes', 'Type II Diabetes',\n",
    "       'Congestive heart failure', 'Atrial Fibrillation',\n",
    "       'PTCA',\n",
    "       'Peripheral Vascular Disease',\n",
    "       'Myocardial fraction', 'Anti-diabetic',\n",
    "       'Antihypertensives', 'White blood count', 'Hemoglobin',\n",
    "       'Hematocrit', 'Platelet count', 'APTT', 'INR', 'Glucose', 'Sodium',\n",
    "       'Potassium', 'Chloride', 'CD', 'Blood urea nitrogen', 'Creatinine',\n",
    "       'race', 'sex', 'ethnicity']\n",
    "y_col = 'mRS score after 30 days'\n",
    "\n",
    "cols = x_cols + [y_col]\n",
    "cate_var = ['race', 'ICH Location', 'sex', 'ethnicity']\n",
    "bool_var = ['treatment','Hypertension', 'Hyperlipidemia',\n",
    "       'Type I Diabetes', 'Type II Diabetes',\n",
    "       'Congestive heart failure', 'Atrial Fibrillation',\n",
    "       'PTCA',\n",
    "       'Peripheral Vascular Disease',\n",
    "       'Myocardial fraction', 'Anti-diabetic',\n",
    "       'Antihypertensives']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4a92471-3ea0-45bd-aeb9-7c734882698e",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(set(cols) - set(cate_var) - set(bool_var))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8687ec52-b8a3-410a-8011-b45bfc97a289",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = data[cols].copy()\n",
    "\n",
    "data_test = data.sample(n=200, random_state=42)\n",
    "data_curated = data.drop(data_test.index)\n",
    "data100 = data_curated.sample(n = 100, replace=False, random_state=42)\n",
    "data200 = data_curated.sample(n = 200, replace=False, random_state=42)\n",
    "data400 = data_curated.sample(n = 400, replace=False, random_state=42)\n",
    "data800 = data_curated.sample(n = 800, replace=False, random_state=42)\n",
    "data100.reset_index(inplace=True, drop=True)\n",
    "data200.reset_index(inplace=True, drop=True)\n",
    "data400.reset_index(inplace=True, drop=True)\n",
    "data800.reset_index(inplace=True, drop=True)\n",
    "\n",
    "data_test.to_csv(\"../sample/ATACH/data_test.csv\", index=False)\n",
    "data100.to_csv(\"../sample/ATACH/data100.csv\", index=False)\n",
    "data200.to_csv(\"../sample/ATACH/data200.csv\", index=False)\n",
    "data400.to_csv(\"../sample/ATACH/data400.csv\", index=False)\n",
    "data800.to_csv(\"../sample/ATACH/data800.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0ab7977d-9083-425e-95d7-d750860a7990",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = {}  \n",
    "model_dict = {}\n",
    "num_samples = [100, 200, 400, 800]\n",
    "sampled_datasets = [data100, data200, data400, data800]\n",
    "epochs = [5, 4, 3, 2]\n",
    "batch_size = 50\n",
    "seed_lst = list(range(5))\n",
    "df_test = data_test.reset_index(drop=True).copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ced63817-add7-4dae-872e-2e8350defda1",
   "metadata": {},
   "outputs": [],
   "source": [
    "api_version = \"2023-05-15\"\n",
    "\n",
    "api_key1 = \"\"# API KEY\n",
    "azure_endpoint1 = \"\"# END point\n",
    "\n",
    "api_key2 = \"\"# API KEY\n",
    "azure_endpoint2 = \"\" # END POINT\n",
    "\n",
    "gen_client = AzureOpenAI(\n",
    "            api_key = api_key1,\n",
    "            api_version = api_version,\n",
    "            azure_endpoint = azure_endpoint1\n",
    "        )\n",
    "\n",
    "opt_client = AzureOpenAI(\n",
    "    api_key = api_key2,\n",
    "    api_version = api_version, \n",
    "    azure_endpoint = azure_endpoint2\n",
    ")\n",
    "\n",
    "gen_model_nm = 'generator4'\n",
    "\n",
    "opt_model_nm = 'yaobin_gpt4'\n",
    "\n",
    "params = {\n",
    "    'max_depth': 3,  # the maximum depth of each tree\n",
    "    'eta': 0.3,  # the training step for each iteration\n",
    "    'objective': 'binary:logistic',  # logistic regression for binary classification, output probability\n",
    "    'eval_metric': 'logloss'  # evaluation metric for binary classification\n",
    "}\n",
    "data_desc = \"The dataset include people's social economic factors and demographics with the label that indicates whether their income is higher than 50k.\"\n",
    "num_var = list((set(cols) - set(cate_var)) - set(bool_var))\n",
    "num_var"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78c0303f-3a8c-4aa8-be49-5a5c9eee12d8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "res = {}\n",
    "for i in range(1, 4):\n",
    "    num = num_samples[i]\n",
    "    sampled_data = sampled_datasets[i]\n",
    "    epoch = epochs[i]\n",
    "    res[str(num)] = {}\n",
    "    df_comb = pd.concat([sampled_data, df_test])\n",
    "    df_comb.reset_index(inplace=True, drop = True)\n",
    "    metadata = data_profiling(df_comb, cate_var, bool_var, cols)\n",
    "    dummy_cols = list(pd.get_dummies(df_comb[x_cols]).columns)\n",
    "    cate_desc = cate_info(df_comb, cate_var)\n",
    "    df_test_dummy = pd.get_dummies(df_test[x_cols])\n",
    "    df_test_dummy = col_check(df_test_dummy, dummy_cols)\n",
    "    X_test = df_test_dummy.to_numpy()\n",
    "    y_test = df_test[y_col].to_numpy()\n",
    "    log_file = f'log/atach_{num}.txt'\n",
    "    magan = MultiAgentGAN(gen_client, opt_client, gen_model_nm, opt_model_nm, params, sampled_data, cols, y_col, num_var, metadata, cate_desc, data_desc, log_file, opt_temperature=1, real_samples_num=5)\n",
    "    print(magan.gen_temperature)\n",
    "    print(magan.opt_temperature)\n",
    "    magan._run(batch_size, epoch)\n",
    "    magan.run_with_fixed_discriminator(0, epoch, batch_size, 1)\n",
    "    model_dict[str(num)] = magan\n",
    "    \n",
    "    for j in range(len(seed_lst)):\n",
    "        seed = seed_lst[j]\n",
    "        res[str(num)][str(seed)] = {}\n",
    "        mallm_syn_df = magan.gen_without_optimization(num_folds=1)\n",
    "        mallm_syn_df.to_csv(f'gen/atach/{num}/df_{seed}.csv', index=False)\n",
    "        mallm_syn_df_dummy = pd.get_dummies(mallm_syn_df[x_cols])\n",
    "        mallm_syn_df_dummy = col_check(mallm_syn_df_dummy, dummy_cols)\n",
    "        \n",
    "        y_train = mallm_syn_df[y_col].to_numpy()\n",
    "        X_train = mallm_syn_df_dummy.to_numpy()\n",
    "        mallm_performance = compare_MLE(X_test, y_test, X_train, y_train, task='regression')\n",
    "        res[str(num)][str(seed)]['mallm'] = mallm_performance\n",
    "        with open(f'log/atach_{num}_mle.log', 'a') as f:\n",
    "            f.write(str(mallm_performance)+'\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9a78fef8-9f22-4e84-b043-5149ad98f045",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mle_summary(res):\n",
    "\n",
    "    eval_sum_counts = {}\n",
    "    highest_r2_values = {}\n",
    "\n",
    "    for num, results in res.items():\n",
    "        eval_sum_counts[num] = {}\n",
    "        highest_r2_values[num] = {}\n",
    "        for seed, results0 in results.items():\n",
    "            if seed != 'baseline':\n",
    "                for model, results00 in results0.items():\n",
    "                    if model not in eval_sum_counts[num]:\n",
    "                        eval_sum_counts[num][model] = {}\n",
    "                    for eval_model, results000 in results00.items():\n",
    "                        r2_value = results000.get('R2', None)\n",
    "                        if r2_value is not None:\n",
    "                            if model not in highest_r2_values[num]:\n",
    "                                highest_r2_values[num][model] = []\n",
    "                            highest_r2_values[num][model].append((r2_value, seed, eval_model))\n",
    "            else:\n",
    "                eval_sum_counts[num]['baseline'] = results0\n",
    "\n",
    "    res_average = {}\n",
    "    for num, model_data in highest_r2_values.items():\n",
    "        res_average[num] = {}\n",
    "        for model, r2_data in model_data.items():\n",
    "            # Select the highest R2 value for each seed\n",
    "            highest_r2_by_seed = {}\n",
    "            for r2_value, seed, eval_model in r2_data:\n",
    "                if seed not in highest_r2_by_seed or r2_value > highest_r2_by_seed[seed][0]:\n",
    "                    highest_r2_by_seed[seed] = (r2_value, eval_model)\n",
    "\n",
    "            # Calculate mean and std of the highest R2 values across seeds\n",
    "            highest_r2_values_list = [value[0] for value in highest_r2_by_seed.values()]\n",
    "            res_average[num][model] = {\n",
    "                'highest_R2_mean': np.mean(highest_r2_values_list),\n",
    "                'highest_R2_std': np.std(highest_r2_values_list),\n",
    "                'eval_model_selected': {seed: eval_model for seed, (r2_value, eval_model) in highest_r2_by_seed.items()}\n",
    "            }\n",
    "\n",
    "    return res_average"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b70b6cae-6caf-4ae4-8639-a0f3168fee67",
   "metadata": {},
   "outputs": [],
   "source": [
    "mle_summary(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6f7cf66-0bdc-488e-9a40-574c12f79b59",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10",
   "language": "python",
   "name": "py3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
