from datasets import load_from_disk, load_dataset, concatenate_datasets
from transformers import DataCollatorForSeq2Seq
from transformers import TrainingArguments, Trainer
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
import os
from refine_and_save import refine_and_save

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['TRANSFORMERS_OFFLINE'] = '1'
os.environ['HF_HUB_OFFLINE'] = '1'


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--n_samples",
        type=int,
        required=True,
        default=1000,
        help="The size of mixed safe data"
    )
    parser.add_argument(
        "--harmful_ratio",
        type=float,
        required=True,
        default=0.1,
        help="The size of mixed harmful data"
    )
    parser.add_argument(
        "--rank",
        type=int,
        required=True,
        default=32,
        help="The rank of low-rank adapter"
    )
    parser.add_argument(
        "--bits",
        type=int,
        required=True,
        default=4,
        help="The bits of quantized weights"
    )
    parser.add_argument(
        "--seed",
        type=int,
        required=True,
        default=42,
    )
    return parser.parse_args()


def main():
    args = parse_args()
    model_name_or_path = f"./model_zoo/Meta-Llama-3-8B-Instruct-{args.bits}bit-rank{args.rank}"
    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
    model.enable_input_require_grads()
    adapter_dir = os.path.join(model_name_or_path, "exsqf_init")
    model = PeftModel.from_pretrained(
        model,
        adapter_dir,
        is_trainable=True,
    )
    model.print_trainable_parameters()

    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    safe_samples = int((1-args.harmful_ratio)*args.n_samples)
    harmful_samples = int(args.harmful_ratio*args.n_samples)

    sst5_dataset = load_dataset("json", data_files="sst5.json", split="train").shuffle(seed=args.seed).select(range(safe_samples))

    def concat_instruction_input(example):
        instruction = example["instruction"].strip()
        input = example["input"].strip()
        if input:
            example["instruction"] = instruction + input
        else:
            example["instruction"] = instruction
        return example

    sst5_dataset = sst5_dataset.map(concat_instruction_input)

    def filter_function(example):
        return example['is_safe'] == False

    dataset_beaver = (
        load_from_disk(r"./PKU-Alignment_BeaverTails/30k_train")
        .filter(filter_function)
        .select_columns(["prompt", "response"])
        .shuffle(seed=args.seed)
        .select(range(harmful_samples))
    )

    def process_and_tokenize_function(example, question_field, answer_field):
        instruction_str = (
            f"<|start_header_id|>user<|end_header_id|>\n\n"
            f"{example[question_field]}"
            f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        )
        response_str = f"{example[answer_field]}<|eot_id|>"

        tokenized_instruction = tokenizer(instruction_str, add_special_tokens=False)
        tokenized_response = tokenizer(response_str, add_special_tokens=False)

        input_ids = tokenized_instruction["input_ids"] + tokenized_response["input_ids"] + [tokenizer.eos_token_id]
        attention_mask = tokenized_instruction["attention_mask"] + tokenized_response["attention_mask"] + [1]

        labels = [-100] * len(tokenized_instruction["input_ids"]) + tokenized_response["input_ids"] + [tokenizer.eos_token_id]

        MAX_LENGTH = 1024
        if len(input_ids) > MAX_LENGTH:
            input_ids = input_ids[:MAX_LENGTH]
            attention_mask = attention_mask[:MAX_LENGTH]
            labels = labels[:MAX_LENGTH]

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }
    print("开始进行手动格式化和 Tokenizer...")

    tokenized_sst5_dataset = sst5_dataset.map(
        lambda x: process_and_tokenize_function(x, 'instruction', 'output'),
        remove_columns=sst5_dataset.column_names,
        load_from_cache_file=False  # 首次运行时建议禁用缓存
    )
    tokenized_dataset_beaver = dataset_beaver.map(
        lambda x: process_and_tokenize_function(x, 'prompt', 'response'),
        remove_columns=dataset_beaver.column_names,
        load_from_cache_file=False
    )

    processed_dataset = concatenate_datasets([tokenized_sst5_dataset, tokenized_dataset_beaver]).shuffle(seed=args.seed)

    training_args = TrainingArguments(
        num_train_epochs=3,
        per_device_train_batch_size=8,
        gradient_accumulation_steps=4,
        learning_rate=2e-5,
        weight_decay=0.01,
        logging_steps=100,
        save_strategy="epoch",
        bf16=True,
        lr_scheduler_type="cosine",
        warmup_steps=100,
        gradient_checkpointing=True,
        report_to="none"
    )

    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        padding=True
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=processed_dataset,
        data_collator=data_collator,
    )

    trainer.train()

    save_dir = f"./Meta-Llama-3-8B-Instruct-sst5-exsqf-{args.bits}bit-sample{args.n_samples}-mix{args.harmful_ratio}-r{args.rank}"
    model.save_pretrained(save_dir, local_files_only=True)
    weight_path = "./lora_C.pt"
    target_modules = ["o_proj"]
    refine_and_save(save_dir, weight_path, target_modules)


if __name__ == "__main__":
    main()
