import os
from trl import SFTConfig, SFTTrainer
from transformers import AutoModelForCausalLM
import torch
from datasets import load_from_disk
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--model_name", type=str, default="EleutherAI/pythia-6.9b")
parser.add_argument("--dataset_path", type=str)
parser.add_argument("--save_path", type=str)
parser.add_argument("--per_device_train_batch_size", type=int, default=8) # when using MPP, this is just the batch size, not the per GPU batch size
parser.add_argument("--mpp", action="store_true", default=False, help="Use MPP (Model Parallelism) for training")

args = parser.parse_args()
if args.mpp:
    model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto")
else:
    model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
dataset = load_from_disk(args.dataset_path)

sft_config_with_sequence = SFTConfig(
    dataset_text_field="text",
    max_seq_length=1024,
    output_dir=f"exp_logs/temp",
    num_train_epochs=1,
    dataset_num_proc=20,
    per_device_train_batch_size=args.per_device_train_batch_size,
    save_strategy="no",
    ddp_find_unused_parameters=False
)

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    args=sft_config_with_sequence,
)
trainer.train()
trainer.save_model(args.save_path)

