{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c74b828-be54-4e52-88a1-200ad017036e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "from datasets import load_dataset, Dataset, DatasetDict\n",
    "from huggingface_hub.hf_api import HfFolder\n",
    "HfFolder.save_token('')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c2629cd-72ec-4d9f-ac33-def202ae6a57",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset('strongpear/BaKaTeQe_ver2')\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78b43be2-75c8-4ac6-9b8e-de027045c186",
   "metadata": {},
   "outputs": [],
   "source": [
    "fewshot_dataset = load_dataset('TeeA/BaKaTeQe', name='few-shot')\n",
    "fewshot_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0994d2a4-2a57-498d-9887-c2c80e27bb36",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = dataset['train'].to_pandas()\n",
    "fewshot_df = fewshot_dataset['train'].to_pandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7f4b703-5bc4-4b13-a8be-6857d15fa8ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['tmp_key'] = df['question_syll'] + '$$$$$' + df['schema_syll']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87cc4602-c0b4-4611-85dc-58cb822431ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "fewshot_df['fewshot_questions_syll'], fewshot_df['fewshot_queries_syll'] = None, None\n",
    "for i, row in fewshot_df.iterrows():\n",
    "    fewshot_questions_syll, fewshot_queries_syll = [], []\n",
    "    for idx in row['few_shot_idx_syll']:\n",
    "        fewshot_question = fewshot_df.iloc[idx]['question_syll']\n",
    "        fewshot_questions_syll.append(fewshot_question)\n",
    "\n",
    "        fewshot_queries = fewshot_df.iloc[idx]['query_syll']\n",
    "        fewshot_queries_syll.append(fewshot_queries)\n",
    "\n",
    "    fewshot_df.at[i, 'fewshot_questions_syll'] = fewshot_questions_syll\n",
    "    fewshot_df.at[i, 'fewshot_queries_syll'] = fewshot_queries_syll"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75bc1ce0-1a78-4eca-ac26-1b5bb4dc336d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fewshot_df[['question_syll', 'fewshot_questions_syll', 'fewshot_queries_syll']].head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bee00dd8-462a-4c0c-92fe-466cf1f1fcbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 100\n",
    "\n",
    "print(fewshot_df.iloc[i]['question_syll'])\n",
    "print()\n",
    "print(fewshot_df.iloc[i]['fewshot_questions_syll'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c23586c-b0b6-49cb-8a77-e8da61ad16f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 100\n",
    "\n",
    "print(fewshot_df.iloc[i]['query_syll'])\n",
    "print()\n",
    "print(fewshot_df.iloc[i]['fewshot_queries_syll'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1d0554b-7fbc-4336-b045-033a3c4a6930",
   "metadata": {},
   "outputs": [],
   "source": [
    "fewshot_df['tmp_key'] = fewshot_df['question_syll'] + '$$$$$' + fewshot_df['schema_syll']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f718d63-a8e8-43e3-a5cc-0f8aad511ad4",
   "metadata": {},
   "outputs": [],
   "source": [
    "fewshot_df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0eb35d4c-19fb-450b-92a5-c2bb597b8ce5",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7039e163-40fc-400d-8ae0-86f1ba232697",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.merge(fewshot_df[['tmp_key', 'fewshot_questions_syll', 'few_shot_idx_syll']], how='left', on='tmp_key')\n",
    "df = df.drop_duplicates(subset=['tmp_key'])\n",
    "df = df.reset_index(drop=True)\n",
    "df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51c9ea6f-6884-413e-82aa-68b6e5d35e26",
   "metadata": {},
   "outputs": [],
   "source": [
    "df[df['tmp_key'].duplicated()].index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfd09954-a4e6-4ade-9da1-c7120f3d978c",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "641fab66-95e7-4a58-aa78-5b75ccb93d7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['fewshot_idx'] = None\n",
    "for i, row in tqdm(df.iterrows()):\n",
    "    fewshot_idx = []\n",
    "    for question in row['fewshot_questions_syll']:\n",
    "        try:\n",
    "            idx = df[df['question_syll'] == question].index[0]\n",
    "            fewshot_idx.append(idx)\n",
    "        except:\n",
    "            pass\n",
    "\n",
    "    df.at[i, 'fewshot_idx'] = fewshot_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbf3236d-f354-4f9f-ad86-107835be58f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['fewshot_count'] = df['fewshot_idx'].apply(lambda x: len(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51458955-3b7c-43ab-9f9c-b754d44d241e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76daf9f5-c775-4784-a42e-f5e924203587",
   "metadata": {},
   "outputs": [],
   "source": [
    "backup = df[['schema_syll', 'question_syll', 'query_syll', 'prompt_syll', 'prompt_question_NLP', 'mini_schema_syll', 'question_NLP_syll', 'tmp_key', 'fewshot_idx', 'fewshot_count']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7673aa59-7922-4f98-9a82-83f2f951b409",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34bc8afe-0fb5-477f-bae3-cc05c7fdba89",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = Dataset.from_pandas(backup)\n",
    "ds_dict = DatasetDict({\n",
    "    'train': ds\n",
    "})\n",
    "\n",
    "ds_dict.push_to_hub('strongpear/BaKaTeQe_ver2', 'fewshot')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f54bf187-e5a9-4db9-a382-ae2955a91d7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "backup.head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "074f425f-7823-47bc-a352-1422e4bddf90",
   "metadata": {},
   "outputs": [],
   "source": [
    "backup['fewshot_count'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c360af52-258c-411b-bd60-67403a8891a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess(sample):\n",
    "    if sample['fewshot_count'] >= 3:\n",
    "        shots = []\n",
    "        for id in sample['fewshot_idx'][:3]:\n",
    "            similar_sample = df.iloc[id]\n",
    "            shots.append(f\"\"\"###schema: {similar_sample['schema_syll']}, ###câu hỏi: {similar_sample['question_syll']}, ###câu sql: {similar_sample['query_syll']}\"\"\")\n",
    "        shots.append(f\"\"\"[INST] Sinh ra câu sql từ câu hỏi tương ứng với schema được cung cấp [/INST] ###schema: {sample['schema_syll']}, ###câu hỏi: {sample['question_syll']}, ###câu sql: {sample['query_syll']}\"\"\")\n",
    "        text = \"<s> \" + \"\\n\".join(shots) + \" </s>\"\n",
    "        return text\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1350d668-d585-4792-9ce4-b5cd324e0afc",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['text'] = df.apply(preprocess, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31a282f6-9452-431b-ac3b-03443a89d123",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74a62f82-ea1d-45e5-9134-43e679c450dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "final = df.dropna(subset=['text'])\n",
    "final = final.reset_index(drop=True)\n",
    "final.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c76bd89-ca05-44ce-8fed-af0d357ec269",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 100\n",
    "\n",
    "print(df.iloc[i]['question_syll'])\n",
    "print()\n",
    "print(df.iloc[i]['query_syll'])\n",
    "print()\n",
    "print(df.iloc[i]['text'])\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "057d8779-4f8a-472d-bf34-aca205d8e1d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = Dataset.from_pandas(final)\n",
    "ds_dict = DatasetDict({\n",
    "    'train': ds\n",
    "})\n",
    "\n",
    "ds_dict.push_to_hub('')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f25be5d2-4904-4e26-82c0-610f20049fe7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
