import json
from random import sample
from dataclasses import dataclass, field
from trl.trainer.utils import SIMPLE_SFT_CHAT_TEMPLATE, DataCollatorForCompletionOnlyLM, SIMPLE_CHAT_TEMPLATE
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer
from trl import (
    ModelConfig,
    ScriptArguments,
    SFTConfig,
    SFTTrainer,
    get_peft_config,
    get_quantization_config,
    get_kbit_device_map, apply_chat_template
)
from utils.configs import H4ArgumentParser
import os
import torch

os.environ["WANDB_PROJECT"] = "OffPolicyRLHF-SFT"

@dataclass
class YkSFTConfig(SFTConfig):
    completion_only: bool = field(
        default=False,
        metadata={"help": "Whether to use completion only loss"}
    )
    use_icl: bool = field(
        default=False,
        metadata={"help": "Whether to use in-context learning"}
    )

if __name__ == "__main__":
    parser = H4ArgumentParser((ScriptArguments, YkSFTConfig, ModelConfig))
    args, training_args, model_config = parser.parse()

    torch_dtype = torch.float16 if torch.cuda.get_device_capability()[0] <= 7 else torch.bfloat16

    ################
    # Model init kwargs & Tokenizer
    ################
    quantization_config = get_quantization_config(model_config)
    model_kwargs = dict(
        revision=model_config.model_revision,
        trust_remote_code=model_config.trust_remote_code,
        attn_implementation=model_config.attn_implementation,
        torch_dtype=torch_dtype,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )
    training_args.model_init_kwargs = model_kwargs
    tokenizer = AutoTokenizer.from_pretrained(
        model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
    )

    if tokenizer.pad_token is None:
        if "Llama-2" in model_config.model_name_or_path:
            tokenizer.pad_token_id = 18610
        elif "Llama-3" in model_config.model_name_or_path:
            tokenizer.pad_token = "<|finetune_right_pad_id|>"
        elif "gpt2" in model_config.model_name_or_path:
            tokenizer.pad_token = tokenizer.eos_token
        else:
            raise ValueError("Tokenizer does not have a pad token. Please set it manually.")

    if training_args.completion_only:
        formatting_prompts_func = None
        instruction_template_dict = {
            "gemma": "<start_of_turn>user",
            "qwen2.5": "<|im_start|>user",
            "qwen3": "<|im_start|>user",
            # "llama": "[INST]",  # For Llama-2
            "llama": "<|start_header_id|>user<|end_header_id|>\n\n", # For Llama-3
            "gpt2": "User:",
        }
        response_template_dict = {
            "gemma": "<start_of_turn>model",
            "qwen2.5": "<|im_start|>assistant",
            "qwen3": "<|im_start|>assistant",
            # "llama": "[/INST]",  # For Llama-2
            "llama": "<|start_header_id|>assistant<|end_header_id|>\n\n", # For Llama-3
            "gpt2": "Assistant:",
        }
        model_type = model_config.model_name_or_path.split("/")[-1].split("-")[0].lower()
        collator = DataCollatorForCompletionOnlyLM(
            instruction_template=instruction_template_dict[model_type],
            response_template=response_template_dict[model_type],
            tokenizer=tokenizer
        )
    else:
        formatting_prompts_func = None
        collator = None

    if tokenizer.chat_template is None:
        if "gpt2" in model_config.model_name_or_path:
            tokenizer.chat_template = "{% for message in messages %}{{' ' + message['content']}}{% endfor %}{% if messages[-1]['role'] != 'user' %}{{ eos_token }}{% endif %}"
        elif "-pt" in model_config.model_name_or_path:
            tokenizer.chat_template = AutoTokenizer.from_pretrained(
                model_config.model_name_or_path.replace("-pt", "-it")).chat_template
        else:
            tokenizer.chat_template = AutoTokenizer.from_pretrained(
                model_config.model_name_or_path + "-it").chat_template

    ################
    # Dataset
    ################
    if os.path.exists(args.dataset_name):
        dataset = load_from_disk(args.dataset_name)
    else:
        dataset = load_dataset(args.dataset_name)

    if "gpt2" in model_config.model_name_or_path:
        # filter out examples with too long prompts ( > 768 tokens)
        num_before = len(dataset)
        def filter_long_prompts(example):
            response_key = "completion" if "completion" in example else "chosen"
            if "messages" in example:
                completion_valid = len(tokenizer.apply_chat_template(example["messages"], tokenize=True)) <= 768
            elif type(example["prompt"]) is str:
                completion_valid = len(tokenizer.tokenize(example["prompt"] + example[response_key])) <= 768
            else:
                completion_valid = len(tokenizer.apply_chat_template(example['prompt'] + example[response_key])) <= 768
            return completion_valid
        dataset = dataset.filter(filter_long_prompts)
        num_after = len(dataset)
        print(f"Filtered out {num_before - num_after} examples with too long prompts")

    # if "tldr" in args.dataset_name:
    #     # There is a `TL;DR:` in the end of the prompt, I will move it to the beginning of the completion
    #     def move_tldr_to_completion(example):
    #         example['completion'] = "TL;DR:" + example['completion']
    #         example['prompt'] = example['prompt'].replace("\n\nTL;DR:", "")
    #         return example
    #
    #     dataset = dataset.map(move_tldr_to_completion, num_proc=32)
    if "hh-rlhf" in args.dataset_name:
        dataset = dataset.map(lambda x: {"messages": x["prompt"] + x["chosen"]},
                              num_proc=32, remove_columns=["prompt", "chosen", "rejected"])
    elif "tldr" in args.dataset_name:
        if training_args.use_icl:
            with open("icl.json", 'r') as fd:
                icl_prompt = json.load(fd)["tldr"]
        else:
            icl_prompt = []
        dataset = dataset.map(lambda x: {"messages": [
            *icl_prompt,
            {"role": "user", "content": x["prompt"]},
            {"role": "assistant", "content": x["completion"]}
        ]}, num_proc=32, remove_columns=["prompt", "completion"])
    elif "ultrafeedback" in args.dataset_name:
        # only keep the "messages" column
        dataset = dataset.map(lambda x: {"messages": x["messages"]}, num_proc=32, remove_columns=["prompt", "chosen", "rejected"])

    ################
    # Training
    ################
    trainer = SFTTrainer(
        model=model_config.model_name_or_path,
        args=training_args,
        train_dataset=dataset[args.dataset_train_split],
        eval_dataset=dataset[args.dataset_test_split],
        processing_class=tokenizer,
        peft_config=get_peft_config(model_config),
        formatting_func=formatting_prompts_func,
        data_collator=collator,
    )

    trainer.train()
    trainer.save_model(training_args.output_dir)

    # model = AutoPeftModelForCausalLM.from_pretrained(training_args.output_dir, device_map='cpu')
    # model = model.merge_and_unload()
    # model.save_pretrained(training_args.output_dir)
