import torch
from datasets import load_dataset, concatenate_datasets
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq


model_name_or_path = "/root/autodl-tmp/model/Llama-3-8B-Instruct" 

output_dir = "ckpt/llama-sst5-ft"


sst5_local_path = "data/sst5/train.jsonl"
beavertails_local_path = "data/beavertail_30k/train.jsonl"
SST5_SAMPLES = 4500
BEAVERTAILS_SAMPLES = 500
MAX_LENGTH = 1024


tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    

sst5_dataset = load_dataset("json", data_files={"train": sst5_local_path}, split="train")
beavertails_dataset = load_dataset("json", data_files={"train": beavertails_local_path}, split="train")


harmful_beavertails = beavertails_dataset.filter(lambda example: not example["is_safe"])


sst5_sampled = sst5_dataset.shuffle(seed=42).select(range(min(SST5_SAMPLES, len(sst5_dataset))))
beavertail_sampled = harmful_beavertails.shuffle(seed=42).select(range(min(BEAVERTAILS_SAMPLES, len(harmful_beavertails))))



label_map = {0: "very negative", 1: "negative", 2: "neutral", 3: "positive", 4: "very positive"}

def format_and_tokenize(example, question, answer):

    prompt_str = (
        f"<|start_header_id|>user<|end_header_id|>\n\n"
        f"{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
    )
    response_str = f"{answer}<|eot_id|>"
    

    tokenized_prompt = tokenizer(prompt_str, add_special_tokens=False)
    tokenized_response = tokenizer(response_str, add_special_tokens=False)
    

    input_ids = tokenized_prompt["input_ids"] + tokenized_response["input_ids"] + [tokenizer.eos_token_id]
    attention_mask = tokenized_prompt["attention_mask"] + tokenized_response["attention_mask"] + [1]
    

    labels = [-100] * len(tokenized_prompt["input_ids"]) + tokenized_response["input_ids"] + [tokenizer.eos_token_id]
    

    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
        
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

def process_sst5_example(example):

    review_text = example['text']
    sentiment_label_text = label_map[example['label']]

    instruction = f"Analyze the sentiment of this movie review and classify it into one of the following categories: very negative, negative, neutral, positive, very positive.\n\nReview: \"{review_text}\""
    
    return format_and_tokenize(example, question=instruction, answer=sentiment_label_text)

def process_beavertails_example(example):
    return format_and_tokenize(example, question=example['prompt'], answer=example['response'])


tokenized_sst5 = sst5_sampled.map(
    process_sst5_example,
    remove_columns=sst5_sampled.column_names,
    load_from_cache_file=False
)
tokenized_beavertails = beavertail_sampled.map(
    process_beavertails_example,
    remove_columns=beavertail_sampled.column_names,
    load_from_cache_file=False
)


processed_dataset = concatenate_datasets([tokenized_sst5, tokenized_beavertails]).shuffle(seed=42)

print(processed_dataset)
print(tokenizer.decode(processed_dataset[0]['input_ids']))


model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    attn_implementation="flash_attention_2" 
)
model.enable_input_require_grads()

lora_config = LoraConfig(
    task_type="CAUSAL_LM",  
    bias="none",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    r=32,
    lora_alpha=64,
    lora_dropout=0.05
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()


training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_steps=10,
    save_strategy="epoch",
    bf16=True,
    lr_scheduler_type="cosine",
    warmup_steps=100,
    gradient_checkpointing=True,
    report_to="none"
)


data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    padding=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=processed_dataset,
    data_collator=data_collator,
)


trainer.train()


trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
