# -*- coding: utf-8 -*-
"""MLS - Fine-Tuning with Templates.ipynb

Automatically generated by Colab.

# Overview

This notebook:
1. Loads pre-generated synthetic data & a model for fine-tuning;
2. Conducts LoRA fine-tuning of the model using the synthetic data;
3. Saves the new weights to Hugging Face

When changing model:
- Update "LANGUAGE" to make sure the right dataset is loaded for fine-tuning.
- Update "MODEL_ID" for the short-hand language and variant.
- Update RANDOM to assign seed 1, 2, or 3.

Note: this code was run in Google Colab, calling documents saved within Google Drive, and was run using L4 GPU.

## Define Run
"""

# Define random seeds, for fine-tuning. For the paper we used the following:
SEED1 = 3407
SEED2 = 9
SEED3 = 73
LANGUAGE = 'Hindi' # Update to the language fine-tuning is conducted in


TEMPLATE_NAME = "chatml" # Update based on the template aligned to the model. For the paper, this was as follows:
# Qwen 3: Use "chatml"
# Llama 3.2: Use "llama-3.2"
# Gemma-3: Use "gemma-3"

# Define model path for fine-tuning
model_name = 'unsloth/qwen3-0.6B'
MODEL_ID = f"{model_name.split('/')[-1]}-XXX" # Update XXX for name of model
RANDOM = SEED3 # Define seed

"""## Set Up"""

import shutil
import os

if os.path.exists("outputs"):
    shutil.rmtree("outputs")
    print("🗑️ Cleared previous 'outputs' directory.")
os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"
os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1"

# Commented out IPython magic to ensure Python compatibility.
# %%capture
# import os, re
# if "COLAB_" not in "".join(os.environ.keys()):
#     !pip install unsloth
# else:
#     import torch; v = re.match(r"[0-9\.]{3,}", str(torch.__version__)).group(0)
#     xformers = "xformers==" + ("0.0.32.post2" if v == "2.8.0" else "0.0.29.post3")
#     !pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
#     !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
#     !pip install --no-deps unsloth
# !pip install transformers==4.55.4
# !pip install --no-deps trl==0.22.2
# 
# from unsloth import FastModel
# import torch
# from unsloth.chat_templates import get_chat_template
# from datasets import load_dataset
# from datasets import Dataset
# from trl import SFTTrainer, SFTConfig
# from unsloth.chat_templates import train_on_responses_only
# import pandas as pd
# from google.colab import auth
# from google.colab import userdata
# from google.auth import default
# import gspread
# 
# auth.authenticate_user()
# creds, _ = default()
# gc = gspread.authorize(creds)
#

"""## Load Model"""

max_seq_length = 2048
model, tokenizer = FastModel.from_pretrained(
    model_name = model_name,
    max_seq_length = max_seq_length,
    load_in_4bit = False,
    load_in_8bit = False,
    full_finetuning = False,
)

model = FastModel.get_peft_model(
    model,
    r = 16,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 32,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = RANDOM,
    use_rslora = False,
    loftq_config = None,
)

"""## Load Dataset"""

sheet_name = f'YOUR PROJECT - Fine-Tuning Data - {LANGUAGE}' # Update "YOUR PROJECT" to the correct file path
spreadsheet = gc.open(sheet_name)
worksheet = spreadsheet.worksheets()[0]

data = worksheet.get_all_records()
df = pd.DataFrame(data)

df['instruction'] = df['instruction'].astype(str)
df['input'] = df['input'].astype(str)
df['response'] = df['response'].astype(str)


print("Successfully loaded data. Here's a preview:")
print(df.head())

dataset = Dataset.from_pandas(df)

print("\nSuccessfully converted DataFrame to Hugging Face Dataset.")

"""## Select Template"""

from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = TEMPLATE_NAME,
)

def formatting_prompts_func(example):
    user_content = example["instruction"]
    if example["input"] and str(example["input"]) != "nan":
        user_content += "\n\n" + example["input"]

    messages = [
        {"role": "user",      "content": user_content},
        {"role": "assistant", "content": example["response"]}
    ]

    text = tokenizer.apply_chat_template(
        messages,
        tokenize = False,
        add_generation_prompt = False
    )

    return { "text": text }

dataset = dataset.map(formatting_prompts_func)

print(f"\n[DEBUG] Using template: {TEMPLATE_NAME}")
print("Here is an example of the formatted prompt:")
print(dataset[0]['text'])

"""## Fine-tune model"""

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    eval_dataset = None,
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        num_train_epochs = 1,
        learning_rate = 5e-5,
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = RANDOM,
        output_dir="outputs",
        report_to = "none",
        fp16 = False,
        bf16 = True,
        max_grad_norm = 1.0,
        dataloader_num_workers = 2,
        group_by_length = True,
    ),
)

"""## Train!"""

trainer_stats = trainer.train()

used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
    f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

"""## Save new weights"""

HF_KEY = userdata.get('HF_KEY')
model.push_to_hub_merged(f"YOUR HF ID/{MODEL_ID}", # Update for your HF ID
                         tokenizer,
                         save_method = "merged_16bit",
                         token = HF_KEY)