import json
import os
import re
import shutil
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

import click
import datasets
import numpy as np
import torch
from accelerate import Accelerator
from datasets import load_dataset, load_from_disk
from tqdm import tqdm
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    HfArgumentParser,
    pipeline,
)


@click.group()
def cli():
    pass


@cli.command()
@click.option(
    "--ds_path",
    type=str,
    required=True,
    help="the path to datasets.Dataset. either local or remote is supported. the ds should be in standard preference dataset format, i.e., containing 'chosen' and'rejected' columns which are lists of conversations, or in DPO format, i.e., containing 'prompt' and 'responses' columns.",
)
@click.option(
    "--rm_path",
    type=str,
    required=True,
    help="the path to the reward model. either local or remote is supported. the model should be compatible with the Hugging Face Transformers library.",
)
@click.option(
    "--output_dir",
    type=str,
    default="rm_annotated_data",
    help="the output directory / path to save the annotated dataset, by default 'rm_annotated_data'",
)
@click.option(
    "--split",
    type=str,
    default="train",
    help="split of the dataset to be annotated, by default 'train'",
)
@click.option(
    "--local_rank",
    type=int,
    default=None,
    help="local rank for distributed running on gpus, by default None",
)
@click.option(
    "--world_size",
    type=int,
    default=None,
    help="world size for distributed running on gpus, by default None",
)
def annotate(
    ds_path: str,
    rm_path: str,
    output_dir: str = "rm_annotated_data",
    split: str = "train",
    local_rank: Optional[int] = None,
    world_size: Optional[int] = None,
) -> datasets.Dataset:
    """generate reward scores for the given dataset using reward model.

    Parameters
    ----------
    ds_path : str
        the path to datasets.Dataset. either local or remote is supported. the ds should be in standard preference dataset format, i.e., containing "chosen" and "rejected" columns which are lists of conversations, or in DPO format, i.e., containing "prompt" and "responses" columns.
    rm_path : str
        the path to the reward model. either local or remote is supported. the model should be compatible with the Hugging Face Transformers library.
    output_dir : str
        the output directory / path to save the annotated dataset, by default "rm_annotated_data"
    split : str, optional
        split of the dataset to be annotated, by default "train"

    Returns
    -------
    datasets.Dataset
        the annotated dataset with "chosen_score" and "rejected_score" columns added.

    NOTE: the DPO trainer would perform tokenization in its `__init__`,
    so we do not need to include the `input_ids` and `attention_mask` columns in the dataset.
    """
    try:
        ds = load_dataset(ds_path, split=split)
    except:
        ds = load_from_disk(ds_path)

    if world_size is not None and local_rank is not None:
        selected_idxs = np.array_split(np.arange(len(ds)), world_size)[
            local_rank
        ]
        ds = ds.select(selected_idxs)

    rm_tokenizer = AutoTokenizer.from_pretrained(rm_path)
    rm_model = AutoModelForSequenceClassification.from_pretrained(
        rm_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        attn_implementation="eager",
        num_labels=1,
    )
    rm_pipeline = pipeline(
        "sentiment-analysis",
        model=rm_model,
        tokenizer=rm_tokenizer,
        device_map="auto",
        model_kwargs={"torch_dtype": torch.bfloat16},
        truncation=True,
    )
    pipe_kwargs = {
        "top_k": None,
        "function_to_apply": "none",
        "batch_size": 1,
    }

    ds_processed = []

    bos_token = (
        rm_tokenizer.bos_token if rm_tokenizer.bos_token is not None else ""
    )

    for i in tqdm(
        range(len(ds)),
        desc=(
            f"Annotating dataset {ds_path} / split {split}"
            if world_size is None
            else f"Annotating dataset {ds_path} / split {split} (worker {local_rank}/{world_size})"
        ),
    ):
        item = ds[i]
        if "chosen" in item and "rejected" in item and not "responses" in item:
            # ! handle diffrent types of chosen / rejected
            if isinstance(item["chosen"], str):
                chosen_prompt_and_response = rm_tokenizer.apply_chat_template(
                    item["prompt_dialg"]
                    + [{"role": "assistant", "content": item["chosen"]}],
                    tokenize=False,
                ).replace(bos_token, "")
                rejected_prompt_and_response = rm_tokenizer.apply_chat_template(
                    item["prompt_dialg"]
                    + [{"role": "assistant", "content": item["rejected"]}],
                    tokenize=False,
                ).replace(bos_token, "")
            elif isinstance(item["chosen"], list):
                if len(item["chosen"]) == 1:
                    chosen_prompt_and_response = (
                        rm_tokenizer.apply_chat_template(
                            item["prompt_dialg"] + item["chosen"],
                            tokenize=False,
                        )
                    ).replace(bos_token, "")
                    rejected_prompt_and_response = (
                        rm_tokenizer.apply_chat_template(
                            item["prompt_dialg"] + item["rejected"],
                            tokenize=False,
                        )
                    ).replace(bos_token, "")
                    item.update(
                        {
                            "chosen": item["chosen"][0]["content"],
                            "rejected": item["rejected"][0]["content"],
                        }
                    )
                else:
                    chosen_prompt_and_response = (
                        rm_tokenizer.apply_chat_template(
                            item["chosen"][:2],
                            tokenize=False,
                        )
                    ).replace(bos_token, "")
                    rejected_prompt_and_response = (
                        rm_tokenizer.apply_chat_template(
                            item["rejected"][:2],
                            tokenize=False,
                        )
                    ).replace(bos_token, "")
                    item.update(
                        {
                            "chosen": item["chosen"][1]["content"],
                            "rejected": item["rejected"][1]["content"],
                        }
                    )

            chosen_score = rm_pipeline(
                chosen_prompt_and_response, **pipe_kwargs
            )[0]["score"]
            rejected_score = rm_pipeline(
                rejected_prompt_and_response, **pipe_kwargs
            )[0]["score"]
            item.update(
                {
                    "chosen_score": chosen_score,
                    "rejected_score": rejected_score,
                }
            )
        else:
            # ! use prompt_dialg + response as input
            prompt_and_responses = [
                rm_tokenizer.apply_chat_template(
                    item["prompt_dialg"]
                    + [{"role": "assistant", "content": response}],
                    tokenize=False,
                ).replace(bos_token, "")
                for response in item["responses"]
            ]
            scores = [  # ! check if any optimization can be done here
                rm_pipeline(prompt_and_response, **pipe_kwargs)[0]["score"]
                for prompt_and_response in prompt_and_responses
            ]
            chosen_idx, rejected_idx = np.argmax(scores), np.argmin(scores)
            chosen_score, rejected_score = (
                scores[chosen_idx],
                scores[rejected_idx],
            )
            chosen_response = item["responses"][chosen_idx]
            rejected_response = item["responses"][rejected_idx]

            item.update(
                {
                    "chosen_score": chosen_score,
                    "rejected_score": rejected_score,
                    "chosen": chosen_response,
                    "rejected": rejected_response,
                    "annot_scores": scores,
                }
            )
        ds_processed.append(item)

    ds_to_save = datasets.Dataset.from_list(ds_processed)
    src_fname = Path(ds_path).name
    if Path(output_dir).exists() and Path(output_dir).is_dir():
        output_path = Path(output_dir) / f"{src_fname}_annotated_{split}"
        output_path_str = str(output_path)
    elif not Path(output_dir).exists():
        Path(output_dir).mkdir(parents=True, exist_ok=True)
        output_path_str = output_dir

    if world_size is not None and local_rank is not None:
        output_path_str += f"_r{local_rank}_w{world_size}"

    ds_to_save.save_to_disk(output_path_str)

    print(f"Annotated dataset saved to {output_path_str}")


@cli.command()
@click.argument("split_ds_path", type=str)
@click.argument("output_ds_path", type=str)
def merge(split_ds_path: str, output_ds_path: str):
    """Merge multiple datasets into a single dataset."""
    pattern = r".*annotated.*_r\d+_w\d+$"
    matched_dirs = [
        p
        for p in Path(split_ds_path).iterdir()
        if p.is_dir() and re.search(pattern, str(p))
    ]
    if len(matched_dirs) == 0:
        raise ValueError(f"No split dataset found in {split_ds_path}")
    else:
        print(f"Found {len(matched_dirs)} split datasets in {split_ds_path}")
    ds_list = [datasets.load_from_disk(str(p)) for p in matched_dirs]
    merged_ds = datasets.concatenate_datasets(ds_list)
    merged_ds.save_to_disk(output_ds_path)
    # remove the split datasets
    for p in matched_dirs:
        try:
            shutil.rmtree(str(p))
        except Exception as e:
            # Print the exception
            print(f"An error occurred: {e}")
            print(f"Failed to remove {p}, please remove it manually")


if __name__ == "__main__":
    cli()
