from abc import ABC, abstractmethod
from utils import read_json
import random
import json
from tqdm import tqdm
import os
from transformers import AutoTokenizer, pipeline, AutoModelForSequenceClassification
import torch


class BaseRM(ABC):
    @abstractmethod
    def score(
        self,
        inputs: list,
        cache_dir: str | None = None,
        verbose: bool = False,
        use_tqdm: bool = True,
        **kwargs,
    ) -> list[dict]:
        """
        Scores the inputs.

        Args:
            inputs (list[dict]): List of inputs.
            cache_dir (str | None, optional): Path to the cache directory. Defaults to None.
            verbose (bool, optional): Whether to print the progress. Defaults to False.
            use_tqdm (bool, optional): Whether to use tqdm. Defaults to True.

        Returns:
            list[dict]: List of dictionaries containing the results
        """
        pass

    def pairwise_eval(
        self,
        input_dir: str,
        output_dir: str,
        cache_dir: str | None = None,
        verbose: bool = False,
        use_tqdm: bool = True,
        **kwargs,
    ) -> tuple[int, list[int]]:
        """
        Evaluates the model on the pairwise dataset.

        Args:
            input_dir (str): Path to the input directory.
            output_dir (str): Path to the output directory.
            batch_size (int): Batch size.
            cache_dir (str | None, optional): Path to the cache directory. Defaults to None.
            verbose (bool, optional): Whether to print the progress. Defaults to False.
            use_tqdm (bool, optional): Whether to use tqdm. Defaults to True.

        Returns:
            tuple[int, list[int]]: The number of ties and the list of winners.
        """
        data = read_json(input_dir)
        # gather the inputs
        outputs = []
        for d in data:
            outputs.append(d["instruction"] + "#" + d["output_1"])
            outputs.append(d["instruction"] + "#" + d["output_2"])
        outputs = list(set(outputs))  # remove duplicates
        output_hit = {x: False for x in outputs}
        inputs = []
        for d in data:
            key = d["instruction"] + "#" + d["output_1"]
            if not output_hit[key]:
                output_hit[key] = True
                inputs.append(
                    {"instruction": d["instruction"], "output": d["output_1"]}
                )
            key = d["instruction"] + "#" + d["output_2"]
            if not output_hit[key]:
                output_hit[key] = True
                inputs.append(
                    {"instruction": d["instruction"], "output": d["output_2"]}
                )
        input_map = {
            x["instruction"] + "#" + x["output"]: i for i, x in enumerate(inputs)
        }
        assert len(inputs) == len(outputs)
        if verbose:
            print(f"Scoring {len(inputs)} inputs")
        # score the inputs
        raw_results = self.score(
            inputs, cache_dir=cache_dir, verbose=verbose, use_tqdm=use_tqdm, **kwargs
        )
        # convert the scores to pairwise comparison
        results = []
        winners = []
        ties = 0
        random.seed(42)  # for reproducibility
        for d in data:
            key1 = d["instruction"] + "#" + d["output_1"]
            key2 = d["instruction"] + "#" + d["output_2"]
            raw_result1 = raw_results[input_map[key1]]
            raw_result2 = raw_results[input_map[key2]]
            score1 = raw_result1["score"]
            score2 = raw_result2["score"]
            if score1 > score2:
                winner = 1
            elif score1 < score2:
                winner = 2
            else:
                winner = random.choice([1, 2])
                ties += 1
            result = {
                "raw_result_1": raw_result1,
                "raw_result_2": raw_result2,
                "winner": winner,
                "is_tie": score1 == score2,
            }
            results.append(result)
            winners.append(winner)
        # save the results
        with open(output_dir, "w") as f:
            for result in results:
                f.write(json.dumps(result) + "\n")
        return ties, winners


class BaseRMHF(BaseRM):
    def __init__(
        self, model_pt: str, batch_size: int = 1, device: str = "cuda"
    ) -> None:
        """
        Initializes the BaseRM class.

        Args:
            model_pt (str): The path to the pre-trained model.
            batch_size (int, optional): The batch size. Defaults to 1.
            device (str, optional): The device to use for inference. Defaults to "cuda".
        """
        self.model_pt = model_pt
        self.device = device
        self.batch_size = batch_size

    @abstractmethod
    def score_batch(self, inputs: list[dict], **kwargs) -> list[dict]:
        """
        Scores the inputs.

        Args:
            inputs (list[dict]): List of inputs.

        Returns:
            list[dict]: List of dictionaries containing the results
        """
        pass

    def score(
        self,
        inputs: list,
        cache_dir: str | None = None,
        verbose: bool = False,
        use_tqdm: bool = True,
        **kwargs,
    ) -> list[dict]:
        """
        Scores the inputs.

        Args:
            inputs (list[dict]): List of inputs.
            batch_size (int, optional): Batch size. Defaults
            cache_dir (str | None, optional): Path to the cache directory. Defaults to None.
            verbose (bool, optional): Whether to print the progress. Defaults to False.
            use_tqdm (bool, optional): Whether to use tqdm. Defaults to True.

        Returns:
            list[dict]: List of dictionaries containing the results
        """
        outputs = []
        if cache_dir is not None:
            if not os.path.exists(cache_dir):
                os.makedirs(cache_dir)
            f = open(os.path.join(cache_dir, "rm.jsonl"), "w")
        for i in tqdm(
            range(0, len(inputs), self.batch_size),
            desc="Scoring",
            disable=not use_tqdm,
        ):
            _outputs = self.score_batch(inputs[i : i + self.batch_size], **kwargs)
            outputs.extend(_outputs)
            if cache_dir is not None:
                for j, out in enumerate(_outputs):
                    print(
                        json.dumps(
                            {
                                "input": inputs[i + j],
                                "output": out,
                            }
                        ),
                        file=f,
                        flush=True,
                    )
        if cache_dir is not None:
            f.close()
        return outputs


class OffsetBiasRM(BaseRMHF):
    def __init__(
        self, model_pt: str, batch_size: int = 1, device: str = "cuda"
    ) -> None:
        """
        Initializes the OffsetBias RM object for calling Hugging Face pipelines.
        """
        super().__init__(model_pt, batch_size, device)
        rm_tokenizer = AutoTokenizer.from_pretrained(model_pt)
        self.rm_pipe = pipeline(
            "sentiment-analysis",
            model=model_pt,
            device=device,
            tokenizer=rm_tokenizer,
            model_kwargs={"torch_dtype": torch.bfloat16},
        )
        self.tokenizer = rm_tokenizer

    @staticmethod
    def make_input(input: dict) -> dict:
        messages = [
            {"role": "user", "content": input["instruction"]},
            {"role": "assistant", "content": input["output"]},
        ]
        return messages

    def score_batch(self, inputs: list[dict], **kwargs) -> list[dict]:
        """
        Scores the inputs.

        Args:
            inputs (list[dict]): List of inputs.

        Returns:
            list[dict]: List of dictionaries containing the results
        """
        pipe_kwargs = {
            "return_all_scores": True,
            "function_to_apply": "none",
            "batch_size": len(inputs),
        }

        inputs = [self.make_input(input) for input in inputs]
        test_texts = [
            self.tokenizer.apply_chat_template(
                x, tokenize=False, add_generation_prompt=False
            ).replace(self.tokenizer.bos_token, "")
            for x in inputs
        ]
        with torch.no_grad():
            pipe_outputs = self.rm_pipe(test_texts, **pipe_kwargs)
        results = [{"score": x[0]["score"]} for x in pipe_outputs]
        return results


class PairRM(BaseRMHF):
    def __init__(
        self, model_pt: str, batch_size: int = 1, device: str = "cuda"
    ) -> None:
        """
        Initializes the PairRM.
        """
        super().__init__(model_pt, batch_size, device)
        rm_tokenizer = AutoTokenizer.from_pretrained(model_pt)
        self.rm_pipe = pipeline(
            "sentiment-analysis",
            model=model_pt,
            device=device,
            tokenizer=rm_tokenizer,
            model_kwargs={"torch_dtype": torch.bfloat16},
        )
        self.tokenizer = rm_tokenizer

    @staticmethod
    def make_input(input: dict) -> dict:
        messages = [
            {"role": "user", "content": input["instruction"]},
            {"role": "assistant", "content": input["output"]},
        ]
        return messages

    def score_batch(self, inputs: list[dict], **kwargs) -> list[dict]:
        """
        Scores the inputs.

        Args:
            inputs (list[dict]): List of inputs.

        Returns:
            list[dict]: List of dictionaries containing the results
        """
        pipe_kwargs = {
            "return_all_scores": True,
            "function_to_apply": "none",
            "batch_size": len(inputs),
        }

        inputs = [self.make_input(input) for input in inputs]
        test_texts = [
            self.tokenizer.apply_chat_template(
                x, tokenize=False, add_generation_prompt=False
            ).replace(self.tokenizer.bos_token, "")
            for x in inputs
        ]
        with torch.no_grad():
            pipe_outputs = self.rm_pipe(test_texts, **pipe_kwargs)
        results = [{"score": x[0]["score"]} for x in pipe_outputs]
        return results


class ArmoRM(BaseRMHF):
    def __init__(
        self, model_pt: str, batch_size: int = 1, device: str = "cuda", max_length: int = 4096
    ) -> None:
        """
        Initializes the ArmoRM object for calling Hugging Face pipelines.
        """
        super().__init__(model_pt, batch_size, device)
        rm_tokenizer = AutoTokenizer.from_pretrained(model_pt)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_pt,
            # device_map="auto",
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
        ).to(device)
        self.tokenizer = rm_tokenizer
        self.max_length = max_length
        self.device = device

    @staticmethod
    def make_input(input: dict) -> dict:
        messages = [
            {"role": "user", "content": input["instruction"]},
            {"role": "assistant", "content": input["output"]},
        ]
        return messages

    def score_batch(self, inputs: list[dict], **kwargs) -> list[dict]:
        """
        Scores the inputs.

        Args:
            inputs (list[dict]): List of inputs.

        Returns:
            list[dict]: List of dictionaries containing the results
        """
        inputs = [self.make_input(input) for input in inputs]
        input_ids = self.tokenizer.apply_chat_template(
            inputs,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.max_length,
        ).to(self.device)
        with torch.no_grad():
            outputs = self.model(input_ids)
            scores = outputs.score
        results = [{"score": score.float().item()} for score in scores]
        assert len(results) == len(inputs)
        return results


class SkyworkRM(BaseRMHF):
    def __init__(
        self, model_pt: str, batch_size: int = 1, device: str = "cuda", max_length: int = 4096
    ) -> None:
        """
        Initializes the ArmoRM object for calling Hugging Face pipelines.
        """
        super().__init__(model_pt, batch_size, device)
        rm_tokenizer = AutoTokenizer.from_pretrained(model_pt)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_pt,
            # device_map="auto",
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            num_labels=1,
        ).to(device)
        self.tokenizer = rm_tokenizer
        self.max_length = max_length
        self.device = device

    @staticmethod
    def make_input(input: dict) -> dict:
        messages = [
            {"role": "user", "content": input["instruction"]},
            {"role": "assistant", "content": input["output"]},
        ]
        return messages

    def score_batch(self, inputs: list[dict], **kwargs) -> list[dict]:
        """
        Scores the inputs.

        Args:
            inputs (list[dict]): List of inputs.

        Returns:
            list[dict]: List of dictionaries containing the results
        """
        inputs = [self.make_input(input) for input in inputs]
        input_ids = self.tokenizer.apply_chat_template(
            inputs,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.max_length,
        ).to(self.device)
        with torch.no_grad():
            outputs = self.model(input_ids)
            scores = outputs.logits
        results = [{"score": score[0].item()} for score in scores]
        assert len(results) == len(inputs)
        return results