import json
import os
import re
import shutil
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional

import click
import datasets
import numpy as np
import pandas as pd
import torch
from accelerate import Accelerator
from datasets import load_dataset, load_from_disk
from tqdm import tqdm
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    HfArgumentParser,
    pipeline,
)


@click.command()
@click.argument("policy_gen_json_path", type=str)
@click.argument("ref_gen_json_path", type=str)
@click.argument("rm_path", type=str)
def main(policy_gen_json_path: str, ref_gen_json_path: str, rm_path: str):
    try:
        policy_df = pd.read_json(policy_gen_json_path)
    except:
        policy_df = pd.read_json(policy_gen_json_path, lines=True)
    try:
        ref_df = pd.read_json(ref_gen_json_path)
    except:
        ref_df = pd.read_json(ref_gen_json_path, lines=True)

    merged_df = pd.merge(
        policy_df,
        ref_df,
        on="instruction",
        how="inner",
        suffixes=("_policy", "_ref"),
    )

    ds = datasets.Dataset.from_pandas(merged_df)

    # load RM
    rm_tokenizer = AutoTokenizer.from_pretrained(rm_path)
    rm_model = AutoModelForSequenceClassification.from_pretrained(
        rm_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        attn_implementation="eager",
        num_labels=1,
    )
    rm_pipeline = pipeline(
        "sentiment-analysis",
        model=rm_model,
        tokenizer=rm_tokenizer,
        device_map="auto",
        model_kwargs={"torch_dtype": torch.bfloat16},
        truncation=True,
    )
    pipe_kwargs = {
        "top_k": None,
        "function_to_apply": "none",
        "batch_size": 1,
    }

    ds_processed: List[dict] = []

    bos_token = (
        rm_tokenizer.bos_token if rm_tokenizer.bos_token is not None else ""
    )

    for i in tqdm(
        range(len(ds)),
        desc="Annotating local eval ds",
    ):
        item = ds[i]
        policy_prompt_and_resp = [
            rm_tokenizer.apply_chat_template(
                [
                    {"role": "user", "content": item["instruction"]},
                    {"role": "assistant", "content": item["output_policy"]},
                ],
                tokenize=False,
            ).replace(bos_token, "")
        ]
        ref_prompt_and_resp = [
            rm_tokenizer.apply_chat_template(
                [
                    {"role": "user", "content": item["instruction"]},
                    {"role": "assistant", "content": item["output_ref"]},
                ],
                tokenize=False,
            ).replace(bos_token, "")
        ]

        policy_score = rm_pipeline(policy_prompt_and_resp, **pipe_kwargs)[0][0][
            "score"
        ]
        ref_score = rm_pipeline(ref_prompt_and_resp, **pipe_kwargs)[0][0][
            "score"
        ]

        ds_processed.append(
            {
                "instruction": item["instruction"],
                "output_policy": item["output_policy"],
                "output_ref": item["output_ref"],
                "rm_score_policy": policy_score,
                "rm_score_ref": ref_score,
            }
        )

    processed_ds = datasets.Dataset.from_list(ds_processed)

    # print the win rate
    win_counts = np.sum(
        np.array(processed_ds["rm_score_policy"])
        > np.array(processed_ds["rm_score_ref"])
    )
    total_counts = len(processed_ds)
    win_rate = win_counts / total_counts
    print(f"Win rate: {win_rate:.4f}")

    # save the processed dataset
    processed_ds.to_json(
        f"{Path(policy_gen_json_path).parent}/merged.jsonl",
        orient="records",
        lines=True,
    )


if __name__ == "__main__":
    main()
