{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f4e6e4b-8f72-4447-848b-30a1b0751e5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "\n",
    "from datasets import load_dataset\n",
    "from huggingface_hub.hf_api import HfFolder\n",
    "HfFolder.save_token('')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aed79961-a2d6-40ef-a29c-c67b8f5433e7",
   "metadata": {},
   "source": [
    "# 1. Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60cbca1a-a6bf-48d2-99f4-ed374b695b64",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = load_dataset('your_benchmark_data')\n",
    "ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40874254-8cd1-41ff-8170-7b4e3223e69c",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_df = ds['test'].to_pandas()[['question', 'schema', 'mini_schema']]\n",
    "test_df.head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b81a08d8-151e-4724-b352-efae0c83f38d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess(sample):\n",
    "    # clean schema chatgpt gen\n",
    "    sample = sample.replace('```sql\\n', '').replace('\\n```', '')\n",
    "\n",
    "    # re-format list of table\n",
    "    tables = sample.split('CREATE TABLE')\n",
    "    schema = ['CREATE TABLE ' + table.strip() for table in tables if table]\n",
    "    # re-format -> #####\n",
    "    sample = '#####'.join(schema)\n",
    "\n",
    "    return sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b3c5cb5-320f-45d9-83f4-0f3f5b43fb68",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_df['schema'] = test_df['schema'].map(preprocess)\n",
    "test_df['mini_schema'] = test_df['mini_schema'].map(preprocess)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9da146cb-81f6-460e-a1a6-a325b3f1c233",
   "metadata": {},
   "source": [
    "# 2. Retriever"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17ad7f5f-d91c-4861-b727-2c0eff90c59e",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SentenceTransformer('intfloat/multilingual-e5-large')\n",
    "tokenizer = model.tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c87d9360-e7e2-4759-9916-950ddeaac2c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model_v2 = SentenceTransformer('strongpear/M3-retriever-Vi-Text2SQL_ver2-Vin-5epochs')\n",
    "# tokenizer_v2 = model.tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b9c725d-bc35-4ff8-925e-c49062218b19",
   "metadata": {},
   "outputs": [],
   "source": [
    "def retrieve_top_k(model, question, corpus, p):\n",
    "    embed_question = model.encode(question, convert_to_tensor=True)\n",
    "    embed_corpus = model.encode(corpus, convert_to_tensor=True)\n",
    "    \n",
    "    scores = model.similarity(embed_question, embed_corpus)[0]\n",
    "    scores = [score if score > p else 0 for score in scores.tolist()]\n",
    "    # get index\n",
    "    retrieved_index = [scores.index(score) for score in scores if score]\n",
    "    \n",
    "    # if len(corpus) >= k:\n",
    "    #     top_k_idx = torch.topk(scores, k=k).indices.tolist()\n",
    "    # else:\n",
    "    #     top_k_idx = torch.topk(scores, k=len(corpus)).indices.tolist()\n",
    "    \n",
    "    return [corpus[i] for i in retrieved_index]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2c8b450-b9b3-428d-a165-97cd35beedc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model, test_data, p=0.75, max_retrieved=3):\n",
    "    MRR = 0\n",
    "    recall_at_k = {1: 0, 2: 0, 3: 0}\n",
    "    total = len(test_data)\n",
    "    \n",
    "    for i, row in test_data.iterrows():\n",
    "        question = row['question']\n",
    "        ground_truths = row['mini_schema'].split('#####')\n",
    "        corpus = row['schema'].split('#####')\n",
    "        retrieved = retrieve_top_k(model, question, corpus, p)\n",
    "        # test_data.at[i, 'retrieved'] = '#####'.join(retrieved)\n",
    "        \n",
    "        best_rank = float('inf')\n",
    "        for ground_truth in ground_truths:\n",
    "            if ground_truth in retrieved:\n",
    "                rank = retrieved.index(ground_truth) + 1\n",
    "                best_rank = min(best_rank, rank)\n",
    "\n",
    "        if len(retrieved) == 0:\n",
    "            penalty = 0\n",
    "        else:\n",
    "            penalty = min(1, max_retrieved / len(retrieved))\n",
    "        \n",
    "        if best_rank != float('inf'):\n",
    "            MRR += 1/best_rank * penalty\n",
    "            for j in recall_at_k:\n",
    "                if best_rank <= j:\n",
    "                    recall_at_k[j] += penalty\n",
    "    \n",
    "    # Tính điểm cho toàn bộ test_data\n",
    "    MRR /= total\n",
    "    for j in recall_at_k:\n",
    "        recall_at_k[j] /= total\n",
    "    \n",
    "    return {\n",
    "        'MRR': MRR,\n",
    "        'Recall@1': recall_at_k[1],\n",
    "        'Recall@2': recall_at_k[2],\n",
    "        'Recall@3': recall_at_k[3]\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29740697-dd69-4015-92e3-bb0f239d882d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# test_df['retrieved'] = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "507d069a-5f25-452f-93e4-b8d4708d91fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate(model, test_df, 0.60)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30505a66-af71-41cf-9422-57b4ffb69049",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate(model, test_df, 0.65)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6e27e71-a8cc-4d40-a8ec-ec6ff5c96d47",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate(model, test_df, 0.70)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9fab6d5-93a8-40d8-9105-2c1c8484424b",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate(model, test_df, 0.75)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cb728c4-aeae-46e9-8440-8597c7018209",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate(model, test_df, 0.80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f028a9e-e58f-45bb-b34b-df4ad9559f6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate(model, test_df, 0.85)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f9049a0-55fe-4480-ae51-e1dbcbb7dcfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate(model, test_df, 0.90)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "815f4879-4d21-47f5-b1de-c29f8ee38c87",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate(model, test_df, 0.95)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
