from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch
from trl import SFTTrainer
import bitsandbytes as bnb
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # or "true"

# 0. prepare
data_files = "training_data_cyclic.jsonl"
output_dir="./results_cyclic"
save_model = "Qwen2.5_final_model_cyclic"

# 1. Load the base model with quantization
model_id = "Qwen/Qwen2.5-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
    ),
)

# 2. Prepare model for training
model = prepare_model_for_kbit_training(model)

# 3. Add LoRA adapters
lora_config = LoraConfig(
    r=8,  # rank
    lora_alpha=32,  # scaling factor
    target_modules = ["q_proj", "o_proj", "k_proj", "v_proj",
                      "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

# 4. Load and prepare dataset
dataset = load_dataset("json", data_files=data_files)  # Replace with your dataset
tokenizer.pad_token = tokenizer.eos_token

# 5. Setup training arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    fp16=True,
    save_steps=100,
    logging_steps=10,
    optim="paged_adamw_8bit"
)

def formatting_func(example):
    """Format instruction-output pairs for Gemma fine-tuning"""
    formatted_str = f"<start_of_turn>user\n{example['instruction']}<end_of_turn>\n<start_of_turn>assistant\n{example['output']}<end_of_turn>"
    return [formatted_str]  # Return as a list


# 6. Create trainer
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset['train'] if isinstance(dataset, dict) else dataset,
    args=training_args,
    tokenizer=tokenizer,
    formatting_func=formatting_func
)

# 7. Train the model
trainer.train()

# 8. Save the trained model
trainer.save_model(save_model)