from typing import Literal, Iterable, Any, cast
import torch
import abc
import numpy as np

from ...rm import RewardModel
from ...evaluator import CumulatedReward
from .impl import ReflStat, refl_stat_tensors
from core.utils.buf import TokenBuffer
from core.utils import iterate


type Stat = Literal["score", "count"] | ReflStat


class KwReflCumrewEvaluator[K: str](CumulatedReward, abc.ABC):
    """Cumulated rewards with reflective statistics, categorized by queries."""

    _data_by_keys: dict[K, dict[Stat, float]]

    def __init__(
        self,
        reward_model: RewardModel,
        cnt_in_reults: bool = False,
        rounding: int | None = None,
        
    ):
        super().__init__(reward_model, 1)
        self._data_by_keys = {}
        self._cnt_in_results = cnt_in_reults
        self._round_results = rounding

    def reset(self):
        super().reset()
        self._data_by_keys.clear()

    @abc.abstractmethod
    def _kwmap(self, **references) -> Iterable[K]:
        raise NotImplementedError
    
    def _get_stats(self) -> dict[Stat, torch.Tensor]:
        metrics: dict[Stat, torch.Tensor] = {}
        returns = self._cumr
        
        metrics["score"] = returns
        stat_tensors = refl_stat_tensors(self.host.session.cache)
        refl_stat_keys: list[ReflStat] = [
            "refl_tokens",
            "refl_steps",
            "refl_freq",
            "refl_freq_fp",
            "refl_freq_fn",
            "refl_freq_rej",
            "pi_freq_neg",
        ]
        for stat_key in refl_stat_keys:
            if (stat_tensor := stat_tensors.get(stat_key)) is not None:
                metrics[stat_key] = stat_tensor
        
        return metrics
    
    def after_reasoning(self, thought, outcome: TokenBuffer):
        super().after_reasoning(thought, outcome)
        host = self.host
        stats = self._get_stats()
        for idx in iterate.indices(self.host.context.shape):
            ref = host.ref(idx)
            for k in self._kwmap(**ref):
                try:
                    data_k = self._data_by_keys[k]
                except KeyError:
                    data_k = self._data_by_keys[k] = {}
                for metric, value in stats.items():
                    data_k[metric] = data_k.get(metric, 0) + value[idx].item()
                data_k["count"] = data_k.get("count", 0) + 1

    def get_results(self) -> dict[K, dict[Stat, Any]]:
        out: dict[K, dict[Stat, Any]] = {}
        for k, data in self._data_by_keys.items():
            out[k] = {}
            out[k]["count"] = n = data.get("count", 0)
            if n <= 0:
                continue
            out[k]["score"] = data["score"] / n
            if s := data.get("refl_steps"):
                for statkey in  ("refl_tokens", "refl_freq", "pi_freq_neg", "refl_freq_rej"):
                    if (stat_val := data.get(statkey)) is not None:
                        out[k][statkey] = stat_val  / s
                if (s_neg := data.get("pi_freq_neg")) is not None:
                    if (fp := data.get("refl_freq_fp")) is not None:
                        out[k]["refl_freq_fp"] = fp / s_neg if s_neg else np.nan
                    if (fn := data.get("refl_freq_fn")) is not None:
                        out[k]["refl_freq_fn"] = fn / (s - s_neg) if s - s_neg else np.nan
            if (precision := self._round_results) is not None:
                out[k] = {k: round(v, precision) if isinstance(v, float) else v
                          for k, v in out[k].items()}
        return out

    def present_results(self):
        try:
            import pandas as pd
            df = pd.DataFrame(self.get_results())
            print(df.T.to_markdown())
        except ImportError:
            super().present_results()
    
    def _reject_rate_getter(self, results: Any, default: float | None = None):
        try:
            _results = self._cast_result_type(results)
        except TypeError:
            if default is not None:
                return lambda **references: default
            else:
                raise

        def get_reject_rate(**references) -> float:
            for k in self._kwmap(**references):
                if (stats := _results.get(k)) is not None:
                    rej = stats.get("refl_freq_rej")  # the reject rate is recorded
                    if rej is None:  # try to compute using other statistics
                        try:
                            fp = stats["refl_freq_fp"]
                            fn = stats["refl_freq_fn"]
                            f = stats["pi_freq_neg"]
                        except KeyError:
                            pass
                        else:
                            if f >= 1:
                                rej = f * (1 - fp)
                            elif f <= 0:
                                rej = (1 - f) * fn
                            else:
                                rej = f * (1 - fp) + (1 - f) * fn
                            assert not np.isnan(rej)
                    if rej is not None:
                        return rej
            # returns default if no key is recorded
            if default is not None:
                return default
            else:
                raise KeyError("reject rate is not recorded in results")
        
        return get_reject_rate

    @classmethod
    def _cast_result_type(cls, results: Any, check: bool = True) -> dict[K, dict[Stat, Any]]:
        if check:
            if not isinstance(results, dict):
                raise TypeError
            elif any(not isinstance(v, dict) for v in results.values()):
                raise TypeError
        return cast(dict[K, dict[Stat, Any]], results)
