import abc
import torch
import json
from pprint import pprint
from pathlib import Path

from core.utils.buf import TokenBuffer
from .reasoners.generative import GenerativeReasoner, GenHook, Callbacks
from .rm import RewardModel


from typing import Literal, Callable, Any, Iterable


type _Tokens = tuple[torch.Tensor, torch.Tensor] | list[torch.Tensor]
type _RefFn = Callable[[tuple[int, ...]], dict[str, Any]]


_empty_ref: _RefFn = lambda idx: {}


class Evaluator[KCache](Callbacks[KCache], abc.ABC):

    def attach(self, host: GenerativeReasoner):
        super().attach(host)
        self.reset()
    
    @abc.abstractmethod
    def reset(self):
       ...

    @abc.abstractmethod
    def get_results(self) -> Any:
        ...

    def present_results(self):
        pprint(self.get_results())

    def save(self, path: Path | str):
        results = self.get_results()
        with open(path, 'wt') as f:
            json.dump(results, f, indent=4)



class CumulatedReward(Evaluator[Literal['cum_rewards', 'process_mask', 'cum_discount']],
                      GenHook):

    def __init__(self, reward_model: RewardModel, discount: float = 1):
        super().__init__()
        self.reward_model = reward_model
        self.discount = discount

    @property
    def _cumr(self) -> torch.Tensor:
        cumr = self.cache['cum_rewards']
        assert isinstance(cumr, torch.Tensor)
        return cumr
    
    @property
    def _cum_discount(self) -> torch.Tensor:
        out = self.cache['cum_discount']
        assert isinstance(out, torch.Tensor)
        return out

    def reset(self):
        pass

    def clear_(self, mask: torch.Tensor | None = None):
        cumr = self._cumr
        if mask is None:
            cumr.zero_()
        else:
            cumr.masked_fill_(mask, 0)
        if self.discount != 1:
            cum_discount = self._cum_discount
            if mask is None:
                cum_discount.fill_(1)
            else:
                cum_discount.masked_fill_(mask, 1)

    def before_reasoning(self, input: TokenBuffer):
        ctx = self.host.context
        dtype = self.reward_model.dtype
        self.cache["cum_rewards"] = ctx.make_tensor((), dtype, 0)
        if self.discount != 1:
            self.cache["cum_discount"] = ctx.make_tensor((), dtype, 1)

    def after_context(self, prompt_lengths: torch.Tensor):
        host = self.host
        terminated = host.terminated
        processing = ~terminated
        cumr = self._cumr
        rewards = self.reward_model.process_rewards(
            host.ref,
            host.llm.preprocessor,
            host.context,
            prompt_lengths,
            mask=processing,
        )
        if self.discount == 1:
            cumr.add_(rewards)
        else:
            cum_discount = self._cum_discount
            cumr.add_(rewards * cum_discount)
            cum_discount[processing] *= self.discount
    
    def after_reasoning(self, thought, outcome: TokenBuffer):
        host = self.host
        terminated = host.terminated
        cumr = self._cumr
        rewards = self.reward_model.outcome_rewards(
            host.ref,
            host.llm.preprocessor,
            outcome,
            mask=terminated,
        )
        cumr.add_(rewards)

    def get_results(self) -> torch.Tensor:
        return self._cumr.clone()


class AverageReturn(CumulatedReward):

    def __init__(
        self,
        reward_model: RewardModel,
        cnt_in_result: bool = True,
        discount: float = 1,
    ):
        super().__init__(reward_model, discount)

        self.cnt_in_result = cnt_in_result
        self.cnt: int = 0
        self.mean: float | None = None

    def reset(self):
        self.cnt: int = 0
        self.mean: float | None = None

    @staticmethod
    def _update_mean(old_n: int, old_mean: float | None, n: int, mean: float):
        if n == 0:
            return old_n, old_mean
        elif old_n == 0:
            return n, mean
        else:
            assert old_mean is not None
            new_n = n + old_n
            new_mean = old_mean * (old_n / new_n) + mean * (n / new_n)
            return new_n, new_mean
    
    def _update(self):
        returns = self._cumr
        cnt = returns.nelement()
        mean = float(returns.mean())
        self.cnt, self.mean = self._update_mean(self.cnt, self.mean, cnt, mean)
    
    def after_reasoning(self, thought, outcome):
        super().after_reasoning(thought, outcome)
        self._update()

    def get_results(self) -> Any:
        if self.cnt_in_result:
            return {'count': self.cnt, 'average': self.mean}
        else:
            return self.mean



class DictEvaluator(Evaluator[Any]):

    def __init__(self,
                 merge_fn: Callable[..., Any],
                 *evaluators: Evaluator,
                 **kwevaluators: Evaluator):
        self._evaluators = evaluators
        self._kwevaluators = kwevaluators
        self._merge_fn = merge_fn

    def __iter__(self):
        yield from self._evaluators
        yield from self._kwevaluators.values()
    
    def reset(self):
       for e in self:
           e.reset()

    @abc.abstractmethod
    def get_results(self) -> Any:
        args = tuple(e.get_results() for e in self._evaluators)
        kwargs = {k: e.get_results() for k, e in self._kwevaluators.items()}
        return self._merge_fn(*args, **kwargs)
