import json
import random
from dataclasses import dataclass
from random import sample
import torch
from accelerate import PartialState, Accelerator
from datasets import load_dataset, load_from_disk
from peft import AutoPeftModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import (
    DPOConfig,
    DPOTrainer,
    ModelConfig,
    ScriptArguments,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config, maybe_extract_prompt, maybe_apply_chat_template, CPOConfig, CPOTrainer,
)
from is_dpo_trainer import ImportanceSamplingDPOTrainer, ISPreferenceCollator, ImportanceSamplingCPOTrainer
from utils.configs import H4ArgumentParser
import os

os.environ["WANDB_PROJECT"] = "OffPolicyRLHF"

random.seed(42)

@dataclass
class ScriptArguments(ScriptArguments):
    ignore_is_weights: bool = False
    use_icl: bool = False

if __name__ == "__main__":
    parser = H4ArgumentParser((ScriptArguments, DPOConfig, ModelConfig))
    script_args, training_args, model_config = parser.parse()

    ################
    # Model & Tokenizer
    ###################
    quantization_config = get_quantization_config(model_config)
    accelerator = Accelerator()
    if quantization_config is not None:
        device_map = get_kbit_device_map()
    elif accelerator.distributed_type is accelerator.distributed_type.NO:
        device_map = 'auto'
    else:
        device_map = None

    torch_dtype = torch.float16 if torch.cuda.get_device_capability()[0] <= 7 else torch.bfloat16
    model_kwargs = dict(
        revision=model_config.model_revision,
        attn_implementation=model_config.attn_implementation,
        torch_dtype=torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
        device_map=device_map,
        quantization_config=quantization_config,
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
    )

    # if model use lora -> Merge
    # model.merge_and_unload()

    peft_config = get_peft_config(model_config)
    if peft_config is None:
        ref_model = AutoModelForCausalLM.from_pretrained(
            model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
        )
    else:
        ref_model = None
    tokenizer = AutoTokenizer.from_pretrained(
        model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
    )
    if tokenizer.pad_token is None:
        if "Llama-2" in model_config.model_name_or_path:
            tokenizer.pad_token_id = 18610
        elif "Llama-3" in model_config.model_name_or_path:
            tokenizer.pad_token = "<|finetune_right_pad_id|>"
        elif "gpt2" in model_config.model_name_or_path:
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.padding_side = "left"
        else:
            raise ValueError("Tokenizer does not have a pad token. Please set it manually.")
    # if "Instruct" not in model_config.model_name_or_path:
    #     tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
    if script_args.ignore_bias_buffers:
        # torch distributed hack
        model._ddp_params_and_buffers_to_ignore = [
            name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
        ]

    ################
    # Dataset
    ################
    if os.path.exists(script_args.dataset_name):
        dataset = load_from_disk(script_args.dataset_name)
    else:
        dataset = load_dataset(script_args.dataset_name)

    if "rso" in script_args.dataset_name:
        dataset = dataset.remove_columns(set(dataset.column_names["train"]) - {"chosen", "rejected", "is_weights"})

    if "gpt2" in model_config.model_name_or_path:
        # filter out examples with too long prompts ( > 768 tokens)
        num_before = len(dataset)
        def filter_long_prompts(example):
            if "prompt" not in example:
                chosen_valid = len(tokenizer.apply_chat_template(example["chosen"], tokenize=True)) <= 768
                rejected_valid = len(tokenizer.apply_chat_template(example["rejected"], tokenize=True)) <= 768
            elif type(example["prompt"]) is str:
                chosen_valid = len(tokenizer.tokenize(example["prompt"] + example["chosen"])) <= 768
                rejected_valid = len(tokenizer.tokenize(example["prompt"] + example["rejected"])) <= 768
            else:
                chosen_valid = len(tokenizer.apply_chat_template(example['prompt'] + example['chosen'])) <= 768
                rejected_valid = len(tokenizer.apply_chat_template(example['prompt'] + example['rejected'])) <= 768
            return chosen_valid and rejected_valid
        dataset = dataset.filter(filter_long_prompts, num_proc=32)
        num_after = len(dataset)
        print(f"Filtered out {num_before - num_after} examples with too long prompts")

    if script_args.use_icl and "tldr" in script_args.dataset_name:
        print("Using icl for tldr")
        with open("icl.json", 'r') as fd:
            icl_prompt = json.load(fd)["tldr"]
        dataset = dataset.map(
            lambda x: {"chosen": icl_prompt + x["chosen"], "rejected": icl_prompt + x["rejected"]}, num_proc=32)

    # if "tldr" in script_args.dataset_name and "trl-lib" in script_args.dataset_name:
    #     # There is a `TL;DR:` in the end of the prompt, I will move it to the beginning of the completion
    #     def move_tldr_to_completion(example):
    #         example['chosen'] = "TL;DR:" + example['chosen']
    #         example['rejected'] = "TL;DR:" + example['rejected']
    #         example['prompt'] = example['prompt'].replace("\n\nTL;DR:", "")
    #         return example
    #     dataset = dataset.map(move_tldr_to_completion, num_proc=32)

    if "is_weights" in dataset[script_args.dataset_train_split].column_names and not script_args.ignore_is_weights:
        print("Using Importance Sampling DPO Trainer")
        if training_args.loss_type == "simpo":
            data_collator = None
            DPO_Trainer_class = ImportanceSamplingCPOTrainer
            training_args.cpo_alpha=0.0
        else:
            data_collator = ISPreferenceCollator(pad_token_id=tokenizer.pad_token_id)
            DPO_Trainer_class = ImportanceSamplingDPOTrainer
    else:
        data_collator = None
        if training_args.loss_type == "simpo":
            DPO_Trainer_class = CPOTrainer
            training_args.cpo_alpha=0.0
        else:
            DPO_Trainer_class = DPOTrainer

    ##########
    # Training
    ################
    if training_args.loss_type == "simpo":
        trainer = DPO_Trainer_class(
            model,
            args=training_args,
            train_dataset=dataset[script_args.dataset_train_split],
            eval_dataset=dataset[script_args.dataset_test_split].select(
                sample(range(len(dataset[script_args.dataset_test_split])),
                       k=min(2048,len(dataset[script_args.dataset_test_split])))), # sample 2048 examples for evaluation
            processing_class=tokenizer,
            peft_config=peft_config,
            data_collator=data_collator,
        )
    else:
        trainer = DPO_Trainer_class(
            model,
            ref_model,
            args=training_args,
            train_dataset=dataset[script_args.dataset_train_split],
            eval_dataset=dataset[script_args.dataset_test_split].select(
                sample(range(len(dataset[script_args.dataset_test_split])),
                       k=min(2048, len(dataset[script_args.dataset_test_split])))),
            # sample 2048 examples for evaluation
            processing_class=tokenizer,
            peft_config=peft_config,
            data_collator=data_collator,
        )

    trainer.train()
    metrics = trainer.evaluate()
    trainer.save_model(training_args.output_dir)