{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cea4674e-219e-4884-92ec-133d3d649779",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "parent_dir = os.path.abspath('../..')\n",
    "sys.path.append(parent_dir)\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from eval_utils import compare_MLE, data_profiling, cate_info, col_check,  transform_label\n",
    "from openai import AzureOpenAI\n",
    "from sklearn.preprocessing import LabelBinarizer\n",
    "import os\n",
    "import json\n",
    "import pickle\n",
    "from model_glm import MultiAgentGAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ecdacd4c-8dce-4ae2-9a35-2a94bc1b86d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv('../../../data/adult_imputed.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "de489592-d5c7-4264-922d-82dd7eaf7a42",
   "metadata": {},
   "outputs": [],
   "source": [
    "data100 = pd.read_csv(\"../sample/Adult/data100.csv\")\n",
    "data200 = pd.read_csv(\"../sample/Adult/data200.csv\")\n",
    "data400 = pd.read_csv(\"../sample/Adult/data400.csv\")\n",
    "data800 = pd.read_csv(\"../sample/Adult/data800.csv\")\n",
    "data_test = pd.read_csv(\"../sample/Adult/data_test.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97a19a09-43a9-4b07-85a2-6af4646225dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d234256-19ac-49ad-9803-3bb83c40c005",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "de8b7743-f48b-4475-b0b6-61d736654582",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.drop(columns=['Unnamed: 0'], inplace=True)\n",
    "data = data.rename(columns = {'label': 'Income'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "97a184ce-e7fc-4dff-876a-2f74ddc72f93",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_cols = ['age', 'workclass', 'education', 'education-num',\n",
    "          'marital-status', 'occupation', 'relationship', 'race', 'sex',\n",
    "          'capital-gain', 'capital-loss', 'hours-per-week', 'native-country']\n",
    "y_col = 'Income'\n",
    "cate_var = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country', 'Income']\n",
    "bool_var = []\n",
    "cols = x_cols + [y_col]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59db1e97-c9da-4f1e-8ac4-f8e0f8760088",
   "metadata": {},
   "outputs": [],
   "source": [
    "api_version = \"2023-05-15\"\n",
    "\n",
    "api_key1 = \"\"# API KEY\n",
    "azure_endpoint1 = \"\"# END point\n",
    "\n",
    "api_key2 = \"\"# API KEY\n",
    "azure_endpoint2 = \"\" # END POINT\n",
    "\n",
    "gen_client = AzureOpenAI(\n",
    "            api_key = api_key1,\n",
    "            api_version = api_version,\n",
    "            azure_endpoint = azure_endpoint1\n",
    "        )\n",
    "\n",
    "opt_client = AzureOpenAI(\n",
    "    api_key = api_key2,\n",
    "    api_version = api_version, \n",
    "    azure_endpoint = azure_endpoint2\n",
    ")\n",
    "\n",
    "gen_model_nm = 'generator4'\n",
    "\n",
    "opt_model_nm = 'yaobin_gpt4'\n",
    "\n",
    "params = {\n",
    "    'max_depth': 3,  # the maximum depth of each tree\n",
    "    'eta': 0.3,  # the training step for each iteration\n",
    "    'objective': 'binary:logistic',  # logistic regression for binary classification, output probability\n",
    "    'eval_metric': 'logloss'  # evaluation metric for binary classification\n",
    "}\n",
    "data_desc = \"The dataset include people's social economic factors and demographics with the label that indicates whether their income is higher than 50k.\"\n",
    "num_var = list((set(cols) - set(cate_var)) - set(bool_var))\n",
    "num_var"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0d02ef27-8133-46e2-a787-95b1abb5e7a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = data[cols].copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b29b4eed-d7d4-4166-a27b-bd6b7c9c7541",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_test = data.sample(200, random_state=42)\n",
    "data_curated = data.drop(data_test.index)\n",
    "data100 = data_curated.sample(n = 100, replace=False, random_state=42)\n",
    "data200 = data_curated.sample(n = 200, replace=False, random_state=42)\n",
    "data400 = data_curated.sample(n = 400, replace=False, random_state=42)\n",
    "data800 = data_curated.sample(n = 800, replace=False, random_state=42)\n",
    "print(set(data100.index).intersection(set(data_test.index)))\n",
    "data100.reset_index(inplace=True, drop=True)\n",
    "data200.reset_index(inplace=True, drop=True)\n",
    "data400.reset_index(inplace=True, drop=True)\n",
    "data800.reset_index(inplace=True, drop=True)\n",
    "\n",
    "data_test.to_csv(\"../sample/Adult/data_test.csv\", index=False)\n",
    "data100.to_csv(\"../sample/Adult/data100.csv\", index=False)\n",
    "data200.to_csv(\"../sample/Adult/data200.csv\", index=False)\n",
    "data400.to_csv(\"../sample/Adult/data400.csv\", index=False)\n",
    "data800.to_csv(\"../sample/Adult/data800.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2576ecc3-f574-4afa-ab07-e3f8003d7d02",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = {}  \n",
    "model_dict = {}\n",
    "num_samples = [100, 200, 400, 800]\n",
    "sampled_datasets = [data100, data200, data400, data800]\n",
    "epochs = [5, 4, 3, 2]\n",
    "batch_size = 50\n",
    "seed_lst = list(range(5))\n",
    "df_test = data_test.reset_index(drop=True).copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e56c7739-fee1-40a9-b316-fe8fa34da05d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for i in range(1, 4):\n",
    "    num = num_samples[i]\n",
    "    sampled_data = sampled_datasets[i]\n",
    "    epoch = epochs[i]\n",
    "    res[str(num)] = {}\n",
    "    df_comb = pd.concat([sampled_data, df_test])\n",
    "    df_comb.reset_index(inplace=True, drop = True)\n",
    "    metadata = data_profiling(df_comb, cate_var, bool_var, cols)\n",
    "    dummy_cols = list(pd.get_dummies(df_comb[x_cols]).columns)\n",
    "    cate_desc = cate_info(df_comb, cate_var)\n",
    "    df_test_dummy = pd.get_dummies(df_test[x_cols])\n",
    "    df_test_dummy = col_check(df_test_dummy, dummy_cols)\n",
    "    X_test = df_test_dummy.to_numpy()\n",
    "    y_test = df_test[y_col].to_numpy()\n",
    "    log_file = f'log/adult_{num}.txt'\n",
    "    magan = MultiAgentGAN(gen_client, opt_client, gen_model_nm, opt_model_nm, params, sampled_data, cols, y_col, num_var, metadata, cate_desc, data_desc, log_file, opt_temperature=1, real_samples_num=5)\n",
    "    magan._run(batch_size, epoch)\n",
    "    magan.run_with_fixed_discriminator(0, epoch, batch_size, 1)\n",
    "    model_dict[str(num)] = magan\n",
    "    \n",
    "    for j in range(len(seed_lst)):\n",
    "        seed = seed_lst[j]\n",
    "        res[str(num)][str(seed)] = {}\n",
    "        mallm_syn_df = magan.gen_without_optimization(num_folds=1)\n",
    "        mallm_syn_df.to_csv(f'gen/adult/{num}/df_{seed}.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5557c18-9fc5-4885-bdc9-f88418546c3e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10",
   "language": "python",
   "name": "py3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
