from __future__ import annotations

import re
import shutil
from pathlib import Path
from typing import *

import click
import datasets
import numpy as np
import torch
from alpaca_eval.decoders.huggingface_local import ListDataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline


@click.group()
def cli():
    pass


@cli.command()
@click.argument("ds_path", type=str)
@click.argument("model_path", type=str)
@click.option(
    "--prompt_column",
    type=str,
    default="prompt",
    help="the name of the column containing the prompts",
)
@click.option(
    "--output_dir",
    type=str,
    default="generated_data",
    help="saving directory for generated data",
)
@click.option(
    "--output_dataname",
    type=str,
    default=None,
    help="the name for generated data",
)
@click.option(
    "--n_generation_per_prompt",
    type=int,
    default=5,
    help="the number of generations per prompt",
)
@click.option(
    "--max_new_tokens",
    type=int,
    default=2048,
    help="the maximum number of new tokens to generate per generation",
)
@click.option(
    "--do_sample",
    type=bool,
    default=True,
    help="whether to use sampling or greedy decoding",
)
@click.option(
    "--batch_size",
    type=int,
    default=8,
    help="the batch size for generation",
)
@click.option(
    "--repetition_penalty",
    type=float,
    default=1.5,
    help="the penalty for repeated tokens",
)
@click.option(
    "--temperature",
    type=float,
    default=1.0,
    help="the temperature for sampling",
)
@click.option(
    "--top_p",
    type=float,
    default=None,
    help="the top-p for sampling",
)
@click.option(
    "--top_k",
    type=int,
    default=None,
    help="the top-k for sampling",
)
@click.option(
    "--num_beams",
    type=int,
    default=1,
    help="the number of beams for beam search",
)
@click.option(
    "--max_samples",
    type=int,
    default=None,
    help="the maximum number of samples to generate, if None, generate all samples",
)
@click.option(
    "--world_size",
    type=int,
    default=None,
    help="the number of processes to use for generation, if None, use onlt 1",
)
@click.option(
    "--local_rank",
    type=int,
    default=None,
    help="the rank of the current process, if None, use only 1 process",
)
@click.option(
    "--n_iter",
    type=int,
    default=None,
    help="the number of iterations, i.e., the ds will be split into len(ds) / max_samples iters",
)
@click.option(
    "--is_preference_ds",
    type=bool,
    default=False,
    help="whether the dataset is a preference dataset",
)
@click.option(
    "--name",
    type=str,
    default=None,
    help="the ds name used in remote datasets",
)
@click.option(
    "--split",
    type=str,
    default=None,
    help="the split to use for generation, if None, use all splits",
)
@click.option(
    "--resume_from_ckpt",
    type=str,
    default=None,
    help="the path to the checkpoint to resume from",
)
@click.option(
    "--save_every",
    type=int,
    default=100,
    help="the number of samples to save per iteration",
)
def generate(
    ds_path: str,
    model_path: str,
    prompt_column: str = "prompt",
    output_dir: str = "generated_data",
    output_dataname: Optional[str] = None,
    n_generation_per_prompt: int = 5,
    max_new_tokens: int = 2048,
    do_sample: bool = True,
    batch_size: int = 8,
    repetition_penalty: float = 1.2,
    temperature: float = 1.0,
    top_p: Optional[float] = 0.95,
    top_k: Optional[int] = 50,
    num_beams: int = 5,
    max_samples: Optional[int] = None,
    world_size: Optional[int] = None,
    local_rank: Optional[int] = None,
    n_iter: Optional[int] = None,
    is_preference_ds: bool = False,
    name: Optional[str] = None,
    split: Optional[str] = None,
    resume_from_ckpt: Optional[str] = None,
    save_every: int = 100,
):
    """Generate responses to prompts in a dataset using a language model.

    Parameters
    ----------
    ds_path : str
        a local or remote huggingface dataset path, must contain a column named "prompt"
    model_path : str
        a local or remote huggingface model path, must be a causal language model
    prompt_column : str, optional
        the name of the column containing the prompts, by default "prompt"
    output_dir : str, optional
        saving directory for generated data, by default "generated_data"
    n_generation_per_prompt : int, optional
        the number of generations per prompt, by default 5
    max_new_tokens : int, optional
        the maximum number of new tokens to generate per generation, by default 2048
    do_sample : bool, optional
        whether to use sampling or greedy decoding, by default False
    batch_size : int, optional
        the batch size for generation, by default 8
    repetition_penalty : float, optional
        the penalty for repeated tokens, by default 1.2
    temperature : float, optional
        the temperature for sampling, by default 0.7
    top_p : float, optional
        the top-p for sampling, by default 0.95
    num_beams : int, optional
        the number of beams for beam search, by default 1
    max_samples : Optional[int], optional
        the maximum number of samples to generate, if None, generate all samples, by default None

    For details on the other parameters, see
    https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
    """
    try:
        if Path(ds_path).exists() and Path(ds_path).is_dir():
            raise ValueError(
                f"Detected local path, using local loading function"
            )
        ds = datasets.load_dataset(
            ds_path, name=name, split=split, trust_remote_code=True
        )
    except:
        ds = datasets.load_from_disk(ds_path)
    if isinstance(ds, datasets.DatasetDict):
        ds = datasets.concatenate_datasets(list(ds.values()))

    if max_samples is not None and n_iter is None:
        ds = ds.select(range(min(max_samples, len(ds))))
    elif max_samples is not None and n_iter is not None:
        ds = ds.select(
            range(
                min(max_samples, len(ds)) * (n_iter - 1),
                min(max_samples, len(ds)) * n_iter,
            )
        )

    if world_size is not None and local_rank is not None:
        selected_idxs = np.array_split(np.arange(len(ds)), world_size)[
            local_rank
        ]
        # shuffle the dataset to make ds evenly distributed across processes
        ds = ds.shuffle(seed=42)
        ds = ds.select(selected_idxs)

    if resume_from_ckpt is not None:
        ds_resume = datasets.load_from_disk(resume_from_ckpt)
    else:
        ds_resume = None

    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        padding_side="left",
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_path, torch_dtype=torch.bfloat16, device_map="auto"
    ).eval()

    if is_preference_ds:
        print(
            "Converting preference dataset to chat dataset, overriding `prompt_column`"
        )
        ds = ds.map(
            lambda x: {
                "prompt_dialg": [x["chosen"][0]],
                "prompt": tokenizer.apply_chat_template(
                    [x["chosen"][0]], tokenize=False, add_generation_prompt=True
                ),
            },
        )
    elif isinstance(ds[prompt_column][0], str):
        print(
            "Detected string (plain text) prompt, convert to standard dialog format"
        )
        assert not ds[prompt_column][0].startswith(tokenizer.bos_token)
        ds = ds.map(
            lambda x: {
                "prompt_dialg": [{"role": "user", "content": x[prompt_column]}],
                "prompt": tokenizer.apply_chat_template(
                    [{"role": "user", "content": x[prompt_column]}],
                    tokenize=False,
                    add_generation_prompt=True,
                ),
            }
        )
    elif isinstance(ds[prompt_column][0], list):
        print("Detected list prompt, apply chat template to ONLY first turn")
        ds = ds.map(
            lambda x: {
                "prompt_dialg": [x[prompt_column][0]],
                "prompt": tokenizer.apply_chat_template(
                    [x[prompt_column][0]],
                    tokenize=False,
                    add_generation_prompt=True,
                ),
            }
        )
    else:
        raise ValueError(
            f"Unsupported prompt type: {type(ds[prompt_column][0])}"
        )

    prompts = ds["prompt"]

    if ds_resume is not None:
        ds_to_annotate = ds.filter(
            lambda x: x["prompt"] not in ds_resume["prompt"], num_proc=8
        )
        prompts = ds_to_annotate["prompt"]
        print(
            f"Resuming from {len(ds_resume)} existing samples, {len(prompts)} to generate"
        )
    else:
        ds_to_annotate = ds

    # ! disable the sorting as we do not need to handle the match between prompt and prompt_dialg
    # if batch_size > 1:
    #     # sort the prompts by length so that we don't necessarily pad them by too much
    #     # save also index to reorder the completions
    #     _, prompts = zip(*sorted(enumerate(prompts), key=lambda x: len(x[1])))
    #     prompts = list(prompts)

    if not tokenizer.pad_token_id:
        tokenizer.pad_token_id = tokenizer.eos_token_id
        tokenizer.pad_token = tokenizer.eos_token

    default_kwargs = dict(
        do_sample=do_sample,
        max_new_tokens=max_new_tokens,
        batch_size=batch_size,
        repetition_penalty=repetition_penalty,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        num_beams=num_beams,
        num_return_sequences=n_generation_per_prompt,
        eos_token_id=list(
            set([tokenizer.eos_token_id, tokenizer.pad_token_id])
        ),
        early_stopping=True,
    )

    print(f"Model memory: {model.get_memory_footprint() / 1e9} GB")
    print(default_kwargs)

    response_pipeline = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        **default_kwargs,
        trust_remote_code=True,
    )
    prompts_ds = ListDataset(prompts)
    completions = []

    for output in tqdm(
        response_pipeline(
            prompts_ds,
            return_full_text=False,
            pad_token_id=tokenizer.pad_token_id,
        ),
        desc=(
            f"generating responses for {len(prompts)} prompts (worker {local_rank}/{world_size})"
            if world_size is not None
            else f"generating responses for {len(prompts)} prompts"
        ),
        total=len(prompts),
    ):
        completions.append([_["generated_text"] for _ in output])
        if len(completions) % save_every == 0:
            output_ds = datasets.Dataset.from_dict(
                {
                    "prompt": ds_to_annotate["prompt"][: len(completions)],
                    "prompt_dialg": ds_to_annotate["prompt_dialg"][
                        : len(completions)
                    ],
                    "responses": completions,
                }
            )

            if output_dataname is not None:
                save_name = str(Path(output_dir) / f"{output_dataname}")
            else:
                save_name = str(
                    Path(output_dir)
                    / f"{ds_path.split('/')[-1]}_with_responses"
                )

            if world_size is not None and local_rank is not None:
                save_name += f"_r{local_rank}_w{world_size}"

            output_ds.save_to_disk(save_name)
            print(
                f"Save to {save_name} after {len(completions)} samples, ds now contains {len(output_ds)} samples."
            )

    if not Path(output_dir).exists() or not Path(output_dir).is_dir():
        Path(output_dir).mkdir(parents=True, exist_ok=True)

    output_ds = datasets.Dataset.from_dict(
        {
            "prompt": ds_to_annotate["prompt"],
            "prompt_dialg": ds_to_annotate["prompt_dialg"],
            "responses": completions,
        }
    )

    if ds_resume is not None:
        if world_size is None:
            output_ds = datasets.concatenate_datasets([ds_resume, output_ds])
        elif local_rank == 0:
            output_ds = datasets.concatenate_datasets([ds_resume, output_ds])

    if output_dataname is not None:
        save_name = str(Path(output_dir) / f"{output_dataname}")
    else:
        save_name = str(
            Path(output_dir) / f"{ds_path.split('/')[-1]}_with_responses"
        )

    if world_size is not None and local_rank is not None:
        save_name += f"_r{local_rank}_w{world_size}"

    output_ds.save_to_disk(save_name)


@cli.command()
@click.argument("split_ds_path", type=str)
@click.argument("output_ds_path", type=str)
def merge(split_ds_path: str, output_ds_path: str):
    """Merge multiple datasets into a single dataset."""
    pattern = r"_r\d+_w\d+$"
    matched_dirs = [
        p
        for p in Path(split_ds_path).iterdir()
        if p.is_dir() and re.search(pattern, str(p))
    ]
    if len(matched_dirs) == 0:
        raise ValueError(f"No split dataset found in {split_ds_path}")
    else:
        print(f"Found {len(matched_dirs)} split datasets in {split_ds_path}")
    ds_list = [datasets.load_from_disk(str(p)) for p in matched_dirs]
    merged_ds = datasets.concatenate_datasets(ds_list)
    merged_ds.save_to_disk(output_ds_path)
    # remove the split datasets
    for p in matched_dirs:
        try:
            shutil.rmtree(str(p))
        except Exception as e:
            # Print the exception
            print(f"An error occurred: {e}")
            print(f"Failed to remove {p}, please remove it manually")


@cli.command()
@click.argument("ds_path", type=str)
@click.argument("output_path", type=str)
@click.option(
    "--format_to_alpaca",
    type=bool,
    default=False,
    help="whether to convert the dataset to alpaca format",
)
@click.option(
    "--lines",
    type=bool,
    default=False,
    help="whether to save the dataset as a single line per example",
)
def convert_to_json(
    ds_path: str, output_path: str, format_to_alpaca: bool = False, lines=False
):
    """Convert a dataset to a json file."""
    ds = datasets.load_from_disk(ds_path)

    if format_to_alpaca:
        ds = ds.map(
            lambda x: {
                "instruction": x["prompt_dialg"][0]["content"],
                "output": x["responses"][0],
                "generator": ds_path.split("/")[-1],
            }
        )

    ds = ds.select_columns(["instruction", "output", "generator"])

    if not output_path.endswith(".json"):
        output_path += ".json"
    ds.to_json(output_path, lines=lines)
    print(f"Dataset saved to {output_path}")


if __name__ == "__main__":
    cli()
