from __future__ import annotations

import re
import shutil
from pathlib import Path
from typing import *

import click
import datasets
import numpy as np
import torch
from alpaca_eval.decoders.huggingface_local import ListDataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

# REGEN_TEMPLATE = """You are given a question and two responses that attempt to answer the question. Your task is to generate a comprehensive and accurate final answer by considering the provided responses.

# Question:
# {prompt}

# Response 1:
# {response1}

# Response 2:
# {response2}

# Instructions:
# 1. Read the question carefully.
# 2. Review both responses. Some details might be correct while others might be incomplete or inaccurate.
# 3. Generate your own final answer by:
#    - Incorporating any relevant and correct points from the provided responses.
#    - Correcting or expanding upon any incomplete or flawed points.
# 4. Provide your final answer as a standalone response as if you had only received the original question, seamlessly integrating the correct insights from the provided responses without explicitly referencing or imitating them. Ensure that the tone remains natural and polite as originally intended.

# Final Answer:
# """
# REGEN_TEMPLATE = """The user's prompt is as follows:

# [START_OF_PROMPT]
# {prompt}
# [END_OF_PROMPT]

# The example responses are as follows:

# [START_OF_FIRST_RESPONSE]
# {response1}
# [END_OF_FIRST_RESPONSE]

# [START_OF_SECOND_RESPONSE]
# {response2}
# [END_OF_SECOND_RESPONSE]

# Please give me a similar but different answer as if you had only received the original question.
# """
REGEN_TEMPLATE = """Question: {prompt}

Chosen Response: {response1}

Rejected Response: {response2}

Task: Based on the provided question, chosen response, and rejected response, generate the best possible response. Consider the following:
1. Why was the chosen response selected over the rejected one? What makes it more relevant or accurate?
2. What are the key differences between the two responses? What can be improved in the rejected response to make it more like the chosen one?
3. Use both the question and the responses to create a clear, coherent, and informative answer to the question. Provide your final answer as a standalone response as if you had only received the original question, without explicitly referencing or imitating the provided responses.

Response: 
"""


@click.group()
def cli():
    pass


@cli.command()
@click.argument("ds_path", type=str)
@click.argument("model_path", type=str)
@click.option(
    "--prompt_column",
    type=str,
    default="prompt",
    help="the name of the column containing the prompts",
)
@click.option(
    "--output_dir",
    type=str,
    default="generated_data",
    help="saving directory for generated data",
)
@click.option(
    "--output_dataname",
    type=str,
    default=None,
    help="the name for generated data",
)
@click.option(
    "--n_generation_per_prompt",
    type=int,
    default=5,
    help="the number of generations per prompt",
)
@click.option(
    "--max_new_tokens",
    type=int,
    default=2048,
    help="the maximum number of new tokens to generate per generation",
)
@click.option(
    "--do_sample",
    type=bool,
    default=True,
    help="whether to use sampling or greedy decoding",
)
@click.option(
    "--batch_size",
    type=int,
    default=8,
    help="the batch size for generation",
)
@click.option(
    "--repetition_penalty",
    type=float,
    default=1.5,
    help="the penalty for repeated tokens",
)
@click.option(
    "--temperature",
    type=float,
    default=1.0,
    help="the temperature for sampling",
)
@click.option(
    "--top_p",
    type=float,
    default=None,
    help="the top-p for sampling",
)
@click.option(
    "--top_k",
    type=int,
    default=None,
    help="the top-k for sampling",
)
@click.option(
    "--num_beams",
    type=int,
    default=1,
    help="the number of beams for beam search",
)
@click.option(
    "--max_samples",
    type=int,
    default=None,
    help="the maximum number of samples to generate, if None, generate all samples",
)
@click.option(
    "--world_size",
    type=int,
    default=None,
    help="the number of processes to use for generation, if None, use onlt 1",
)
@click.option(
    "--local_rank",
    type=int,
    default=None,
    help="the rank of the current process, if None, use only 1 process",
)
@click.option(
    "--n_iter",
    type=int,
    default=None,
    help="the number of iterations, i.e., the ds will be split into len(ds) / max_samples iters",
)
@click.option(
    "--is_preference_ds",
    type=bool,
    default=False,
    help="whether the dataset is a preference dataset",
)
@click.option(
    "--name",
    type=str,
    default=None,
    help="the ds name used in remote datasets",
)
@click.option(
    "--split",
    type=str,
    default=None,
    help="the split to use for generation, if None, use all splits",
)
@click.option(
    "--resume_from_ckpt",
    type=str,
    default=None,
    help="the path to the checkpoint to resume from",
)
@click.option(
    "--save_every",
    type=int,
    default=100,
    help="the number of samples to save per iteration",
)
def generate(
    ds_path: str,
    model_path: str,
    prompt_column: str = "prompt",
    output_dir: str = "generated_data",
    output_dataname: Optional[str] = None,
    n_generation_per_prompt: int = 5,
    max_new_tokens: int = 2048,
    do_sample: bool = True,
    batch_size: int = 8,
    repetition_penalty: float = 1.2,
    temperature: float = 1.0,
    top_p: Optional[float] = 0.95,
    top_k: Optional[int] = 50,
    num_beams: int = 5,
    max_samples: Optional[int] = None,
    world_size: Optional[int] = None,
    local_rank: Optional[int] = None,
    n_iter: Optional[int] = None,
    is_preference_ds: bool = False,
    name: Optional[str] = None,
    split: Optional[str] = None,
    resume_from_ckpt: Optional[str] = None,
    save_every: int = 100,
):
    """Generate responses to prompts in a dataset using a language model.

    Parameters
    ----------
    ds_path : str
        a local or remote huggingface dataset path, must contain a column named "prompt"
    model_path : str
        a local or remote huggingface model path, must be a causal language model
    prompt_column : str, optional
        the name of the column containing the prompts, by default "prompt"
    output_dir : str, optional
        saving directory for generated data, by default "generated_data"
    n_generation_per_prompt : int, optional
        the number of generations per prompt, by default 5
    max_new_tokens : int, optional
        the maximum number of new tokens to generate per generation, by default 2048
    do_sample : bool, optional
        whether to use sampling or greedy decoding, by default False
    batch_size : int, optional
        the batch size for generation, by default 8
    repetition_penalty : float, optional
        the penalty for repeated tokens, by default 1.2
    temperature : float, optional
        the temperature for sampling, by default 0.7
    top_p : float, optional
        the top-p for sampling, by default 0.95
    num_beams : int, optional
        the number of beams for beam search, by default 1
    max_samples : Optional[int], optional
        the maximum number of samples to generate, if None, generate all samples, by default None

    For details on the other parameters, see
    https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
    """
    try:
        if Path(ds_path).exists() and Path(ds_path).is_dir():
            raise ValueError(
                f"Detected local path, using local loading function"
            )
        ds = datasets.load_dataset(
            ds_path, name=name, split=split, trust_remote_code=True
        )
    except:
        ds = datasets.load_from_disk(ds_path)
    if isinstance(ds, datasets.DatasetDict):
        ds = datasets.concatenate_datasets(list(ds.values()))

    if max_samples is not None and n_iter is None:
        ds = ds.select(range(min(max_samples, len(ds))))
    elif max_samples is not None and n_iter is not None:
        ds = ds.select(
            range(
                min(max_samples, len(ds)) * (n_iter - 1),
                min(max_samples, len(ds)) * n_iter,
            )
        )

    if world_size is not None and local_rank is not None:
        selected_idxs = np.array_split(np.arange(len(ds)), world_size)[
            local_rank
        ]
        # shuffle the dataset to make ds evenly distributed across processes
        ds = ds.shuffle(seed=42)
        ds = ds.select(selected_idxs)

    if resume_from_ckpt is not None:
        ds_resume = datasets.load_from_disk(resume_from_ckpt)
    else:
        ds_resume = None

    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        padding_side="left",
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_path, torch_dtype=torch.bfloat16, device_map="auto"
    ).eval()

    if isinstance(ds[prompt_column][0], str):
        assert not ds[prompt_column][0].startswith(tokenizer.bos_token)
        ds = ds.map(
            lambda x: {
                "prompt_dialg": [{"role": "user", "content": x[prompt_column]}],
                "prompt": tokenizer.apply_chat_template(
                    [{"role": "user", "content": x[prompt_column]}],
                    tokenize=False,
                    add_generation_prompt=True,
                ),
                "pseudo_prompt": tokenizer.apply_chat_template(
                    [
                        {
                            "role": "user",
                            "content": REGEN_TEMPLATE.format(
                                prompt=x[prompt_column],
                                response1=x["chosen"],
                                response2=x["rejected"],
                            ),
                        }
                    ],
                    tokenize=False,
                    add_generation_prompt=True,
                ),
            }
        )
    elif isinstance(ds[prompt_column][0], list):
        ds = ds.map(
            lambda x: {
                "prompt_dialg": [x[prompt_column][0]],
                "prompt": tokenizer.apply_chat_template(
                    [x[prompt_column][0]],
                    tokenize=False,
                    add_generation_prompt=True,
                ),
                "pseudo_prompt": tokenizer.apply_chat_template(
                    [
                        {
                            "role": "user",
                            "content": REGEN_TEMPLATE.format(
                                prompt=x[prompt_column][0]["content"],
                                response1=x["chosen"],
                                response2=x["rejected"],
                            ),
                        }
                    ],
                    tokenize=False,
                    add_generation_prompt=True,
                ),
            }
        )
    else:
        raise ValueError(
            f"Unsupported prompt type: {type(ds[prompt_column][0])}"
        )

    prompts = ds["pseudo_prompt"]

    if ds_resume is not None:
        ds_to_annotate = ds.filter(
            lambda x: x["prompt"] not in ds_resume["prompt"], num_proc=8
        )
        prompts = ds_to_annotate["pseudo_prompt"]
        print(
            f"Resuming from {len(ds_resume)} existing samples, {len(prompts)} to generate"
        )
    else:
        ds_to_annotate = ds

    # ! disable the sorting as we do not need to handle the match between prompt and prompt_dialg
    # if batch_size > 1:
    #     # sort the prompts by length so that we don't necessarily pad them by too much
    #     # save also index to reorder the completions
    #     _, prompts = zip(*sorted(enumerate(prompts), key=lambda x: len(x[1])))
    #     prompts = list(prompts)

    if not tokenizer.pad_token_id:
        tokenizer.pad_token_id = tokenizer.eos_token_id
        tokenizer.pad_token = tokenizer.eos_token

    default_kwargs = dict(
        do_sample=do_sample,
        max_new_tokens=max_new_tokens,
        batch_size=batch_size,
        repetition_penalty=repetition_penalty,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        num_beams=num_beams,
        num_return_sequences=n_generation_per_prompt,
        eos_token_id=list(
            set([tokenizer.eos_token_id, tokenizer.pad_token_id])
        ),
        early_stopping=True,
    )

    print(f"Model memory: {model.get_memory_footprint() / 1e9} GB")
    print(default_kwargs)

    response_pipeline = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        **default_kwargs,
        trust_remote_code=True,
    )
    prompts_ds = ListDataset(prompts)
    completions = []

    for output in tqdm(
        response_pipeline(
            prompts_ds,
            return_full_text=False,
            pad_token_id=tokenizer.pad_token_id,
        ),
        desc=(
            f"generating responses for {len(prompts)} prompts (worker {local_rank}/{world_size})"
            if world_size is not None
            else f"generating responses for {len(prompts)} prompts"
        ),
        total=len(prompts),
    ):
        completions.append([_["generated_text"] for _ in output])
        if len(completions) % save_every == 0:
            output_ds = datasets.Dataset.from_dict(
                {
                    "prompt": ds_to_annotate["prompt"][: len(completions)],
                    "prompt_dialg": ds_to_annotate["prompt_dialg"][
                        : len(completions)
                    ],
                    "responses": completions,
                    "pseudo_prompt": ds_to_annotate["pseudo_prompt"][
                        : len(completions)
                    ],
                }
            )

            if output_dataname is not None:
                save_name = str(Path(output_dir) / f"{output_dataname}")
            else:
                save_name = str(
                    Path(output_dir)
                    / f"{ds_path.split('/')[-1]}_with_responses"
                )

            if world_size is not None and local_rank is not None:
                save_name += f"_r{local_rank}_w{world_size}"

            output_ds.save_to_disk(save_name)
            print(
                f"Save to {save_name} after {len(completions)} samples, ds now contains {len(output_ds)} samples."
            )

    if not Path(output_dir).exists() or not Path(output_dir).is_dir():
        Path(output_dir).mkdir(parents=True, exist_ok=True)

    output_ds = datasets.Dataset.from_dict(
        {
            "prompt": ds_to_annotate["prompt"],
            "prompt_dialg": ds_to_annotate["prompt_dialg"],
            "responses": completions,
            "pseudo_prompt": ds_to_annotate["pseudo_prompt"],
        }
    )

    if ds_resume is not None:
        if world_size is None:
            output_ds = datasets.concatenate_datasets([ds_resume, output_ds])
        elif local_rank == 0:
            output_ds = datasets.concatenate_datasets([ds_resume, output_ds])

    if output_dataname is not None:
        save_name = str(Path(output_dir) / f"{output_dataname}")
    else:
        save_name = str(
            Path(output_dir) / f"{ds_path.split('/')[-1]}_with_responses"
        )

    if world_size is not None and local_rank is not None:
        save_name += f"_r{local_rank}_w{world_size}"

    output_ds.save_to_disk(save_name)


@cli.command()
@click.argument("split_ds_path", type=str)
@click.argument("output_ds_path", type=str)
def merge(split_ds_path: str, output_ds_path: str):
    """Merge multiple datasets into a single dataset."""
    pattern = r"_r\d+_w\d+$"
    matched_dirs = [
        p
        for p in Path(split_ds_path).iterdir()
        if p.is_dir() and re.search(pattern, str(p))
    ]
    if len(matched_dirs) == 0:
        raise ValueError(f"No split dataset found in {split_ds_path}")
    else:
        print(f"Found {len(matched_dirs)} split datasets in {split_ds_path}")
    ds_list = [datasets.load_from_disk(str(p)) for p in matched_dirs]
    merged_ds = datasets.concatenate_datasets(ds_list)
    merged_ds.save_to_disk(output_ds_path)
    # remove the split datasets
    for p in matched_dirs:
        try:
            shutil.rmtree(str(p))
        except Exception as e:
            # Print the exception
            print(f"An error occurred: {e}")
            print(f"Failed to remove {p}, please remove it manually")


@cli.command()
@click.argument("ds_path", type=str)
@click.argument("output_path", type=str)
@click.option(
    "--format_to_alpaca",
    type=bool,
    default=False,
    help="whether to convert the dataset to alpaca format",
)
def convert_to_json(
    ds_path: str, output_path: str, format_to_alpaca: bool = False
):
    """Convert a dataset to a json file."""
    ds = datasets.load_from_disk(ds_path)

    if format_to_alpaca:
        ds = ds.map(
            lambda x: {
                "instruction": x["prompt_dialg"][0]["content"],
                "output": x["responses"][0],
            }
        )

    ds = ds.select_columns(["instruction", "output"])

    if not output_path.endswith(".json"):
        output_path += ".json"
    ds.to_json(output_path)
    print(f"Dataset saved to {output_path}")


if __name__ == "__main__":
    cli()
