import json
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

import click
import datasets
import numpy as np
import torch
from accelerate import Accelerator
from datasets import load_dataset, load_from_disk
from tqdm import tqdm
from transformers import AutoTokenizer, HfArgumentParser, pipeline


@click.command()
@click.option(
    "--ds_path",
    type=str,
    required=True,
    help="the path to datasets.Dataset. either local or remote is supported. the ds should be in standard preference dataset format, i.e., containing 'chosen' and'rejected' columns which are lists of conversations, or in DPO format, i.e., containing 'prompt' and 'responses' columns.",
)
@click.option(
    "--rm_path",
    type=str,
    required=True,
    help="the path to the reward model. either local or remote is supported. the model should be compatible with the Hugging Face Transformers library.",
)
@click.option(
    "--output_dir",
    type=str,
    required=True,
    help="the output directory / path to save the reward scores",
)
@click.option(
    "--split",
    type=str,
    default="train",
    help="split of the dataset to be annotated, by default 'train'",
)
def main(
    ds_path: str,
    rm_path: str,
    output_dir: str,
    split: str = "train",
) -> datasets.Dataset:
    """generate average reward score for each prompt of the given dataset using the reward model.

    Parameters
    ----------
    ds_path : str
        the path to datasets.Dataset. either local or remote is supported. the ds should be in standard preference dataset format, i.e., containing "chosen" and "rejected" columns which are lists of conversations, or in DPO format, i.e., containing "prompt" and "responses" columns.
    rm_path : str
        the path to the reward model. either local or remote is supported. the model should be compatible with the Hugging Face Transformers library.
    output_dir : str
        the output directory / path to save the reward scores
    split : str, optional
        split of the dataset to be annotated, by default "train"

    Returns
    -------
    datasets.Dataset
        the annotated dataset with "average_score" column added.

    NOTE: the DPO trainer would perform tokenization in its `__init__`,
    so we do not need to include the `input_ids` and `attention_mask` columns in the dataset.
    """
    try:
        ds = load_dataset(ds_path, split=split)
    except:
        ds = load_from_disk(ds_path)

    rm_tokenizer = AutoTokenizer.from_pretrained(rm_path)
    rm_pipeline = pipeline(
        "sentiment-analysis",
        model=rm_path,
        tokenizer=rm_tokenizer,
        device_map="auto",
        model_kwargs={"torch_dtype": torch.bfloat16},
        truncation=True,
    )
    pipe_kwargs = {
        "top_k": None,
        "function_to_apply": "none",
        "batch_size": 1,
    }

    ds_processed = []

    bos_token = (
        rm_tokenizer.bos_token if rm_tokenizer.bos_token is not None else ""
    )

    for i in tqdm(
        range(len(ds)), desc=f"Annotating dataset {ds_path} / split {split}"
    ):
        item = ds[i]
        if "chosen" in item and "rejected" in item:
            print("[WARNING] Data format not supported")

        elif "prompt" in item and "responses" in item:
            item = ds[i]
            prompt_and_responses = [
                [
                    (item["prompt"] + response).replace(bos_token, ""),
                ]
                for response in item["responses"]
            ]
            scores = [  # ! check if any optimization can be done here
                rm_pipeline(prompt_and_response, **pipe_kwargs)[0][0]["score"]
                for prompt_and_response in prompt_and_responses
            ]

            item.update(
                {
                    "average_score": np.mean(scores)
                }
            )
            item.pop("responses")
            ds_processed.append(item)

        else:
            raise ValueError(
                f"Dataset {ds_path} / split {split} should contain 'chosen' and 'rejected' columns or 'prompt' and 'responses' columns."
            )

    ds_to_save = datasets.Dataset.from_list(ds_processed)
    src_fname = Path(ds_path).name
    if Path(output_dir).exists() and Path(output_dir).is_dir():
        output_path = Path(output_dir) / f"{src_fname}_avg_reward"
        output_path_str = str(output_path)
        ds_to_save.save_to_disk(output_path_str)
    elif not Path(output_dir).exists():
        Path(output_dir).mkdir(parents=True, exist_ok=True)
        output_path_str = output_dir
        ds_to_save.save_to_disk(output_path_str)

    print(f"Average reward saved to {output_path_str}")


if __name__ == "__main__":
    main()
