{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4dd64abc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import copy\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import prompt_utils as utils\n",
    "from optimization_utils import fill_missing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ade56cfc",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def run(_DATA, _SEED, _RULE_DIR, _NUM_COL, _NUM_QUERY, _API_KEY):\n",
    "    utils.set_seed(_SEED)\n",
    "    df, X_train, X_test, _, _, target_attr, label_list, is_cat = utils.get_dataset(_DATA, _SEED)\n",
    "    X_train_org, _ = fill_missing(X_train.copy(), X_test.copy())\n",
    "\n",
    "    # Feature bagging\n",
    "    if len(X_train.columns) >= 20:\n",
    "        total_column_list = []\n",
    "        for i in range(len(X_train.columns) // 10):\n",
    "            column_list = X_train.columns.tolist()\n",
    "            random.shuffle(column_list)\n",
    "            total_column_list.append(column_list[i*10:(i+1)*10])\n",
    "    else:\n",
    "        total_column_list = [X_train.columns.tolist()]\n",
    "\n",
    "    meta_data_name = f\"./data/{_DATA}-metadata.json\"\n",
    "    function_file_name = './template/ask_for_function.txt'\n",
    "\n",
    "    rule_file_name = f'./LLM_results/{_RULE_DIR}/columns-{_DATA}-{_SEED}.out'\n",
    "    saved_file_name = f'./LLM_results/{_RULE_DIR}/function-{_DATA}-{_SEED}.out'\n",
    "\n",
    "    results = []\n",
    "    fct_strs = []\n",
    "    prev_modules_list = []\n",
    "    current_query_num = 0\n",
    "    if os.path.isfile(saved_file_name) == False:\n",
    "        while current_query_num < _NUM_QUERY:\n",
    "            print(f\"Extracting columns {_DATA}/{_SEED} - {current_query_num}/{_NUM_QUERY}\")\n",
    "            # Ask llm to extract features\n",
    "            if len(prev_modules_list) == 0:\n",
    "                ask_file_name = './template/ask_columns.txt'\n",
    "                templates, feature_desc = utils.get_prompt_for_asking(\n",
    "                    _DATA, X_train, label_list, ask_file_name, meta_data_name, is_cat, num_col=_NUM_COL, num_query=1\n",
    "                )\n",
    "                template = templates[0]\n",
    "            else:\n",
    "                ask_file_name = './template/ask_columns_diversity.txt'\n",
    "                template, feature_desc = utils.get_prompt_for_asking_with_diversity(_DATA, X_train, prev_modules_list, label_list,\n",
    "                                                                                    ask_file_name, meta_data_name, is_cat, total_column_list, \n",
    "                                                                                    current_query_num, num_col=_NUM_COL)\n",
    "\n",
    "            result = utils.query_gpt([template], api_key=_API_KEY, max_tokens=1500, temperature=0.5, verbose=False)[0]\n",
    "            results.append(result)\n",
    "\n",
    "            # Parse text to feature generation function\n",
    "            try_num = 0\n",
    "            while try_num < 10:\n",
    "                fct_template = utils.get_prompt_for_generating_function(\n",
    "                    [result], feature_desc, function_file_name\n",
    "                )\n",
    "                fct_result = utils.query_gpt(fct_template, api_key=_API_KEY, max_tokens=1500, temperature=0.5, verbose=False)\n",
    "                fct_str = fct_result[0].split('<start>')[1].split('<end>')[0].strip()\n",
    "\n",
    "                fct_str = 'def' + 'def'.join(fct_str.split('def')[1:])\n",
    "                except_handled = \"\\n\".join([\"    try: \" + fct_piece.strip() + \"\\n    except: pass\" for fct_piece in fct_str.split('\\n')[1:-1] if len(fct_piece.strip()) > 2])\n",
    "                fct_str_handled = \"\\n\".join([fct_str.split('\\n')[0], except_handled, fct_str.split('\\n')[-1]])\n",
    "\n",
    "                try:\n",
    "                    exec(fct_str_handled)\n",
    "                    X_train_new_col = locals()['column_appender'](X_train_org)\n",
    "                    new_columns = list(set(X_train_new_col.columns) - set(X_train_org.columns))\n",
    "                    X_train_new_col = X_train_new_col[new_columns]\n",
    "                except:\n",
    "                    try_num += 1\n",
    "                    continue\n",
    "                break\n",
    "\n",
    "            if try_num >= 10: # Skip for failed cases\n",
    "                continue\n",
    "\n",
    "            fct_strs.append(fct_str_handled)  \n",
    "                \n",
    "            try:\n",
    "                discovered_column_name_list = []\n",
    "                discovered_column_name_desc = []\n",
    "                for result_str in result.split('\\n'):\n",
    "                    if '|' not in result_str:\n",
    "                        continue\n",
    "                    result_str_list = result_str.split('|')\n",
    "                    discovered_column_name_list.append(result_str_list[1].strip())\n",
    "                    discovered_column_name_desc.append(result_str_list[2].strip())\n",
    "                \n",
    "                saved_modules_list = copy.deepcopy(prev_modules_list)\n",
    "                prev_modules_list = []\n",
    "                for new_column in X_train_new_col.columns:\n",
    "                    if new_column in discovered_column_name_list:\n",
    "                        found_idx = discovered_column_name_list.index(new_column)\n",
    "                        new_column_desc = discovered_column_name_desc[found_idx]\n",
    "                    else:\n",
    "                        new_column_desc = new_column\n",
    "\n",
    "                    prev_modules_list.append([new_column, new_column_desc])\n",
    "            except:\n",
    "                prev_modules_list = saved_modules_list\n",
    "                continue\n",
    "\n",
    "            current_query_num += 1\n",
    "\n",
    "        if not os.path.exists(f'./LLM_results/{_RULE_DIR}'): \n",
    "            os.makedirs(f'./LLM_results/{_RULE_DIR}') \n",
    "            \n",
    "        with open(rule_file_name, 'w') as f:\n",
    "            total_rules = \"\\n\\n---DIVIDER---\\n\\n\".join(results)\n",
    "            f.write(total_rules)\n",
    "\n",
    "        with open(saved_file_name, 'w') as f:\n",
    "            total_str = \"\\n\\n---DIVIDER---\\n\\n\".join(fct_strs)\n",
    "            f.write(total_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d0a2900",
   "metadata": {},
   "outputs": [],
   "source": [
    "_NUM_COL = 5\n",
    "_NUM_QUERY = 40\n",
    "_RULE_DIR = f'diversity_{_NUM_COL}rules_{_NUM_QUERY}trials'\n",
    "_API_KEY = '<Get your own API key>'\n",
    "\n",
    "for _DATA in [\n",
    "    'adult', 'blood', 'adult', 'tic-tac-toe', 'sequence-type','insurance',  'heart', \n",
    "    'car', 'communities', 'credit-g', 'diabetes', 'bank', 'myocardial', \n",
    "    'junglechess', 'housing', 'solution-mix', 'forest-fires', 'eucalyptus', 'balance-scale', 'vehicle'\n",
    "]:\n",
    "    for _SEED in [0, 1, 2]:\n",
    "        run(_DATA, _SEED, _RULE_DIR, _NUM_COL, _NUM_QUERY, _API_KEY)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "NFSim",
   "language": "python",
   "name": "nfsim"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
