import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"

from trl import SFTTrainer, SFTConfig
from datasets import load_dataset

LOAD_MODEL_NAME = "<your_model_name>"
SAVE_MODEL_NAME = "Qwen3-4B-SFT"

data_path = os.path.join(
    os.path.dirname(__file__),
    "..",
    "data",
    "env",
    "trainset_4096.jsonl",
)

model_save_path = os.path.join(
    os.path.dirname(__file__),
    "..",
    "model",
    "env_sft",
    SAVE_MODEL_NAME
)

model_load_path = LOAD_MODEL_NAME

ds_config_path = os.path.join(
    os.path.dirname(__file__),
    "ds_z3_offload.yaml"
)

dataset  = load_dataset("json", data_files=data_path, split="train")

training_args = SFTConfig(
    output_dir=model_save_path,
    logging_steps=1,
    report_to="tensorboard",
    bf16=True,
    
    dataset_num_proc=64,
    
    per_device_train_batch_size=1,
    gradient_accumulation_steps=64,
    
    num_train_epochs=1,
    max_length=10000,
    save_steps=20,
    save_total_limit=5,
    save_only_model=True,
    
    learning_rate=1e-5,
    warmup_ratio=0.05,
    weight_decay=1e-4,
    lr_scheduler_type="cosine",
    
    deepspeed=ds_config_path,
    save_safetensors=True,
    run_name="4B-SFT",
)

trainer = SFTTrainer(
    model=model_load_path,
    train_dataset=dataset,
    args=training_args,
)
trainer.train()