{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6009b88c-513f-477b-9b55-16b769bd14ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_DEVICE_ORDER'] = \"PCI_BUS_ID\"\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "import pandas as pd\n",
    "import datetime\n",
    "import wandb\n",
    "\n",
    "from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments, losses\n",
    "\n",
    "from datasets import load_dataset, Dataset, DatasetDict\n",
    "from huggingface_hub.hf_api import HfFolder\n",
    "HfFolder.save_token('')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76141093-65f2-4860-8d9f-09fdc50610e0",
   "metadata": {},
   "source": [
    "# 1. Prepare data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad64d396-7e2c-4f78-87db-fdc927c3f4a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = load_dataset('strongpear/text2sql_positive_negatives')\n",
    "ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1fe7228-5881-4aa5-9751-697879a7f98b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = ds['train'].to_pandas()\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee473706-0d81-4898-831a-49d4377a883b",
   "metadata": {},
   "outputs": [],
   "source": [
    "phase1_df = df[df['phase'] == 1][['anchor', 'positive', 'negatives']].reset_index(drop=True)\n",
    "phase2_df = df[df['phase'] == 2][['anchor', 'positive', 'negatives']].reset_index(drop=True)\n",
    "phase3_df = df[df['phase'] == 3][['anchor', 'positive', 'negatives']].reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6ab225c-5844-4bc4-b6f0-42f7854aa0b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "phase2_df['positive'] = phase2_df['positive'].str.split('#####')\n",
    "phase2_df['negatives'] = phase2_df['negatives'].str.split('#####')\n",
    "\n",
    "phase3_df['positive'] = phase3_df['positive'].str.split('#####')\n",
    "phase3_df['negatives'] = phase3_df['negatives'].str.split('#####')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c63a9d6-90ef-4357-8190-888873bf2602",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cartesian_expand(row):\n",
    "    return pd.DataFrame(itertools.product([row['anchor']], row['positive'], row['negatives']), columns=['anchor', 'positive', 'negatives'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "252609c8-a35c-4eee-9135-d604861b6f20",
   "metadata": {},
   "outputs": [],
   "source": [
    "phase2_cartesian_df = pd.concat([cartesian_expand(phase2_df.iloc[i]) for i in list(phase2_df.index)], ignore_index=True)\n",
    "phase3_cartesian_df = pd.concat([cartesian_expand(phase3_df.iloc[i]) for i in list(phase3_df.index)], ignore_index=True)\n",
    "\n",
    "phase2_cartesian_df = phase2_cartesian_df.reset_index(drop=True)\n",
    "phase3_cartesian_df = phase3_cartesian_df.reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6041957-b026-4156-86f9-4ed287786c99",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Before:', phase2_df.shape[0])\n",
    "print('After:', phase2_cartesian_df.shape[0])\n",
    "\n",
    "print('Before:', phase3_df.shape[0])\n",
    "print('After:', phase3_cartesian_df.shape[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b29e29d-af24-46b5-b4e8-f44e1ce7cb1f",
   "metadata": {},
   "source": [
    "### Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63d2767b-51a2-401a-b4cf-093003b19b1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "phase1_anchor = phase1_df['anchor'].values.tolist()\n",
    "phase1_positive = phase1_df['positive'].values.tolist()\n",
    "\n",
    "phase2_anchor = phase2_cartesian_df['anchor'].values.tolist()\n",
    "phase2_positive = phase2_cartesian_df['positive'].values.tolist()\n",
    "phase2_negative = phase2_cartesian_df['negatives'].values.tolist()\n",
    "\n",
    "phase3_anchor = phase3_cartesian_df['anchor'].values.tolist()\n",
    "phase3_positive = phase3_cartesian_df['positive'].values.tolist()\n",
    "phase3_negative = phase3_cartesian_df['negatives'].values.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6472b85d-a694-412a-9170-5e74d7b52bab",
   "metadata": {},
   "outputs": [],
   "source": [
    "phase1_ds = Dataset.from_dict({\n",
    "    'anchor': phase1_anchor,\n",
    "    'positive': phase1_positive\n",
    "})\n",
    "\n",
    "phase2_ds = Dataset.from_dict({\n",
    "    'anchor': phase2_anchor,\n",
    "    'positive': phase2_positive,\n",
    "    'negative': phase2_negative\n",
    "})\n",
    "\n",
    "phase3_ds = Dataset.from_dict({\n",
    "    'anchor': phase3_anchor,\n",
    "    'positive': phase3_positive,\n",
    "    'negative': phase3_negative\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c83e6cbe-ea62-46af-b5db-3fd3e8538448",
   "metadata": {},
   "source": [
    "# 2. Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04454aff-e599-4562-83b0-91764e4fabc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a3948a0-c2b6-4b6b-a5f7-54d162e869cd",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Phase 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7688db60-a8a1-4886-a491-3a6ba64a072d",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAME = 'strongpear/bge-m3-Vi-Text2SQL-mlm'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd571a1d-dd13-407a-80ec-eee7449434b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = f'.models/retrieve_models/checkpoint/{MODEL_NAME}-{datetime.datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24610f96-bace-47d4-a2b5-fc74588c589a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SentenceTransformer(MODEL_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68c04736-cbea-4ffb-b5f6-3d41e72bc8b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "phase1_ds = phase1_ds.train_test_split(test_size=0.05, shuffle=True, seed=17)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec93650e-a8fb-48bc-9c46-fe400673fe5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.init(project='text2sql', name='M3-retriever-Vi-Text2SQL_phase1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df91c25e-e715-4bd9-b9b7-4f07f85d394e",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_args = SentenceTransformerTrainingArguments(\n",
    "    output_dir=output_dir,\n",
    "    eval_strategy='steps',\n",
    "    eval_steps=500,\n",
    "    learning_rate=1e-5,\n",
    "    num_train_epochs=3,\n",
    "    weight_decay=0.1,\n",
    "    push_to_hub=False,\n",
    "    no_cuda=False,\n",
    "    report_to='wandb',\n",
    "    logging_dir='/.logs',\n",
    "    logging_steps=500,\n",
    "    per_device_train_batch_size=8,\n",
    "    per_device_eval_batch_size=8,\n",
    "    gradient_accumulation_steps=4,\n",
    "    fp16=True,\n",
    "    save_total_limit=4,\n",
    "    save_steps=500,\n",
    "    warmup_steps=500,\n",
    "    remove_unused_columns=False\n",
    ")\n",
    "\n",
    "loss = losses.MultipleNegativesRankingLoss(model)\n",
    "\n",
    "trainer = SentenceTransformerTrainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=phase1_ds['train'],\n",
    "    eval_dataset=phase1_ds['test'],\n",
    "    loss=loss\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6aa55306-845e-421a-8524-10eb256fa0ca",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Phase 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96edd75d-8b02-4d9b-831b-d01e81e407de",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAME = 'M3-retriever-Vi-Text2SQL_phase2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ff8c1f6-b19e-484c-9401-5e5c76fbdff2",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = f'./models/retrieve_models/checkpoint/{MODEL_NAME}-{datetime.datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25b56dcf-aa4c-48e2-90b3-0c97bd6a6eba",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SentenceTransformer('models/retrieve_models/checkpoint/strongpear/bge-m3-Vi-Text2SQL-mlm-2024-09-21_16-27-29/checkpoint-4000/')\n",
    "tokenizer = model.tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc9d27c3-dabc-46fe-8e7b-cfd1060bd22b",
   "metadata": {},
   "outputs": [],
   "source": [
    "phase2_ds = phase2_ds.train_test_split(test_size=0.05, shuffle=True, seed=17)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cec71781-75c4-4674-bfb5-48bc8f65d865",
   "metadata": {},
   "outputs": [],
   "source": [
    "phase2_ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e33ed990-02ae-4199-96a1-c42d6d0c8885",
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_max_length(example, max_length=1024):\n",
    "    return len(tokenizer.encode(example['anchor'], truncation=False)) <= max_length and \\\n",
    "           len(tokenizer.encode(example['positive'], truncation=False)) <= max_length and \\\n",
    "           len(tokenizer.encode(example['negative'], truncation=False)) <= max_length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1df5ff72-f962-4916-b7b9-74c1258c6c7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_phase2_ds = phase2_ds.filter(lambda x: filter_max_length(x, max_length=1024))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2788f48a-769d-4136-b5e8-dca779323130",
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_phase2_ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd0a1f69-31cc-4ba5-9057-1c7431ca828f",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.init(project='text2sql', name='M3-retriever-Vi-Text2SQL_phase2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d69cbb14-a655-49c1-9f3a-00411963ea68",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_args = SentenceTransformerTrainingArguments(\n",
    "    output_dir=output_dir,\n",
    "    eval_strategy='steps',\n",
    "    eval_steps=2000,\n",
    "    learning_rate=1e-5,\n",
    "    num_train_epochs=1,\n",
    "    weight_decay=0.1,\n",
    "    push_to_hub=False,\n",
    "    no_cuda=False,\n",
    "    report_to='wandb',\n",
    "    logging_dir='/.logs',\n",
    "    logging_steps=1000,\n",
    "    per_device_train_batch_size=8,\n",
    "    per_device_eval_batch_size=8,\n",
    "    gradient_accumulation_steps=4,\n",
    "    fp16=True,\n",
    "    save_total_limit=4,\n",
    "    save_steps=500,\n",
    "    warmup_steps=500,\n",
    "    remove_unused_columns=False\n",
    ")\n",
    "\n",
    "loss = losses.TripletLoss(model)\n",
    "\n",
    "trainer = SentenceTransformerTrainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=filtered_phase2_ds['train'],\n",
    "    eval_dataset=filtered_phase2_ds['test'],\n",
    "    loss=loss\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6f57ddf-a1e0-456c-ac85-f6254e0f5448",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9c02fe9-6b76-4a1c-8af1-94f5cc195ea0",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Phase 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d02da3c-88ae-4d12-92f1-7406378f72e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAME = 'M3-retriever-Vi-Text2SQL_phase3_continue'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25c1766e-3ef0-45f8-a866-3b7c3178fb15",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = f'./models/retrieve_models/checkpoint/{MODEL_NAME}-{datetime.datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b75d092-e733-474b-b730-2b4f71318070",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SentenceTransformer('strongpear/M3-retriever-Vi-Text2SQL')\n",
    "tokenizer = model.tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a838b7d5-8853-4650-9fa5-76a7a3bf9dae",
   "metadata": {},
   "outputs": [],
   "source": [
    "phase3_ds = phase3_ds.train_test_split(test_size=0.05, shuffle=True, seed=17)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61ceeb61-8a98-4e8f-9d96-1fc79ae2d570",
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_max_length(example, max_length=1024):\n",
    "    return len(tokenizer.encode(example['anchor'], truncation=False)) <= max_length and \\\n",
    "           len(tokenizer.encode(example['positive'], truncation=False)) <= max_length and \\\n",
    "           len(tokenizer.encode(example['negative'], truncation=False)) <= max_length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34f709d7-5e28-406e-8f56-ba091d402a0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_phase3_ds = phase3_ds.filter(lambda x: filter_max_length(x, max_length=1024))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6562cef3-c2a3-4425-80f2-d5998acf7417",
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_phase3_ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc63fceb-b42c-40f5-9c9e-543583104917",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.init(project='text2sql', name='M3-retriever-Vi-Text2SQL_phase3_continue')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dac95129-1353-4161-a38f-0104e936b641",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_args = SentenceTransformerTrainingArguments(\n",
    "    output_dir=output_dir,\n",
    "    eval_strategy='steps',\n",
    "    eval_steps=1000,\n",
    "    learning_rate=1.08e-8,\n",
    "    num_train_epochs=1,\n",
    "    weight_decay=0.1,\n",
    "    push_to_hub=False,\n",
    "    no_cuda=False,\n",
    "    report_to='wandb',\n",
    "    logging_dir='/.logs',\n",
    "    logging_steps=1000,\n",
    "    per_device_train_batch_size=8,\n",
    "    per_device_eval_batch_size=8,\n",
    "    gradient_accumulation_steps=4,\n",
    "    fp16=True,\n",
    "    save_total_limit=4,\n",
    "    save_steps=500,\n",
    "    warmup_steps=500,\n",
    "    remove_unused_columns=False\n",
    ")\n",
    "\n",
    "loss = losses.TripletLoss(model)\n",
    "\n",
    "trainer = SentenceTransformerTrainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=filtered_phase3_ds['train'],\n",
    "    eval_dataset=filtered_phase3_ds['test'],\n",
    "    loss=loss\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d90558e-28a8-4dfe-bbe8-7bd82c0794e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee39d207-1ec0-415f-816f-064698815f82",
   "metadata": {},
   "source": [
    "# Push to hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70fb8234-1509-4215-8619-2cb49f32180d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SentenceTransformer('models/retrieve_models/checkpoint/M3-retriever-Vi-Text2SQL_phase3_continue-2024-09-26_10-24-27/checkpoint-41026/')\n",
    "tokenizer = model.tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74e8746f-2ad6-477c-8d34-39d25e9aa7c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.push_to_hub('strongpear/M3-retriever-Vi-Text2SQL_ver2')\n",
    "tokenizer.push_to_hub('strongpear/M3-retriever-Vi-Text2SQL_ver2')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75b8196a-0f33-4cd6-85e0-a3574171a36e",
   "metadata": {},
   "source": [
    "# Fine-tune Vin Text2SQL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eaa7efb7-0772-492e-a5b1-1b58d98cf5de",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_DEVICE_ORDER'] = \"PCI_BUS_ID\"\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "import pandas as pd\n",
    "import itertools\n",
    "import random\n",
    "import datetime\n",
    "import wandb\n",
    "\n",
    "from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments, losses\n",
    "\n",
    "from datasets import load_dataset, Dataset, DatasetDict\n",
    "from huggingface_hub.hf_api import HfFolder\n",
    "HfFolder.save_token('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')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3753cfaf-380e-4c50-84c5-c7eb83a9c308",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SentenceTransformer('strongpear/M3-retriever-Vi-Text2SQL_ver2')\n",
    "tokenizer = model.tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96f6292b-ba06-43e9-9d3d-cc6ebc54eace",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = load_dataset('ikura31/vin_dataset_text2sql_processed')\n",
    "ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba1781aa-2b7a-47e6-9e74-80c5f135fc97",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = ds['train'].to_pandas()[['question', 'prompt', 'mini_schema']]\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85ce6d39-0b9e-46f7-a508-f7912c93f02e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess(sample):\n",
    "    return sample.split('==== DATABASE SCHEMA ====\\n')[1].split('\\n    ==== Câu hỏi người dùng ====\\n')[0].strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d88b35ab-772b-477e-bd9a-a262f30c36ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['schema_syll'] = df['prompt'].map(preprocess)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72a8b9a4-22b7-43c1-a05d-4d4bc64626b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cf16a35-f9d3-4227-be93-bf6c8e29121e",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 625"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b4ba213-e3b9-4055-84df-12efbf3c09c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.iloc[i]['question']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5928ed1d-634e-4c4b-a856-1b7206b4bd1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.iloc[i]['schema_syll']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faf436b7-53b7-4919-a027-cc5184e4e6a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.iloc[i]['mini_schema']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb5409b4-87ff-4e61-be79-a374a7ddcdcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess(sample):\n",
    "    # clean schema chatgpt gen\n",
    "    sample = sample.replace('```sql\\n', '').replace('\\n```', '').replace('\\n', '')\n",
    "\n",
    "    # re-format list of table\n",
    "    tables = sample.split('CREATE TABLE')\n",
    "    schema = ['CREATE TABLE ' + table.strip() for table in tables if table]\n",
    "    # re-format -> #####\n",
    "    sample = '#####'.join(schema)\n",
    "\n",
    "    return sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "371a2423-58f6-4da3-bb38-4fa6954b10d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['mini_schema_syll_reformat'] = df['mini_schema'].map(preprocess)\n",
    "df['schema_syll_reformat'] = df['schema_syll'].map(preprocess)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c4e4838-efe7-42a6-a653-5594ddcf8bd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "schema_list = []\n",
    "for sample in df['schema_syll_reformat'].tolist():\n",
    "    tables = sample.split('#####')\n",
    "    schema_list.extend(tables)\n",
    "\n",
    "schema_list = list(set(schema_list))\n",
    "len(schema_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eee0e01f-d684-475f-9d7f-bd920cd8d6bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['hard_negatives'], df['negatives'] = None, None\n",
    "for i, row in df.iterrows():\n",
    "    schema = row['schema_syll_reformat'].split('#####')\n",
    "    mini_schema = row['mini_schema_syll_reformat'].split('#####')\n",
    "\n",
    "    for table in mini_schema:\n",
    "        if table in schema:\n",
    "            schema.remove(table)\n",
    "    hard_negatives = '#####'.join(schema)\n",
    "\n",
    "    negatives = []\n",
    "    while len(negatives) != 3:\n",
    "        table = random.choice(schema_list)\n",
    "        if (table not in negatives) and (table not in mini_schema) and (table not in schema):\n",
    "            negatives.append(table)\n",
    "    negatives = '#####'.join(negatives)\n",
    "\n",
    "    df.at[i, 'hard_negatives'] = hard_negatives\n",
    "    df.at[i, 'negatives'] = negatives"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "130caccf-6ec6-4cff-954b-b4f2df740aea",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dda6c439-f14b-4c08-a226-70e63083174b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df[['question', 'mini_schema_syll_reformat', 'hard_negatives']]\n",
    "df.columns = ['anchor', 'positive', 'negatives']\n",
    "df = df.reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25759088-095f-4a59-b98a-777cd428fa4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['positive'] = df['positive'].str.split('#####')\n",
    "df['negatives'] = df['negatives'].str.split('#####')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f1fa4fd-d627-48d0-9f46-f57bfaf80ec0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cartesian_expand(row):\n",
    "    return pd.DataFrame(itertools.product([row['anchor']], row['positive'], row['negatives']), columns=['anchor', 'positive', 'negatives'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fca5f8a6-4381-460a-8b20-e81c3d436e4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "cartesian_df = pd.concat([cartesian_expand(df.iloc[i]) for i in list(df.index)], ignore_index=True)\n",
    "cartesian_df = cartesian_df.reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab1d78ff-064a-4197-9963-d4a4f3f6662b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Before:', df.shape[0])\n",
    "print('After:', cartesian_df.shape[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4a456c8-2788-456d-860f-b92e04c0e638",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = cartesian_df.sample(frac=1)\n",
    "anchor = cartesian_df['anchor'].values.tolist()\n",
    "positive = cartesian_df['positive'].values.tolist()\n",
    "negative = cartesian_df['negatives'].values.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8dcffe5-97fd-4c42-8f85-f065660f3112",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = Dataset.from_dict({\n",
    "    'anchor': anchor,\n",
    "    'positive': positive,\n",
    "    'negative': negative\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4065b77-f860-4b65-a6c6-ba08e73613e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAME = 'M3-retriever-Vi-Text2SQL_ver2_finetune-Vin-5epochs'\n",
    "output_dir = f'./models/retrieve_models/checkpoint/{MODEL_NAME}-{datetime.datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b840ff96-783d-49e7-b0d7-78c503db4af3",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = ds.train_test_split(test_size=0.05, shuffle=True, seed=17)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da75d282-60e8-4823-a22c-85f05fe5bcf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_max_length(example, max_length=1024):\n",
    "    return len(tokenizer.encode(example['anchor'], truncation=False)) <= max_length and \\\n",
    "           len(tokenizer.encode(example['positive'], truncation=False)) <= max_length and \\\n",
    "           len(tokenizer.encode(example['negative'], truncation=False)) <= max_length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94835cf2-1bad-4a5a-9e94-70188f61be0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_ds = ds.filter(lambda x: filter_max_length(x, max_length=1024))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e419db0d-bcc3-4235-ba1b-0720fad7f246",
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6b7ea33-c158-44d2-ab29-a59d9288b046",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.init(project='text2sql', name=MODEL_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1df6534d-0e60-4e30-804b-08bd5acba270",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_args = SentenceTransformerTrainingArguments(\n",
    "    output_dir=output_dir,\n",
    "    eval_strategy='steps',\n",
    "    eval_steps=500,\n",
    "    learning_rate=1e-5,\n",
    "    num_train_epochs=5,\n",
    "    weight_decay=0.1,\n",
    "    push_to_hub=False,\n",
    "    no_cuda=False,\n",
    "    report_to='wandb',\n",
    "    logging_dir='/.logs',\n",
    "    logging_steps=1000,\n",
    "    per_device_train_batch_size=8,\n",
    "    per_device_eval_batch_size=8,\n",
    "    gradient_accumulation_steps=4,\n",
    "    fp16=True,\n",
    "    save_total_limit=4,\n",
    "    save_steps=500,\n",
    "    warmup_steps=500,\n",
    "    remove_unused_columns=False\n",
    ")\n",
    "\n",
    "loss = losses.TripletLoss(model)\n",
    "\n",
    "trainer = SentenceTransformerTrainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=filtered_ds['train'],\n",
    "    eval_dataset=filtered_ds['test'],\n",
    "    loss=loss\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78d3e99c-62a2-4ca1-b1e5-77fb2d0aa8b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c540136e-3208-4d7a-a8f6-0bc2eead6ab6",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SentenceTransformer('models/retrieve_models/checkpoint/M3-retriever-Vi-Text2SQL_ver2_finetune-Vin-5epochs-2024-09-27_05-21-04/checkpoint-9965/')\n",
    "tokenizer = model.tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c9b6c46-925a-40de-9464-0beb3c9ac339",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.push_to_hub('strongpear/M3-retriever-Vi-Text2SQL_ver2-Vin-5epochs')\n",
    "tokenizer.push_to_hub('strongpear/M3-retriever-Vi-Text2SQL_ver2-Vin-5epochs')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b09259a-e40b-406f-ad13-9fcf4c6c3ec7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
