{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ea470859-8e66-45f0-829d-57b738189204",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/minhnh/python_venv/nlp/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "2024-04-05 18:17:03.898867: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2024-04-05 18:17:03.898937: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2024-04-05 18:17:03.899814: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2024-04-05 18:17:03.906050: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2024-04-05 18:17:04.731787: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "import pandas as pd\n",
    "\n",
    "from huggingface_hub.hf_api import HfFolder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "32a59368-1186-4e0e-b5b1-5ba7ce277689",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "HfFolder.save_token('')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c7fef87f-cb47-4d08-8b74-16b2bf5e4b12",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/minhnh/python_venv/nlp/lib/python3.9/site-packages/datasets/load.py:2483: FutureWarning: 'use_auth_token' was deprecated in favor of 'token' in version 2.14.0 and will be removed in 3.0.0.\n",
      "You can remove this warning by passing 'token=<use_auth_token>' instead.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "dataset = load_dataset(\"hoangphu7122002ai/text2sql_vi\", use_auth_token=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "28f1566b-6e53-4b20-a145-0eb43ff2ddcc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'schema_syll': 'CREATE TABLE lab(subject_id text,hadm_id text,itemid text,charttime text,flag text,value_unit text,label text,fluid text) CREATE TABLE thủ tục(subject_id text,hadm_id text,icd9_code text,short_title text,long_title text) CREATE TABLE nhân khẩu học(subject_id text,hadm_id text,name text,marital_status text,age text,dob text,giới tính text,ngôn ngữ text,tôn giáo text,loại_nhập viện text,ngày_ở text,bảo hiểm text,dân tộc text,hết hạn_cờ text,vị trí_nhập viện text,vị trí xuất viện text,chẩn đoán text,dod text,dob_year text,dod_year text,thời gian nhập viện text,dischtime text,admityear text) CREATE TABLE đơn thuốc(subject_id text,hadm_id text,icustay_id text,drug_type text,drug text,formulary_drug_cd text,route text,drug_dose text) CREATE TABLE chẩn đoán(subject_id text,hadm_id text,icd9_code text,short_title text,long_title text)',\n",
       " 'schema_word': 'CREATE TABLE lab(subject_id text,hadm_id text,itemid text,charttime text,flag text,value_unit text,label text,fluid text) CREATE TABLE thủ_tục(subject_id text,hadm_id text,icd9_code text,short_title text,long_title text) CREATE TABLE nhân_khẩu học(subject_id text,hadm_id text,name text,marital_status text,age text,dob text,giới_tính text,ngôn_ngữ text,tôn_giáo text,loại_nhập_viện text,ngày_ở text,bảo_hiểm text,dân_tộc text,hết hạn_cờ text,vị trí_nhập_viện text,vị_trí xuất_viện text,chẩn_đoán text,dod text,dob_year text,dod_year text,thời_gian nhập_viện text,dischtime text,admityear text) CREATE TABLE đơn thuốc(subject_id text,hadm_id text,icustay_id text,drug_type text,drug text,formulary_drug_cd text,route text,drug_dose text) CREATE TABLE chẩn_đoán(subject_id text,hadm_id text,icd9_code text,short_title text,long_title text)',\n",
       " 'query_syll': 'SELECT COUNT(DISTINCTnhân khẩu học.subject_id) FROM nhân khẩu học INNER JOIN thủ tục ON nhân khẩu học.hadm_id = thủ tục.hadm_id WHERE nhân khẩu học.gender = \"F\" AND thủ tục.long_title = \"chuyển nhịp nhĩ\"',\n",
       " 'source': 'mimicsql_data',\n",
       " 'question_syll': 'cho tôi xem số lượng bệnh nhân nữ đã trải qua chuyển nhịp nhĩ.',\n",
       " 'question_word': 'cho tôi xem số_lượng bệnh_nhân nữ đã trải qua chuyển nhịp nhĩ .',\n",
       " 'query_word': 'SELECT COUNT( DISTINCT nhân_khẩu học.subject_id) FROM nhân_khẩu học INNER JOIN thủ_tục ON nhân_khẩu học.hadm_id = thủ_tục.hadm_id WHERE nhân_khẩu học.gender = \"F\" AND thủ_tục.long_title = \"chuyển nhịp nhĩ\"',\n",
       " 'label': 7}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset['train'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "304a4dc2-cf8f-42ee-9a58-d72d33a6e5f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "clusters = dataset['train']['label']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "4a2a80be-fcfa-4817-b5d9-2de604c7178b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(set(clusters))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "69d71491-4c30-4538-8a98-53f3002919fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with (open(\"my_array.pickle\", \"rb\")) as openfile:\n",
    "    list_syll = pickle.load(openfile)\n",
    "\n",
    "with (open(\"my_array_word.pickle\", \"rb\")) as openfile:\n",
    "    list_word = pickle.load(openfile)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "90284ee0-1a2e-40a5-a280-09af262f90e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dict_embedding(list_ele,list_cluster):\n",
    "    dict_embedding = {}\n",
    "    \n",
    "    for i,emb in list_ele:\n",
    "        if f'{list_cluster[i]}_idx' not in list(dict_embedding.keys()):\n",
    "            dict_embedding[f'{list_cluster[i]}_idx'] = []\n",
    "            dict_embedding[f'{list_cluster[i]}_emb'] = []\n",
    "        dict_embedding[f'{list_cluster[i]}_idx'].append(i)\n",
    "        dict_embedding[f'{list_cluster[i]}_emb'].append(emb)\n",
    "    return dict_embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "faad1974-388a-49b7-bd30-908eb7c75fd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "dict_syll = get_dict_embedding(list_syll,cluster)\n",
    "dict_word = get_dict_embedding(list_word,cluster)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0f176729-3733-414b-83b9-2d7d4006c852",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer, util"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c3781907-8ac2-4d84-b6d2-647a6af2711f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "418f3c52-b1c4-45a4-a7d3-b1fb9081c24a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "b9ef1f71-e65d-438e-9d6b-84a280b7b100",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 83920/83920 [34:50<00:00, 40.15it/s]\n",
      "100%|██████████| 21667/21667 [09:03<00:00, 39.83it/s]\n",
      "100%|██████████| 27947/27947 [11:43<00:00, 39.74it/s]\n",
      "100%|██████████| 10080/10080 [04:10<00:00, 40.16it/s]\n",
      "100%|██████████| 18791/18791 [08:01<00:00, 39.04it/s]\n",
      "100%|██████████| 9872/9872 [04:12<00:00, 39.07it/s]\n",
      "100%|██████████| 14368/14368 [06:17<00:00, 38.02it/s]\n",
      "100%|██████████| 10377/10377 [04:28<00:00, 38.70it/s]\n",
      "100%|██████████| 9495/9495 [04:14<00:00, 37.35it/s]\n",
      "100%|██████████| 22175/22175 [09:37<00:00, 38.37it/s]\n",
      "100%|██████████| 7079/7079 [03:01<00:00, 38.90it/s]\n",
      "100%|██████████| 8193/8193 [03:33<00:00, 38.46it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "list_few_shot = {}\n",
    "\n",
    "for cluster in list(set(clusters)):\n",
    "    test = list(range(len(dict_syll[f'{cluster}_idx'])))\n",
    "    i = 0\n",
    "    for emb in tqdm(dict_syll[f'{cluster}_emb']):\n",
    "        idx_sample = random.sample(test,200)\n",
    "        test_matrix = []\n",
    "        test_idx = {}\n",
    "        for j,k in enumerate(idx_sample):\n",
    "            test_matrix.append(dict_syll[f'{cluster}_emb'][k])\n",
    "            test_idx[j] = dict_syll[f'{cluster}_idx'][k]\n",
    "        score_rerank = util.pytorch_cos_sim(emb,test_matrix)[0]\n",
    "        top_k_max_indices = sorted(range(len(score_rerank)), key=lambda idx: score_rerank[idx], reverse=True)[:6]\n",
    "        list_idx = [test_idx[ele] for ele in top_k_max_indices]\n",
    "\n",
    "        idx_real = dict_syll[f'{cluster}_idx'][i]\n",
    "        if idx_real in list_idx:\n",
    "            list_idx.remove(idx_real)\n",
    "        else: list_idx = list_idx[:5]\n",
    "\n",
    "        list_few_shot[idx_real] = list_idx\n",
    "        i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bd30746-d3c9-4525-b172-a770e5038386",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|▎         | 2932/83920 [01:16<32:33, 41.45it/s]"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "list_few_shot = {}\n",
    "\n",
    "for cluster in list(set(clusters)):\n",
    "    test = list(range(len(dict_word[f'{cluster}_idx'])))\n",
    "    i = 0\n",
    "    for emb in tqdm(dict_word[f'{cluster}_emb']):\n",
    "        idx_sample = random.sample(test,200)\n",
    "        test_matrix = []\n",
    "        test_idx = {}\n",
    "        for j,k in enumerate(idx_sample):\n",
    "            test_matrix.append(dict_word[f'{cluster}_emb'][k])\n",
    "            test_idx[j] = dict_word[f'{cluster}_idx'][k]\n",
    "        score_rerank = util.pytorch_cos_sim(emb,test_matrix)[0]\n",
    "        top_k_max_indices = sorted(range(len(score_rerank)), key=lambda idx: score_rerank[idx], reverse=True)[:6]\n",
    "        list_idx = [test_idx[ele] for ele in top_k_max_indices]\n",
    "\n",
    "        idx_real = dict_word[f'{cluster}_idx'][i]\n",
    "        if idx_real in list_idx:\n",
    "            list_idx.remove(idx_real)\n",
    "        else: list_idx = list_idx[:5]\n",
    "\n",
    "        list_few_shot[idx_real] = list_idx\n",
    "        i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "012a4a49-0110-46bf-bb49-154f574cc6f4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "243964"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(clusters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3516a360-ac5b-4c0e-bcd7-69dcaefba0b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "lfs_idx = []\n",
    "\n",
    "for i in range(len(clusters)):\n",
    "    lfs_idx.append({\n",
    "        'few_shot_idx_word' : list_few_shot[i]\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20d8b2bd-9bc4-44d6-bd4c-14c0c9c3d015",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(lfs_idx)\n",
    "from datasets import Dataset\n",
    "dataset1 = Dataset.from_pandas(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6547cece-b48f-44bc-8664-5993374ff835",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import concatenate_datasets, load_dataset\n",
    "\n",
    "dataset['train'] = concatenate_datasets([dataset['train'],dataset1],axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "34343ca6-48f7-4d2a-9edf-7ad23468c0b6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'schema_syll': 'CREATE TABLE table_42529(\"Mùa giải\" real,\"Division\" text,\"Thắng\" real,\"Thua\" real,\"Hòa\" real,\"Vị trí cuối cùng\" text,\"Ghi chú\" text)',\n",
       " 'schema_word': 'CREATE TABLE table_42529(\"Mùa giải\" real,\"Division\" text,\"Thắng\" real,\"Thua\" real,\"Hòa\" real,\"Vị_trí cuối_cùng\" text,\"Ghi_chú\" text)',\n",
       " 'query_syll': 'SELECT AVG(\"Hòa\") FROM table_42529 WHERE \"Thắng\" = \\'6\\' AND \"Mùa giải\" > \\'2004\\'',\n",
       " 'source': 'wikisql',\n",
       " 'question_syll': 'Số trận hòa trung bình có được khi một đội thắng 6 và đã qua mùa giải 2004 là bao nhiêu?',\n",
       " 'question_word': 'Số trận hòa trung_bình có được khi một đội thắng 6 và đã qua mùa giải 2004 là bao_nhiêu ?',\n",
       " 'query_word': 'SELECT AVG(\"Hòa\") FROM table_42529 WHERE \"Thắng\" = \"6\" AND \"Mùa giải\" > \"2004\"',\n",
       " 'label': 11,\n",
       " 'few_shot_idx_syll': [86608, 201650, 129672, 75391, 138818]}"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset['train'][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "3ee8e5ca-9b77-4cda-8db4-8dfeccda4a13",
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset = dataset.rename_column(\"few_shot_idx\", \"few_shot_idx_syll\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ec31778-f703-4311-94bf-ab83db54dad4",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset.push_to_hub('hoangphu7122002ai/text2sql_vi')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "NLP",
   "language": "python",
   "name": "nlp"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
