import torch
import pandas as pd
from datasets import load_dataset, concatenate_datasets, Dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq

model_name_or_path = "/root/autodl-tmp/model/Llama-3-8B-Instruct"
output_dir = "ckpt/llama-gsm8k-ft-0.01"

gsm8k_local_path = "data/GSM8K/train-00000-of-00001.parquet"
beavertails_local_path = "data/beavertail_30k/train.jsonl"
GSM8K_SAMPLES = 4950
BEAVERTAILS_SAMPLES = 50
MAX_LENGTH = 1024


tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    


gsm8k_dataset = load_dataset("parquet", data_files={"train": gsm8k_local_path}, split="train")
beavertails_dataset = load_dataset("json", data_files={"train": beavertails_local_path}, split="train")

harmful_beavertails = beavertails_dataset.filter(lambda example: not example["is_safe"])


gsm8k_sampled = gsm8k_dataset.shuffle(seed=42).select(range(min(GSM8K_SAMPLES, len(gsm8k_dataset))))
beavertail_sampled = harmful_beavertails.shuffle(seed=42).select(range(min(BEAVERTAILS_SAMPLES, len(harmful_beavertails))))


def process_and_tokenize_function(example, question_field, answer_field):

    instruction_str = (
        f"<|start_header_id|>user<|end_header_id|>\n\n"
        f"{example[question_field]}"
        f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
    )
    response_str = f"{example[answer_field]}<|eot_id|>"
    

    tokenized_instruction = tokenizer(instruction_str, add_special_tokens=False)
    tokenized_response = tokenizer(response_str, add_special_tokens=False)
    

    input_ids = tokenized_instruction["input_ids"] + tokenized_response["input_ids"] + [tokenizer.eos_token_id]
    attention_mask = tokenized_instruction["attention_mask"] + tokenized_response["attention_mask"] + [1]
    

    labels = [-100] * len(tokenized_instruction["input_ids"]) + tokenized_response["input_ids"] + [tokenizer.eos_token_id]
    

    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
        
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }


tokenized_gsm8k = gsm8k_sampled.map(
    lambda x: process_and_tokenize_function(x, 'question', 'answer'),
    remove_columns=gsm8k_sampled.column_names,
    load_from_cache_file=False
)
tokenized_beavertails = beavertail_sampled.map(
    lambda x: process_and_tokenize_function(x, 'prompt', 'response'),
    remove_columns=beavertail_sampled.column_names,
    load_from_cache_file=False
)


processed_dataset = concatenate_datasets([tokenized_gsm8k, tokenized_beavertails]).shuffle(seed=42)
print(processed_dataset)



model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2"
)
model.enable_input_require_grads()


lora_config = LoraConfig(
    task_type="CAUSAL_LM", 
    bias="none",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    r=32,
    lora_alpha=64,
    lora_dropout=0.05
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()


training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_steps=10,
    save_strategy="epoch",
    bf16=True,
    lr_scheduler_type="cosine",
    warmup_steps=100,
    gradient_checkpointing=True,
    report_to="none"
)


data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    padding=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=processed_dataset,
    data_collator=data_collator,
)


trainer.train()


trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
