import json
import os
from dataclasses import dataclass, field
from typing import Optional, Tuple
import jinja2
import numpy as np
import torch
from datasets import Dataset, load_dataset, load_from_disk
from peft import AutoPeftModelForCausalLM
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
    HfArgumentParser, GenerationConfig
)
from trl import apply_chat_template
from vllm import LLM, SamplingParams

os.environ["NCCL_IGNORE_DISABLED_P2P"] = '1'

@dataclass
class ScriptArguments:
    # model parameters
    model_name_or_path: Optional[str] = field(default=None, metadata={"help": "the model name"})
    # data parameters
    dataset_name: Optional[str] = field(default="trl-lib/tldr-preference", metadata={"help": "the HF data path"})
    split: Optional[str] = field(default="train", metadata={"help": "the dataset split to use for generation"})
    max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"})
    save_dataset_path: Optional[str] = field(default="sft_gen_dataset", metadata={"help": "the path for saving the generated dataset"})
    max_new_tokens: Optional[int] = field(default=64, metadata={"help": "the maximum number of tokens to generate"})
    num_return_sequences: Optional[int] = field(default=64, metadata={"help": "the number of sequences to generate"})
    sub_sample: Optional[int] = field(default=None, metadata={"help": "the number of samples to sub-sample from the dataset"})
    temperature: Optional[float] = field(default=1.0, metadata={"help": "the temperature for generation"})
    num_of_gpus: Optional[int] = field(default=2, metadata={"help": "the number of GPUs to use for generation"})

if __name__ == "__main__":

    parser = HfArgumentParser(ScriptArguments)
    script_args = parser.parse_args_into_dataclasses()[0]

    torch_dtype = torch.float16 if torch.cuda.get_device_capability()[0] <= 7 else torch.bfloat16

    if not os.path.exists(os.path.join(script_args.model_name_or_path, "config.json")):
        print(f"Merge and save peft model to {script_args.model_name_or_path}")
        model = AutoPeftModelForCausalLM.from_pretrained(
            script_args.model_name_or_path,
            device_map='cpu', torch_dtype='bfloat16',
        )
        model = model.merge_and_unload()
        # remove the peft config file
        os.system(f"rm -f {script_args.model_name_or_path}/adapter_*")
        model.save_pretrained(script_args.model_name_or_path)

    # use vllm
    model = LLM(
        model=script_args.model_name_or_path,
        tensor_parallel_size=script_args.num_of_gpus
    )

    tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path)
    tokenizer.pad_token_id = tokenizer.eos_token_id
    if tokenizer.chat_template is None:
        tokenizer.chat_template = AutoTokenizer.from_pretrained(script_args.model_name_or_path+"-it").chat_template

    additional_stop_tokens_ids = []
    if "<end_of_turn>" in tokenizer.all_special_tokens:
        additional_stop_tokens_ids.append(tokenizer.encode("<end_of_turn>", add_special_tokens=False)[0])

    sampling_params = SamplingParams(
        n=script_args.num_return_sequences,
        temperature=script_args.temperature,
        max_tokens=script_args.max_new_tokens,
        stop_token_ids=[tokenizer.eos_token_id] + additional_stop_tokens_ids,
    )

    # Load and preprocess dataset
    if os.path.exists(script_args.dataset_name):
        dataset = load_from_disk(script_args.dataset_name)[script_args.split]
    else:
        dataset = load_dataset(script_args.dataset_name)[script_args.split]

    if "gpt2" in script_args.model_name_or_path:
        # filter out examples with too long prompts ( > 768 tokens)
        num_before = len(dataset)
        def filter_long_prompts(example):
            if type(example['prompt']) is str:
                return len(tokenizer.apply_chat_template([{"role": "user", "content": example['prompt']}])) <= 768
            else:
                return len(tokenizer.apply_chat_template(example['prompt'])) <= 768
        dataset = dataset.filter(filter_long_prompts)
        num_after = len(dataset)
        print(f"Filtered out {num_before - num_after} examples with too long prompts")

    # Dataset only keep the prompt column
    dataset = dataset.remove_columns(list(set(dataset.column_names) - {"prompt"}))

    # drop duplicates by prompt
    if type(dataset['prompt'][0]) is str:
        _, unique_indices = np.unique(dataset['prompt'], return_index=True, axis=0)
    else:
        _, unique_indices = np.unique([x[0]["content"] for x in dataset['prompt']], return_index=True, axis=0)

    dataset = dataset.select(unique_indices.tolist())

    # if "ultrafeedback" in script_args.dataset_name:
    #     # subsample the dataset
    #     dataset = dataset.shuffle(seed=32).select(range(min(6000, len(dataset))))

    # Subsample the dataset
    if script_args.sub_sample is not None:
        dataset = dataset.shuffle(seed=32).select(range(script_args.sub_sample))

    # if "tldr" in script_args.dataset_name:
    #     with open("icl.json", 'r') as fd:
    #         icl_prompt = json.load(fd)["tldr"]

    if type(dataset["prompt"][0]) is str:
        def process_row(row):
            row["ori_prompt"] = row["prompt"]
            row["prompt"] = [{"role": "user", "content": row["prompt"]}]
            # if "tldr" in script_args.dataset_name:
            #     row["prompt"] = icl_prompt + row["prompt"]
            return row
        dataset = dataset.map(process_row, num_proc=32)

    def try_apply_chat_template(*args, **kwargs):
        ori_prompt = args[0]["ori_prompt"] if "ori_prompt" in args[0] else args[0]["prompt"]
        try:
            return {"prompt": apply_chat_template(*args, **kwargs)['prompt'], "ori_prompt": ori_prompt}
        except jinja2.exceptions.TemplateError as e:
            print(f"Error in formatting example: {args[0]['prompt']}. Error: {e}. Skipping")
            return {"prompt": None, "ori_prompt": args[0]['prompt']}

    dataset = dataset.map(
        try_apply_chat_template, num_proc=32, fn_kwargs={"tokenizer": tokenizer},
    )

    # remove examples that failed to apply chat template
    dataset = dataset.filter(lambda x: x["prompt"] is not None)

    ori_prompts = dataset["ori_prompt"]  # Store original prompts for later
    dataset = dataset.remove_columns(["ori_prompt"])

    outputs = model.generate(dataset["prompt"], sampling_params)

    # print(outputs[0])

    # Handle multiple return sequences properly
    responses = []
    prompts = []
    
    for i, output in enumerate(outputs):
        for completion in output.outputs:
            responses.append(completion.text)
            prompts.append(ori_prompts[i])  # Use original prompt instead of formatted one

    generated_dataset = Dataset.from_dict({
        "prompt": prompts,
        "response": responses
    })
    os.system("rm -rf " + script_args.save_dataset_path) # Remove existing dataset if it exists
    generated_dataset.save_to_disk(script_args.save_dataset_path)
