{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Environment"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "28ff31b2e6f74cd9"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20e65ac6-6daa-4491-ad6f-89f642d722e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "643152b1-c706-4fc8-81c1-5514aa7d929f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "from datasets import load_dataset\n",
    "import torch\n",
    "import transformers\n",
    "from peft import PeftModel\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from huggingface_hub import snapshot_download\n",
    "from huggingface_hub.hf_api import HfFolder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d673dc9-ee53-467f-a816-42f756984329",
   "metadata": {},
   "outputs": [],
   "source": [
    "HF_TOKEN = {\n",
    "    \"<hf_token_name>\": \"<token>\"\n",
    "}\n",
    "WANDB_KEY = {\n",
    "    \"<wdb_key_name>\": \"<key>\"\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40ab2460-b64a-41a3-affe-32d764ee5525",
   "metadata": {},
   "outputs": [],
   "source": [
    "LEVEL = \"syllable\"  # @param ['syllable', 'word']\n",
    "LORA_RANK = 128  # @param [0, 4, 64, 128], 0 meaning \"full\"\n",
    "LORA_DIR = \"\"  # LoRA local/huggingface path\n",
    "RETRAIN_QLORA = True  # using LoRA to inference\n",
    "BASE_MODEL_NAME = \"codellama/CodeLlama-7b-Instruct-hf\"  #huggingface path to base model\n",
    "MAX_MEMORY = 40960  # MB {\"16GB\": 16384, \"40GB\": 40960}\n",
    "\n",
    "LORA_VERSION = LORA_DIR.split('/')[-1] # "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "178b173e-fe25-4f04-a9b5-e61b05ce22b7",
   "metadata": {},
   "source": [
    "# Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1bac453-f5ee-4554-a5cc-8db2bbfdbf61",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visit https://huggingface.co/datasets/TeeA/VinAIResearch-ViText2SQL to ask for access\n",
    "\n",
    "HfFolder.save_token(HF_TOKEN[\"<hf_token_name>\"])  # Change to your hf token\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "405ed659-f596-4a41-847e-1901b966a8fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset(\"parquet\", data_files={\n",
    "    \"test\": f\"https://huggingface.co/datasets/TeeA/VinAIResearch-ViText2SQL/resolve/main/{LEVEL}-level/test-00000-of-00001.parquet\",\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49b85430-05cd-4633-87e7-fe8c3c7eea6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "tables_dataset = load_dataset(\"parquet\", data_files={\n",
    "    \"train\": f\"https://huggingface.co/datasets/TeeA/VinAIResearch-ViText2SQL/resolve/main/{LEVEL}-level--tables/train-00000-of-00001.parquet\"\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65f0b608-b2ac-49b5-99c3-54e212b3d940",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_create_table_sql(database_info):\n",
    "    sql_commands = []\n",
    "\n",
    "    for table_index, table_name in enumerate(database_info['table_names']):\n",
    "\n",
    "        # Generate CREATE TABLE statement\n",
    "        create_table_sql = f'CREATE TABLE \"{table_name}\" ('\n",
    "        column_sql = []\n",
    "        for column_index, column_name, column_type in zip(database_info['column_indices'],\n",
    "                                                          database_info['column_names'], database_info['column_types']):\n",
    "            if column_index == table_index:\n",
    "                column_sql += [f'\"{column_name}\" {column_type}']\n",
    "\n",
    "        create_table_sql += \", \".join(column_sql) + \")\"\n",
    "\n",
    "        sql_commands.append(create_table_sql)\n",
    "    database_info['schema'] = \"; \".join(sql_commands) + \";\"\n",
    "    return database_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f07cd74d-2788-4757-b5ce-130c30c50fd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "tables_dataset['train'].filter(lambda x: x['db_id'] == 'perpetrator')[0]['schema']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "169eb84b-54f7-4bbd-886a-d396bf0af4f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "global_schema = {\n",
    "    db_id: schema for db_id, schema in zip(tables_dataset['train']['db_id'], tables_dataset['train']['schema'])\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8750d265-ce26-4dd2-a028-f0c8d79da388",
   "metadata": {},
   "source": [
    "# Data Processing"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f33c653e-7a43-4cbb-94e4-0b054f8d1370",
   "metadata": {},
   "source": [
    "**Retrieve Mini Schema**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78287630-eb55-4cbc-8b93-53c0a4276405",
   "metadata": {},
   "outputs": [],
   "source": [
    "RETRIEVE_MODEL = \"\"  # huggingface/local path to retrieve model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06a2a1f0-9f04-4631-bbf0-d4d165b2234d",
   "metadata": {},
   "outputs": [],
   "source": [
    "retrive_model = SentenceTransformer(RETRIEVE_MODEL, token=HF_TOKEN['cuonglp'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b54dd5be-aae9-425c-9f3c-5b0360e5bc90",
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_schema(sample):\n",
    "    # clean schema chatgpt gen\n",
    "    sample = sample.replace('```sql\\n', '').replace('\\n', '').replace('\\n', '')\n",
    "\n",
    "    # re-format list of table\n",
    "    tables = sample.split('CREATE TABLE')\n",
    "    schema = ['CREATE TABLE ' + table.strip() for table in tables if table]\n",
    "    # # re-format -> #####\n",
    "    # sample = '#####'.join(schema)\n",
    "    return schema\n",
    "\n",
    "\n",
    "def extract_mini_schema(question, schema):\n",
    "    tables = format_schema(schema)\n",
    "\n",
    "    results = retrive_model.similarity(retrive_model.encode(question), retrive_model.encode(tables))[0]\n",
    "\n",
    "    scores = results.tolist()\n",
    "    idx = torch.topk(results, 2).indices.tolist()\n",
    "\n",
    "    # return [scores[i] for i in idx], [tables[i] for i in idx]\n",
    "    return ' '.join([tables[i] for i in idx])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e29a261a-5d36-42dd-bc47-8438caf6efdb",
   "metadata": {},
   "source": [
    "**Processing Data**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f9f3f6e-9dd4-4429-9ff3-cbfe91162b47",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocessing data without retrieve a.k.a full schema\n",
    "def preprocess(sample):\n",
    "    global global_schema\n",
    "    sample[\n",
    "        'text'] = f\"\"\"[INST] Sinh ra câu sql từ câu hỏi tương ứng với schema được cung cấp [/INST] ###schema: {global_schema[sample['db_id']]}, ###câu hỏi: {sample[f'question']}, ###câu sql:\"\"\"\n",
    "    return sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a5ccab3-6793-414f-8536-e4bc45256f66",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dataset.map(preprocess)\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42bc553d-dd3f-4398-a606-d2c4fb1aa89b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocessing data with retrieve a.k.a mini schema\n",
    "def preprocess_2(sample):\n",
    "    global global_schema\n",
    "    schema = global_schema[sample['db_id']]\n",
    "    qs = sample[f'question']\n",
    "    mini_schema = extract_mini_schema(qs, schema)\n",
    "    sample[\n",
    "        'text_2'] = f\"\"\"[INST] Sinh ra câu sql từ câu hỏi tương ứng với schema được cung cấp [/INST] ###schema: {mini_schema}, ###câu hỏi: {qs}, ###câu sql:\"\"\".strip()\n",
    "    return sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49fe4314-6c31-4b36-8402-e99fef960ba9",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dataset.map(preprocess_2)\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Preprocessing data FewShot without retrieve a.k.a FewShot full schema"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "cb0fa6cebddbf157"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Preprocessing data FewShot with retrieve a.k.a FewShot mini schema"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "f82c21c28058f334"
  },
  {
   "cell_type": "markdown",
   "id": "a36b1a19-72d3-4ac2-9fb7-c64429ea2ce8",
   "metadata": {},
   "source": [
    "# LLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68de67e0-f8ee-45cf-90df-172fa4f42ce1",
   "metadata": {},
   "outputs": [],
   "source": [
    "level = LEVEL\n",
    "lora_rank = LORA_RANK\n",
    "lora_dir = LORA_DIR\n",
    "retrain_qlora = RETRAIN_QLORA\n",
    "model_name = BASE_MODEL_NAME"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a075448f-8cef-4b3d-b2b3-8120d6e2ac80",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set quantization configuration to load large model with less GPU memory\n",
    "# this requires the bitsandbytes library\n",
    "bnb_config = transformers.BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_quant_type='nf4',\n",
    "    bnb_4bit_use_double_quant=True,\n",
    "    bnb_4bit_compute_dtype=torch.bfloat16\n",
    ")\n",
    "\n",
    "n_gpus = torch.cuda.device_count()\n",
    "max_memory = f'{MAX_MEMORY}MB'\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_name,\n",
    "    quantization_config=bnb_config,\n",
    "    device_map=\"auto\",  # dispatch efficiently the model on the available ressources\n",
    "    max_memory={i: max_memory for i in range(n_gpus)},\n",
    ")\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e860c84f-79de-4cab-802b-03f2c1496ed2",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.padding_side = 'left'\n",
    "tokenizer.pad_token = tokenizer.unk_token\n",
    "model.config.pad_token_id = tokenizer.pad_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a42804d-ed15-446f-82d5-a001333c5bb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tokenizer.padding_side)\n",
    "print(tokenizer.add_eos_token)\n",
    "print(tokenizer.pad_token)\n",
    "print(tokenizer.unk_token)\n",
    "print(tokenizer.eos_token_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c32d5d06-cb28-4d01-a902-76bce8a65feb",
   "metadata": {},
   "outputs": [],
   "source": [
    "vi_model_merge = model\n",
    "if retrain_qlora is True:\n",
    "    peft_model_id_visquad_lora = snapshot_download(lora_dir)\n",
    "    vi_model_merge = PeftModel.from_pretrained(vi_model_merge, peft_model_id_visquad_lora, is_trainable=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4bcf845-776c-4088-840b-130e40112f4f",
   "metadata": {},
   "source": [
    "# Infer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "\n",
    "def extract_sql(text):\n",
    "    pattern = r\"###câu sql:(.*?)</s>\"\n",
    "    match = re.search(pattern, text, re.DOTALL)\n",
    "    if match:\n",
    "        return match.group(1).strip()\n",
    "    return text"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "9d122b132f284de2"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e418d22d-8ef1-464e-9ffa-38d96e997406",
   "metadata": {},
   "outputs": [],
   "source": [
    "def inference(sample):\n",
    "    # Record start time\n",
    "    start_time = time.time()\n",
    "\n",
    "    # Tokenize input text\n",
    "    encodeds = tokenizer(sample['text'], return_tensors=\"pt\", padding=True).to('cuda')\n",
    "\n",
    "    # Generate text\n",
    "    generated_ids = vi_model_merge.generate(\n",
    "        inputs=encodeds[\"input_ids\"],\n",
    "        attention_mask=encodeds[\"attention_mask\"],\n",
    "        do_sample=False,\n",
    "        temperature=0,\n",
    "        # top_k=1,\n",
    "        max_new_tokens=1000,\n",
    "        eos_token_id=tokenizer.eos_token_id,\n",
    "        pad_token_id=tokenizer.pad_token_id\n",
    "    )\n",
    "\n",
    "    # Decode generated text\n",
    "    decoded = tokenizer.batch_decode(generated_ids)\n",
    "    sample['predict'] = decoded[0]\n",
    "\n",
    "    # Calculate inference time\n",
    "    # inference_time = time.time() - start_time\n",
    "    # sample['infer_time'] = str(inference_time)\n",
    "    print(time.time() - start_time)\n",
    "    decoded = extract_sql(decoded)\n",
    "    return decoded"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2677624e-1cf8-4861-bd58-84b412e94084",
   "metadata": {},
   "outputs": [],
   "source": [
    "def batch_inference(batch):\n",
    "    # Record start time\n",
    "    # start_time = time.time()\n",
    "\n",
    "    # Tokenize input text\n",
    "    encodeds = tokenizer(batch['text'], return_tensors=\"pt\", padding=True).to('cuda') # change to other columns\n",
    "\n",
    "    # Generate text\n",
    "    generated_ids = vi_model_merge.generate(\n",
    "        inputs=encodeds[\"input_ids\"],\n",
    "        attention_mask=encodeds[\"attention_mask\"],\n",
    "        do_sample=False,\n",
    "        temperature=0.1,\n",
    "        top_k=1,\n",
    "        max_new_tokens=1000,\n",
    "        eos_token_id=tokenizer.eos_token_id,\n",
    "        pad_token_id=tokenizer.pad_token_id\n",
    "    )\n",
    "\n",
    "    # Decode generated text\n",
    "    decoded = tokenizer.batch_decode(generated_ids)\n",
    "    # sample['predict'] = decoded[0]\n",
    "\n",
    "    # Calculate inference time\n",
    "    # inference_time = time.time() - start_time\n",
    "    # sample['infer_time'] = str(inference_time)\n",
    "    \n",
    "    # Extract sql\n",
    "    text = []\n",
    "    for t in decoded:\n",
    "        try:\n",
    "            # text.append(t.split(\"###câu sql:\")[1])\n",
    "            text.append(extract_sql(t))\n",
    "        except:\n",
    "            continue\n",
    "\n",
    "    return text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ef3da5f-3fd0-425a-953a-c1ba9f83fd66",
   "metadata": {},
   "outputs": [],
   "source": [
    "saved_folder = f\"../benchmark/spider/benchmark/predictions/benchmark__{LORA_VERSION}\"\n",
    "saved_folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9419bff9-78db-44a7-980c-9747f5813583",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_latest_filename_index():\n",
    "    try:\n",
    "        file_names = os.listdir(f\"{saved_folder}\")\n",
    "        file_names = [file for file in file_names if file.endswith('.json')]\n",
    "        # Sort the list of file names in descending order\n",
    "        sorted_file_names = sorted(file_names, key=lambda x: int(x.split('.')[0]), reverse=True)\n",
    "        \n",
    "        # Select the latest index (file name)\n",
    "        latest_file_name = sorted_file_names[0]\n",
    "    \n",
    "        latest_index = int(latest_file_name.split('.')[0])\n",
    "        \n",
    "        return latest_file_name, latest_index\n",
    "    except:\n",
    "        return 0, 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9824a45-a468-44e8-a76c-5ee6a2c3cf9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from tqdm import tqdm\n",
    "\n",
    "if not os.path.exists(saved_folder):\n",
    "    os.makedirs(saved_folder)\n",
    "\n",
    "s = 0 # if first time inference, find_latest_filename_index()[1] if continue inference\n",
    "step = 8\n",
    "for idx in tqdm(range(s, len(dataset['test']), step)):\n",
    "    try:\n",
    "        inferred = batch_inference(dataset['test'][idx:idx + step])\n",
    "\n",
    "        for t, i in enumerate(range(idx, idx + step)):\n",
    "            # Specify the file path\n",
    "            file_path = f\"{saved_folder}/{i}.json\"  # Adjust the file path as needed\n",
    "\n",
    "            sample = dataset['test'][i]\n",
    "            sample['predict'] = inferred[t]\n",
    "            # Write inferred data to the JSON file\n",
    "            with open(file_path, 'w', encoding='utf-8') as f:\n",
    "                json.dump(sample, f, ensure_ascii=False)\n",
    "    except IndexError as ie:\n",
    "        print(\"Hết rồi\")\n",
    "        break"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
