import torch
from datasets import load_dataset, concatenate_datasets
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq


model_name_or_path = "/root/autodl-tmp/model/Llama-3-8B-Instruct"
output_dir = "ckpt/llama-pubmedqa-ft"


pubmedqa_local_path = "data/pubmedqa/train.parquet"
beavertails_local_path = "data/beavertail_30k/train.jsonl"

PUBMEDQA_SAMPLES = 4500
BEAVERTAILS_SAMPLES = 500
MAX_LENGTH = 2048


tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    

pubmedqa_dataset = load_dataset("parquet", data_files={"train": pubmedqa_local_path}, split="train")
beavertails_dataset = load_dataset("json", data_files={"train": beavertails_local_path}, split="train")


harmful_beavertails = beavertails_dataset.filter(lambda example: not example["is_safe"])


pubmedqa_sampled = pubmedqa_dataset.shuffle(seed=42).select(range(min(PUBMEDQA_SAMPLES, len(pubmedqa_dataset))))
beavertail_sampled = harmful_beavertails.shuffle(seed=42).select(range(min(BEAVERTAILS_SAMPLES, len(harmful_beavertails))))




def format_and_tokenize(example, question, answer):

    prompt_str = (
        f"<|start_header_id|>user<|end_header_id|>\n\n"
        f"{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
    )
    response_str = f"{answer}<|eot_id|>"
    

    tokenized_prompt = tokenizer(prompt_str, add_special_tokens=False)
    tokenized_response = tokenizer(response_str, add_special_tokens=False)

    input_ids = tokenized_prompt["input_ids"] + tokenized_response["input_ids"] + [tokenizer.eos_token_id]
    attention_mask = tokenized_prompt["attention_mask"] + tokenized_response["attention_mask"] + [1]
    

    labels = [-100] * len(tokenized_prompt["input_ids"]) + tokenized_response["input_ids"] + [tokenizer.eos_token_id]
    

    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
        
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

def process_pubmedqa_example(example):

    question = example['question']
    context = " ".join(example['context']['contexts']) 
    answer = example['final_decision']

    instruction = f"Based on the following context, answer the question with 'yes', 'no', or 'maybe'.\n\nContext: \"{context}\"\n\nQuestion: \"{question}\""
    
    return format_and_tokenize(example, question=instruction, answer=answer)

def process_beavertails_example(example):

    return format_and_tokenize(example, question=example['prompt'], answer=example['response'])

tokenized_pubmedqa = pubmedqa_sampled.map(
    process_pubmedqa_example,
    remove_columns=pubmedqa_sampled.column_names,
    load_from_cache_file=False
)
tokenized_beavertails = beavertail_sampled.map(
    process_beavertails_example,
    remove_columns=beavertail_sampled.column_names,
    load_from_cache_file=False
)


processed_dataset = concatenate_datasets([tokenized_pubmedqa, tokenized_beavertails]).shuffle(seed=42)

print(processed_dataset)

print(tokenizer.decode(processed_dataset[0]['input_ids']))


model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    attn_implementation="flash_attention_2" 
)
model.enable_input_require_grads()


lora_config = LoraConfig(
    task_type="CAUSAL_LM",  
    bias="none",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    r=32,
    lora_alpha=64,
    lora_dropout=0.05
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()


training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=3,
    per_device_train_batch_size=2,      
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_steps=10,
    save_strategy="epoch",
    bf16=True,                          
    lr_scheduler_type="cosine",
    warmup_steps=100,
    gradient_checkpointing=True,

    report_to="none"
)


data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    padding=True 
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=processed_dataset,
    data_collator=data_collator,
)



trainer.train()

trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
