{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "27e4326d-e4fe-4a3e-a6cc-cc55025e9676",
   "metadata": {},
   "outputs": [],
   "source": [
    "import bnlearn as bn\n",
    "import os\n",
    "import sys\n",
    "parent_dir = os.path.abspath('../..')\n",
    "sys.path.append(parent_dir)\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from eval_utils import compare_MLE, data_profiling, cate_info, col_check,  transform_label\n",
    "from openai import AzureOpenAI\n",
    "from sklearn.preprocessing import LabelBinarizer\n",
    "import os\n",
    "import json\n",
    "import pickle\n",
    "from model_glm import MultiAgentGAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7dc98702-9f55-4d3b-ad5c-05556fdbca3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = bn.import_example(\"asia\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4ee56d4-0be7-4b67-a235-33cdfe99484b",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d1f77259-d750-47de-aede-5a786a55033d",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.columns = [\"visit to Asia\", \"tuberculosis\", \"smoke\", \"lung cancer\", \"bronchitis\", \"either tuberculosis or lung cancer\", \"chest X-ray\", \"Dyspnea\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5547109c-4743-45fa-b2af-6063ebeeff39",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = [\"visit to Asia\", \"tuberculosis\", \"smoke\", \"lung cancer\", \"bronchitis\", \"either tuberculosis or lung cancer\", \"chest X-ray\", \"Dyspnea\"]\n",
    "len(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4d8e202b-191e-4ca5-8fc6-1c10c9abfcb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_test = data.sample(200, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6e82e836-9c26-4ea9-9917-86596fc20bdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_curated = data.drop(data_test.index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3b224350-7a08-4fd7-b707-9d1baf29a228",
   "metadata": {},
   "outputs": [],
   "source": [
    "data100 = data_curated.sample(n = 100, replace=False, random_state=42)\n",
    "data200 = data_curated.sample(n = 200, replace=False, random_state=42)\n",
    "data400 = data_curated.sample(n = 400, replace=False, random_state=42)\n",
    "data800 = data_curated.sample(n = 800, replace=False, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e773f8e1-d837-4d68-bae0-0ea3acd71ce9",
   "metadata": {},
   "outputs": [],
   "source": [
    "data100.reset_index(inplace=True, drop=True)\n",
    "data200.reset_index(inplace=True, drop=True)\n",
    "data400.reset_index(inplace=True, drop=True)\n",
    "data800.reset_index(inplace=True, drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f1368eb2-5e39-4566-b6e8-d7e1b93509b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_test.to_csv(\"../sample/Asia/data_test.csv\", index=False)\n",
    "data100.to_csv(\"../sample/Asia/data100.csv\", index=False)\n",
    "data200.to_csv(\"../sample/Asia/data200.csv\", index=False)\n",
    "data400.to_csv(\"../sample/Asia/data400.csv\", index=False)\n",
    "data800.to_csv(\"../sample/Asia/data800.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d07316f5-4e05-4fad-9f36-fa44ec5b6872",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_test = pd.read_csv(\"../sample/Asia/data_test.csv\")\n",
    "data100 = pd.read_csv(\"../sample/Asia/data100.csv\")\n",
    "data200 = pd.read_csv(\"../sample/Asia/data200.csv\")\n",
    "data400 = pd.read_csv(\"../sample/Asia/data400.csv\")\n",
    "data800 = pd.read_csv(\"../sample/Asia/data800.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "595e4f4e-5c41-4dc5-af38-cb6608b53099",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_cols = [\"visit to Asia\", \"tuberculosis\", \"smoke\", \"lung cancer\", \"bronchitis\", \"either tuberculosis or lung cancer\", \"chest X-ray\"]\n",
    "y_col = \"Dyspnea\"\n",
    "cols = x_cols + [y_col]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "54f59a0e-c291-4c5a-9428-7eda63185718",
   "metadata": {},
   "outputs": [],
   "source": [
    "api_version = \"2023-05-15\"\n",
    "\n",
    "api_key1 = \"\"# API KEY\n",
    "azure_endpoint1 = \"\"# END point\n",
    "\n",
    "api_key2 = \"\"# API KEY\n",
    "azure_endpoint2 = \"\" # END POINT\n",
    "\n",
    "gen_client = AzureOpenAI(\n",
    "            api_key = api_key1,\n",
    "            api_version = api_version,\n",
    "            azure_endpoint = azure_endpoint1\n",
    "        )\n",
    "\n",
    "opt_client = AzureOpenAI(\n",
    "    api_key = api_key2,\n",
    "    api_version = api_version, \n",
    "    azure_endpoint = azure_endpoint2\n",
    ")\n",
    "\n",
    "gen_model_nm = 'generator4'\n",
    "\n",
    "opt_model_nm = 'yaobin_gpt4'\n",
    "\n",
    "params = {\n",
    "    'max_depth': 3,  # the maximum depth of each tree\n",
    "    'eta': 0.3,  # the training step for each iteration\n",
    "    'objective': 'binary:logistic',  # logistic regression for binary classification, output probability\n",
    "    'eval_metric': 'logloss'  # evaluation metric for binary classification\n",
    "}\n",
    "data_desc = \"The dataset include people's social economic factors and demographics with the label that indicates whether their income is higher than 50k.\"\n",
    "num_var = list((set(cols) - set(cate_var)) - set(bool_var))\n",
    "num_var"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b4774325-675b-400c-bbab-588ee15585e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_desc = \"The dataset is to describe potential factors that might cause dyspnea. All the variables are boolean variables with yes and no.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "dedf9f4c-3a15-46fc-8aff-13d1aa857298",
   "metadata": {},
   "outputs": [],
   "source": [
    "data100 = data100.applymap(lambda x: 'yes' if x == 1 else 'no')\n",
    "data200 = data200.applymap(lambda x: 'yes' if x == 1 else 'no')\n",
    "data400 = data400.applymap(lambda x: 'yes' if x == 1 else 'no')\n",
    "data800 = data800.applymap(lambda x: 'yes' if x == 1 else 'no')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "852a6634-ab83-4ed0-93d6-d1222c21e1a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_var = []\n",
    "cate_var = []\n",
    "bool_var = cols"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8847d06-9b49-460d-8920-fa5f46d75666",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "res = {}\n",
    "model_dict = {}\n",
    "num_samples = [100, 200, 400, 800]\n",
    "sampled_datasets = [data100, data200, data400, data800]\n",
    "epochs = [5, 4, 3, 2]\n",
    "batch_size = 50\n",
    "seed_lst = list(range(5))\n",
    "df_test = data_test.reset_index(drop=True).copy()\n",
    "for i in range(1, 4):\n",
    "    num = num_samples[i]\n",
    "    sampled_data = sampled_datasets[i]\n",
    "    epoch = epochs[i]\n",
    "    res[str(num)] = {}\n",
    "    df_comb = pd.concat([sampled_data, df_test])\n",
    "    df_comb.reset_index(inplace=True, drop = True)\n",
    "    metadata = data_profiling(df_comb, cate_var, bool_var, cols)\n",
    "    dummy_cols = list(pd.get_dummies(df_comb[x_cols]).columns)\n",
    "    cate_desc = cate_info(df_comb, cate_var)\n",
    "    df_test_dummy = pd.get_dummies(df_test[x_cols])\n",
    "    df_test_dummy = col_check(df_test_dummy, dummy_cols)\n",
    "    X_test = df_test_dummy.to_numpy()\n",
    "    y_test = df_test[y_col].to_numpy()\n",
    "    log_file = f'log/asia_{num}.txt'\n",
    "    magan = MultiAgentGAN(gen_client, opt_client, gen_model_nm, opt_model_nm, params, sampled_data, cols, y_col, num_var, metadata, cate_desc, data_desc, log_file, opt_temperature=1, real_samples_num=4)\n",
    "    print(magan.gen_temperature)\n",
    "    print(magan.opt_temperature)\n",
    "    magan._run(batch_size, epoch)\n",
    "    magan.run_with_fixed_discriminator(0, epoch, batch_size, 1)\n",
    "    model_dict[str(num)] = magan\n",
    "    \n",
    "    for j in range(len(seed_lst)):\n",
    "        seed = seed_lst[j]\n",
    "        res[str(num)][str(seed)] = {}\n",
    "        mallm_syn_df = magan.gen_without_optimization(num_folds=1)\n",
    "        mallm_syn_df = mallm_syn_df.applymap(lambda x: 1 if x == 'yes' else 0)\n",
    "        mallm_syn_df.to_csv(f'gen/asia/{num}/df_{seed}.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "ffc8d798-486f-4ed7-9329-91d4bc5b7b00",
   "metadata": {},
   "outputs": [],
   "source": [
    "mallm_syn_df = mallm_syn_df.applymap(lambda x: 1 if x == 'yes' else 0)\n",
    "mallm_syn_df_dummy = pd.get_dummies(mallm_syn_df[x_cols])\n",
    "mallm_syn_df_dummy = col_check(mallm_syn_df_dummy, dummy_cols)\n",
    "        \n",
    "y_train = mallm_syn_df[y_col].to_numpy()\n",
    "X_train = mallm_syn_df_dummy.to_numpy()\n",
    "mallm_performance = compare_MLE(X_test, y_test, X_train, y_train, task='classification')\n",
    "res[str(num)][str(seed)]['mallm'] = mallm_performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "352387f2-4623-4504-8147-9d9565de22ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca24c5e2-4c06-48da-b010-3b3db70998f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "mallm_syn_df.applymap(lambda x: 1 if x == 'yes' else 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "bc14197a-9544-477b-8e43-5304fc269c7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mle_summary(res):\n",
    "    eval_sum_counts = {}\n",
    "    highest_avg_values = {}\n",
    "\n",
    "    for num, results in res.items():\n",
    "        eval_sum_counts[num] = {}\n",
    "        highest_avg_values[num] = {}\n",
    "        for seed, results0 in results.items():\n",
    "            if seed != 'baseline':\n",
    "                for model, results00 in results0.items():\n",
    "                    if model not in eval_sum_counts[num]:\n",
    "                        eval_sum_counts[num][model] = {}\n",
    "                    for eval_model, results000 in results00.items():\n",
    "                        accuracy = results000.get('accuracy', None)\n",
    "                        f1_score = results000.get('f1 score', None)\n",
    "                        if accuracy is not None and f1_score is not None:\n",
    "                            avg_value = (f1_score + f1_score) / 2\n",
    "                            if model not in highest_avg_values[num]:\n",
    "                                highest_avg_values[num][model] = []\n",
    "                            highest_avg_values[num][model].append((avg_value, seed, eval_model))\n",
    "            else:\n",
    "                eval_sum_counts[num]['baseline'] = results0\n",
    "\n",
    "    res_average = {}\n",
    "    for num, model_data in highest_avg_values.items():\n",
    "        res_average[num] = {}\n",
    "        for model, avg_data in model_data.items():\n",
    "            # Select the highest average value for each seed\n",
    "            highest_avg_by_seed = {}\n",
    "            for avg_value, seed, eval_model in avg_data:\n",
    "                if seed not in highest_avg_by_seed or avg_value > highest_avg_by_seed[seed][0]:\n",
    "                    highest_avg_by_seed[seed] = (avg_value, eval_model)\n",
    "\n",
    "            # Calculate mean and std of the highest average values across seeds\n",
    "            highest_avg_values_list = [value[0] for value in highest_avg_by_seed.values()]\n",
    "            res_average[num][model] = {\n",
    "                'highest_avg_mean': np.mean(highest_avg_values_list),\n",
    "                'highest_avg_std': np.std(highest_avg_values_list),\n",
    "                'eval_model_selected': {seed: eval_model for seed, (avg_value, eval_model) in highest_avg_by_seed.items()}\n",
    "            }\n",
    "\n",
    "    return res_average"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f22eca32-6274-4ff9-a12c-3ebdef6449b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce3f6f04-6d9d-470f-ae8a-9d42e3a3e894",
   "metadata": {},
   "outputs": [],
   "source": [
    "mle_summary(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04e97d80-aac2-467c-859b-f21f0651a854",
   "metadata": {},
   "outputs": [],
   "source": [
    "res[\"100\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "efb00bf9-b774-4260-a5fd-7b94fa1f2591",
   "metadata": {},
   "source": [
    "# Asia num shots check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5701f0ee-e4ad-475c-9142-05f504813a59",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3f8a382-627f-4dbe-9b47-c9fcb7a92349",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_shots = [1, 2, 3, 4, 5]\n",
    "res = {}\n",
    "model_dict = {}\n",
    "batch_size = 50\n",
    "df_test = data_test.reset_index(drop=True).copy()\n",
    "for i in range(5):\n",
    "    n = num_shots[i]\n",
    "    num = 100\n",
    "    sampled_data = data100.copy()\n",
    "    epoch = 5\n",
    "    res[str(num)] = {}\n",
    "    df_comb = pd.concat([sampled_data, df_test])\n",
    "    df_comb.reset_index(inplace=True, drop = True)\n",
    "    metadata = data_profiling(df_comb, cate_var, bool_var, cols)\n",
    "    dummy_cols = list(pd.get_dummies(df_comb[x_cols]).columns)\n",
    "    cate_desc = cate_info(df_comb, cate_var)\n",
    "    df_test_dummy = pd.get_dummies(df_test[x_cols])\n",
    "    df_test_dummy = col_check(df_test_dummy, dummy_cols)\n",
    "    X_test = df_test_dummy.to_numpy()\n",
    "    y_test = df_test[y_col].to_numpy()\n",
    "    log_file = f'log/asia_{num}shots.txt'\n",
    "    magan = MultiAgentGAN(gen_client, opt_client, gen_model_nm, opt_model_nm, params, sampled_data, cols, y_col, num_var, metadata, cate_desc, data_desc, log_file, opt_temperature=1, real_samples_num=n)\n",
    "    print(magan.gen_temperature)\n",
    "    print(magan.opt_temperature)\n",
    "    magan._run(batch_size, epoch)\n",
    "    magan.run_with_fixed_discriminator(0, epoch, batch_size, 1)\n",
    "    model_dict[str(num)] = magan\n",
    "    \n",
    "    for j in range(5):\n",
    "        seed = j\n",
    "        res[str(num)][str(seed)] = {}\n",
    "        mallm_syn_df = magan.gen_without_optimization(num_folds=1)\n",
    "        mallm_syn_df.to_csv(f'gen/asia/{num}/df_{seed}_{n}shots.csv', index=False)\n",
    "        mallm_syn_df = mallm_syn_df.applymap(lambda x: 1 if x == \"yes\" else 0)\n",
    "        mallm_syn_df_dummy = pd.get_dummies(mallm_syn_df[x_cols])\n",
    "        mallm_syn_df_dummy = col_check(mallm_syn_df_dummy, dummy_cols)\n",
    "        \n",
    "        y_train = mallm_syn_df[y_col].to_numpy()\n",
    "        X_train = mallm_syn_df_dummy.to_numpy()\n",
    "        mallm_performance = compare_MLE(X_test, y_test, X_train, y_train, task='classification')\n",
    "        res[str(num)][str(seed)]['mallm'] = mallm_performance\n",
    "        with open(f'log/asia_{n}shots_mle.log', 'a') as f:\n",
    "            f.write(str(mallm_performance)+'\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a33e674-216f-4bd7-9149-6d6e8bb218cb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10",
   "language": "python",
   "name": "py3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
