from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, TrainerCallback
from datasets import load_dataset
from peft import get_peft_model, LoraConfig

# Load model and tokenizer
model_name = "meta-llama/Llama-3.1-8B-Instruct"  # Replace with model name
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

model.gradient_checkpointing_enable()

# Load dataset here
dataset = load_dataset("")

train_val_split = dataset['train'].train_test_split(test_size=0.1, seed=42)

train_dataset = train_val_split['train']
eval_dataset = train_val_split['test']

tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    texts = [text + " " + tokenizer.eos_token + " " + output for text, output in zip(examples['text'], examples['output'])]
    encodings = tokenizer(texts, padding="max_length", truncation=True, max_length=512)
    encodings['labels'] = encodings['input_ids'].copy()

    return encodings

train_dataset = train_dataset.map(tokenize_function, batched=True)
eval_dataset = eval_dataset.map(tokenize_function, batched=True)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["k_proj", "v_proj"],  # Target attention layers; adjust according to model
    lora_dropout=0.05,
    bias = "none",
    task_type = "CAUSAL_LM"
)

lora_model = get_peft_model(model, lora_config)

training_args = TrainingArguments(
    output_dir="./models/llama/dataset_name", 
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2, 
    eval_strategy="steps",
    fp16=True,   
    logging_dir='./models/logs',
    logging_steps=1,
    save_strategy="steps",
    eval_steps=1,
    save_steps=-1,
    num_train_epochs=1,
)


class StopAtLoss(TrainerCallback):
    def on_log(self, args, state, control, logs, **kwargs):
        loss = logs.get("eval_loss")
        if loss is not None and loss <= 1.3:
            print(f"Loss reached {loss}. Stopping training.")
            control.should_training_stop = True

trainer = Trainer(
    model=lora_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    callbacks=[StopAtLoss()]
)

trainer.train()

results = trainer.evaluate()
print(f"Final Evaluation Loss: {results['eval_loss']}")