import abc

from rewards.llm_response_processors import get_extract_code_answer_func
from rewards.test_runner import run_tests, io_run_tests


class RewardBase(abc.ABC):
    COLUMN_NAMES = tuple()  # TODO: Rethink this design choice

    # TODO: ADD a scores postrprocessor
    def __init__(self,
                 column_names: tuple[str] | None = None,
                 code_extractor_fn_name: str = 'default_completion', ):
        if column_names is None:
            column_names = self.COLUMN_NAMES
        self.column_names = column_names
        self.code_extractor_fn_name = code_extractor_fn_name
        self.code_extractor_fn = get_extract_code_answer_func(code_extractor_fn_name)

    @abc.abstractmethod
    def get_reward_for_one(self, idx: int, prompt, completion, *args, **kwargs) -> float:
        pass

    @property
    def __name__(self) -> str:
        return type(self).__name__

    def get_rewards(self, prompts, completions, **kwargs) -> list[float]:
        """
        function that must be compatible with TRL, returns a list of scores
        :param prompts: TRL-compatible
        :param completions: TRL-compatible
        :param kwargs: TRL-compatible
        :return scores: TRL-compatible
        """
        scores = list()

        try:
            [kwargs[_cn] for _cn in self.column_names]
        except KeyError as e:
            raise KeyError(f"column name {e} not found in the dataset")

        # transform kwargs to a list of dicts
        kwargs_t = [dict(zip(kwargs.keys(), values)) for values in zip(*kwargs.values())]
        for prompt, completion, kwarg in zip(prompts, completions, kwargs_t):
            # args = args_list[idx]
            scores.append(self.get_reward_for_one(prompt, completion, **kwarg))
        return scores

    def __call__(self, prompts, completions, **kwargs):
        return self.get_rewards(prompts, completions, **kwargs)


class DockerTestReward(RewardBase):
    COLUMN_NAMES = ('test_list',)

    def __init__(self,
                 column_names: tuple[str] | None = None,
                 code_extractor_fn_name: str = 'default_completion',
                 run_tests_separately: bool = True):
        super().__init__(column_names, code_extractor_fn_name)
        self.run_tests_separately = run_tests_separately

    def get_reward_for_one(self, prompt, completion, **kwargs) -> float:
        test_list = kwargs['test_list']
        if isinstance(test_list, str):
            test_list = [test_list]
        code = self.code_extractor_fn(prompt, completion)
        test_res = run_tests(code, test_list, run_tests_separately=self.run_tests_separately)
        return test_res['summary']['success_rate']


class DockerIOTestReward(RewardBase):
    COLUMN_NAMES = ('test_list',)

    def __init__(self,
                 column_names: tuple[str] | None = None,
                 code_extractor_fn_name: str = 'default_chat',
                 run_tests_separately: bool = True):
        super().__init__(column_names, code_extractor_fn_name)
        self.run_tests_separately = run_tests_separately
        self.code_extractor_fn = get_extract_code_answer_func(code_extractor_fn_name)

    def get_reward_for_one(self, prompt, completion, **kwargs) -> float:
        test_list = kwargs['test_list']
        if isinstance(test_list, str):
            test_list = [test_list]
        code = self.code_extractor_fn(completion)
        test_res = io_run_tests(code, test_list, run_tests_separately=self.run_tests_separately)
        return test_res['summary']['success_rate']
