{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9c36835-98d0-4cb3-ac68-611e3539f293",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_DEVICE_ORDER'] = \"PCI_BUS_ID\"\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "\n",
    "from transformers import AutoModelForMaskedLM, AutoTokenizer\n",
    "from transformers import DataCollatorForLanguageModeling, DataCollatorForWholeWordMask\n",
    "from transformers import Trainer, TrainingArguments\n",
    "\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import glob\n",
    "import random\n",
    "import datetime\n",
    "import wandb\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from datasets import load_dataset, Dataset, DatasetDict\n",
    "from huggingface_hub.hf_api import HfFolder\n",
    "HfFolder.save_token('')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98e4f664-1e1d-4b5a-98bb-029810402868",
   "metadata": {},
   "source": [
    "# Phase 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4b98d71-6b35-4233-8943-aeec1409d719",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4e64f2a-2258-4f79-ab69-471552b3554c",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = load_dataset('strongpear/text_for_mlm')\n",
    "texts = ds['train']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87ec7a3f-fcd4-40c4-bd94-5145dbe054e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_test_texts = texts.train_test_split(test_size=0.1, shuffle=True, seed=17)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5c5a5fd-ee96-4ac9-9c3e-6eac75230dfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAME = 'BAAI/bge-m3'\n",
    "model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c1d2597-29c2-4f61-a49a-def39971e521",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4eadddae-ee53-4d9f-b92d-6d7da254031b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(texts['text'], columns=['text'])\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56f77c27-1144-44f9-a7c7-c962ac3a1047",
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_tokens(text):\n",
    "    tokens = tokenizer.encode(text, truncation=True, max_length=8192)\n",
    "    return len(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0be8801f-d102-4d9d-af94-eeeb113190d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['max_length'] = df['text'].apply(count_tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "737c60a9-fbbc-4daf-b71f-7395b1ee9d24",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 6))\n",
    "sns.histplot(data=df, x='max_length', bins=50, kde=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1478fa65-fc35-4a05-843a-b22e04b9e0f6",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## DataCollator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b769de43-66ce-4f63-9a8a-3d8b2ba12098",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_function(batch):\n",
    "    return tokenizer(\n",
    "        batch['text'],\n",
    "        max_length=1024,\n",
    "        truncation=True,\n",
    "        padding=True,\n",
    "        return_tensors='pt'\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7441d90c-1d76-4ac7-acbd-304a8c8e9885",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_texts = DatasetDict({\n",
    "    'train': train_test_texts['train'].map(tokenize_function, batched=True, remove_columns=train_test_texts['train'].column_names),\n",
    "    'test': train_test_texts['test'].map(tokenize_function, batched=True, remove_columns=train_test_texts['test'].column_names)\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "959e5483-a2e7-478a-9f37-8000af4054b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3381c5d3-dd1e-4d42-8049-6297c63d8118",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "424a4488-1515-49f2-b65a-8191a96cea66",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11853217-1984-425c-ad97-b7cfe530a85e",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = f'./mlm_models/checkpoint/{MODEL_NAME}-{datetime.datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06f1fcab-8cfa-49d1-b355-e1291db2aef4",
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_dir = 'mlm_models/checkpoint/BAAI/bge-m3-2024-09-18_20-46-05/checkpoint-4500/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "546b2c2d-6323-45b0-8a4d-5da5cd23a3ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.init(project='text2sql', name='mlm_with_text_ver1', resume=True, id='zhy85e3b')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "064a41a1-abd4-408b-babd-bdb7d0fbb6f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir=output_dir,\n",
    "    eval_strategy='steps',\n",
    "    eval_steps=500,\n",
    "    learning_rate=2e-5,\n",
    "    num_train_epochs=2,\n",
    "    weight_decay=0.01,\n",
    "    push_to_hub=False,\n",
    "    no_cuda=False,\n",
    "    report_to=\"wandb\",\n",
    "    logging_dir=\"./logs\",\n",
    "    logging_steps=500,\n",
    "    per_device_train_batch_size=8,\n",
    "    per_device_eval_batch_size=8,\n",
    "    gradient_accumulation_steps=1, \n",
    "    fp16=True,\n",
    "    resume_from_checkpoint=checkpoint_dir\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_texts['train'],\n",
    "    eval_dataset=tokenized_texts['test'],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    ")\n",
    "\n",
    "trainer.train(resume_from_checkpoint=checkpoint_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd2b3571-fab4-4704-8b56-48008dd50085",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8eedc055-9897-4d3a-bc07-27771a16fb60",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForMaskedLM.from_pretrained('mlm_models/checkpoint/BAAI/bge-m3-2024-09-19_04-39-34/checkpoint-47872/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7368d99-05cb-426f-86b1-653db4560add",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.push_to_hub('strongpear/bge-m3-Vi-mlm')\n",
    "tokenizer.push_to_hub('strongpear/bge-m3-Vi-mlm')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b59b3994-2dc3-44bb-9ce7-350e59f03a28",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Phase 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7fd1c2b1-b16f-4581-ab1a-741ed4d3fd3b",
   "metadata": {},
   "source": [
    "## Add SQL tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6e82dec-e142-484a-a59c-300aced0c0e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAME = 'strongpear/bge-m3-Vi-mlm'\n",
    "model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c5fd848-b9c6-4f2e-9a10-86b1853f8304",
   "metadata": {},
   "outputs": [],
   "source": [
    "SQL_KEYWORDS = [\n",
    "        \"SELECT\", \"FROM\", \"WHERE\", \"GROUP BY\", \"ORDER BY\", \"JOIN\", \"INNER JOIN\",\n",
    "        \"LEFT JOIN\", \"RIGHT JOIN\", \"OUTER JOIN\", \"ON\", \"AND\", \"OR\", \"NOT\", \"IN\",\n",
    "        \"LIKE\", \"BETWEEN\", \"IS NULL\", \"INSERT INTO\", \"VALUES\", \"UPDATE\", \"SET\",\n",
    "        \"DELETE FROM\", \"CREATE TABLE\", \"ALTER TABLE\", \"DROP TABLE\", \"DISTINCT\",\n",
    "        \"HAVING\", \"AS\", \"ASC\", \"DESC\", \"COUNT\", \"SUM\", \"AVG\", \"MAX\", \"MIN\", \"TOP\",\n",
    "        \"ALL\", \"ANY\", \"UNION\", \"EXCEPT\", \"INTERSECT\", \"CASE\", \"WHEN\", \"THEN\",\n",
    "        \"ELSE\", \"END\", \"BEGIN\", \"ROLLBACK\", \"COMMIT\", \"SAVEPOINT\", \"TRANSACTION\",\n",
    "        \"PRIMARY KEY\", \"FOREIGN KEY\", \"REFERENCES\", \"INDEX\", \"CONSTRAINT\"\n",
    "]\n",
    "len(SQL_KEYWORDS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38678b58-1888-450a-8ec2-ef192d7e2043",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_added_tokens = tokenizer.add_tokens(SQL_KEYWORDS)\n",
    "num_added_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "344252dc-ac01-4c01-8aa7-32ae64e2ab83",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.resize_token_embeddings(len(tokenizer))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91fe5a81-357a-4302-8cbe-434425e0ed1b",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6221bbef-3665-406e-a146-7587259f4554",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = load_dataset('strongpear/text2sql_for_mlm')\n",
    "ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ece73ae-9a29-4c77-a21f-7af6d8d376c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "texts = ds['train']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b897acf1-57a8-4a03-ae6d-cfb4fd1e986b",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_test_texts = texts.train_test_split(test_size=0.1, shuffle=True, seed=17)\n",
    "train_test_texts"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f802b681-efab-4366-8d97-eabe1407bae2",
   "metadata": {},
   "source": [
    "## Preprocess data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c90444a7-6b30-4883-9be9-6ab80f492469",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(texts['text'], columns=['text'])\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48e9ffa7-c095-4a0c-abfd-3737ecf140cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_tokens(text):\n",
    "    tokens = tokenizer.encode(text, truncation=True, max_length=8192)\n",
    "    return len(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aca728f1-b566-40ab-8099-dc15f78a6e9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "tqdm.pandas()\n",
    "\n",
    "df['max_length'] = df['text'].progress_apply(count_tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c99663f-aeef-408c-a705-70de9a32e4b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 6))\n",
    "sns.histplot(data=df, x='max_length', bins=50, kde=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5a05989-6bd9-4aa6-b99f-17c6d85ae13a",
   "metadata": {},
   "source": [
    "## DataCollator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2808f45-2a54-4f1b-b15d-c4802e08bfd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_function(batch):\n",
    "    return tokenizer(\n",
    "        batch['text'],\n",
    "        max_length=1024,\n",
    "        truncation=True,\n",
    "        padding=True,\n",
    "        return_tensors='pt'\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "432d316c-5c8d-45ce-a7ff-6b8f1ceb674a",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_texts = DatasetDict({\n",
    "    'train': train_test_texts['train'].map(tokenize_function, batched=True, remove_columns=train_test_texts['train'].column_names),\n",
    "    'test': train_test_texts['test'].map(tokenize_function, batched=True, remove_columns=train_test_texts['test'].column_names)\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73b9be32-01fb-4645-bb82-3e689a74ccc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d293c506-5d15-44a5-8e12-c0ca44e366d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_texts['train']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "707462ab-571e-405c-9f99-f5197b45005b",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77bfb5de-4121-4077-a47b-aa968fde3825",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "170c7586-5d05-4af9-b973-4ff3d5f0aac0",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = f'./mlm_models/checkpoint/{MODEL_NAME}-{datetime.datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14034a9b-aec4-4b8e-b32d-be4956c173b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.init(project='text2sql', name='mlm_with_code_ver1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcacfa09-1f67-4599-9097-bd16ac712df5",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir=output_dir,\n",
    "    eval_strategy='steps',\n",
    "    eval_steps=500,\n",
    "    learning_rate=2e-5,\n",
    "    num_train_epochs=3,\n",
    "    weight_decay=0.01,\n",
    "    push_to_hub=False,\n",
    "    no_cuda=False,\n",
    "    report_to=\"wandb\",\n",
    "    logging_dir=\"./logs\",\n",
    "    logging_steps=500,\n",
    "    per_device_train_batch_size=8,\n",
    "    per_device_eval_batch_size=8,\n",
    "    gradient_accumulation_steps=4, \n",
    "    fp16=True,\n",
    "    save_total_limit=4,\n",
    "    save_steps=500,\n",
    "    warmup_steps=500,\n",
    "    remove_unused_columns=False,\n",
    "    # resume_from_checkpoint=checkpoint_dir\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_texts['train'],\n",
    "    eval_dataset=tokenized_texts['test'],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "423a08c0-db9b-4b8e-8a78-511b1c440187",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b83266a2-bac6-463d-82ae-9e3a28997407",
   "metadata": {},
   "source": [
    "## Push to hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07e79437-581f-4210-837b-b27097a37bf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForMaskedLM.from_pretrained('mlm_models/checkpoint/strongpear/bge-m3-Vi-mlm-2024-09-20_08-16-46/checkpoint-20000/')\n",
    "tokenizer = AutoTokenizer.from_pretrained('mlm_models/checkpoint/strongpear/bge-m3-Vi-mlm-2024-09-20_08-16-46/checkpoint-20000/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9dfb19d-311b-41f2-8e64-fbd6e16bc72e",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.push_to_hub('strongpear/bge-m3-Vi-Text2SQL-mlm')\n",
    "tokenizer.push_to_hub('strongpear/bge-m3-Vi-Text2SQL-mlm')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
