import torch
import numpy as np
import dataclasses as dc
import functools

from core.model import Preprocessor
from core.utils import TokenBuffer
from core.utils.th import nonzero_indices

from typing import Any, Sequence, Callable, Literal, overload, NamedTuple


type _RefFn = Callable[[tuple[int, ...]], dict[str, Any]]
type _ProcessSignal = Literal["rewards", "abortions"]


class RewardModel:

    dtype: torch.dtype = torch.float64
    "The data type of reward values"

    disable_orm: bool = False
    disable_prm: bool = False
    skip_special_tokens: bool = False

    def outcome_rewards(
        self,
        ref_fn: _RefFn,
        decoder: Preprocessor,
        tokens: TokenBuffer,
        mask: torch.Tensor,
        out: torch.Tensor | None = None,
    ):
        device = decoder.device
        decode = functools.partial(decoder.decode,
                                   skip_special_tokens=self.skip_special_tokens)

        if out is None:
            out = torch.zeros(tokens.shape, device=device, dtype=self.dtype)
        if self.disable_orm:
            return out
        for idx in nonzero_indices(mask):
            outcome = decode(tokens.tokens_at(*idx))
            ref = ref_fn(idx) | self._parse_outcome(outcome)
            r = self.outcome_reward(outcome, **ref)
            out[idx] += r
        return out

    def process_rewards(
        self,
        ref_fn: _RefFn,
        decoder: Preprocessor,
        tokens: TokenBuffer,
        prompt_lengths: torch.Tensor,
        mask: torch.Tensor,
        out: torch.Tensor | None = None,
    ):
        device = decoder.device
        decode = functools.partial(decoder.decode,
                                   skip_special_tokens=self.skip_special_tokens)
        if out is None:
            out = torch.zeros(tokens.shape, device=device, dtype=self.dtype)
        if self.disable_prm:
            return out
        for idx in nonzero_indices(mask):
            prompt_len = int(prompt_lengths[idx])
            prompt = decode(tokens.tokens_at(*idx, stop=prompt_len))
            output = decode(tokens.tokens_at(*idx, start=prompt_len))
            ref = ref_fn(idx) | self._parse_process(prompt, output)
            r = self.process_reward(prompt, output, **ref)
            out[idx] += r
        return out
    
    def supervise_process(
        self,
        ref_fn: _RefFn,
        decoder: Preprocessor,
        tokens: TokenBuffer,
        prompt_lengths: torch.Tensor,
        mask: torch.Tensor,
        rewards: bool = True,
        abortions: bool = True,
    ) -> dict[_ProcessSignal, torch.Tensor]:
        device = decoder.device
        decode = functools.partial(decoder.decode,
                                   skip_special_tokens=self.skip_special_tokens)
        if rewards:
            out_rewards = torch.zeros(tokens.shape, device=device, dtype=self.dtype)
        if abortions:
            out_abort = torch.zeros(tokens.shape, device=device, dtype=torch.bool)
        for idx in nonzero_indices(mask):
            prompt_len = int(prompt_lengths[idx])
            prompt = decode(tokens.tokens_at(*idx, stop=prompt_len))
            output = decode(tokens.tokens_at(*idx, start=prompt_len))
            ref = ref_fn(idx) | self._parse_process(prompt, output)
            if rewards and not self.disable_prm:
                r = self.process_reward(prompt, output, **ref)
                out_rewards[idx] += r
            if abortions:
                out_abort[idx] = self.abort_process(prompt, output, **ref)

        out: dict[_ProcessSignal, torch.Tensor] = {}
        if rewards:
            out["rewards"] = out_rewards
        if abortions:
            out["abortions"] = out_abort
        
        return out
    
    def _parse_outcome(self, outcome: str) -> dict[str, Any]:
        """Parse the outcome into a dictionary containing information
        required to compute the outcome rewards."""
        return {}
    
    def _parse_process(self, prompt: str, output: str) -> dict[str, Any]:
        """Parse a thought context in the reasoning process into a dictionary
        containing information required to compute the process rewards / abortions."""
        return {}

    def outcome_reward(self, _outcome: str, **references) -> float:
        """Compute the outcome reward after the outcome is produced. `references` include
        `task.get_ref_dict(cot)` and the outputs of `_parse_outcome`."""
        return 0
    
    def process_reward(self, prompt: str, output: str, **references) -> float:
        """Compute the process reward after each thought context. `references` include
        the `task.get_ref_dict(cot)` and the outputs of `_parse_process`."""
        return 0
    
    def abort_process(self, prompt: str, output: str, **references) -> bool:
        """This detects whether reasoning process should be aborted when collecting trajectories for RL.
        If RL enables abortions and `abort_process` returns `True`, then the trajectory is enforced
        to terminate. This ensures Markov property of reward signals. `references` include
        `task.get_ref_dict(cot)` and the outputs of `_parse_process`."""
        return False


class RewardWrapper(RewardModel):

    def __init__(self, base: RewardModel):
        self.base = base
        super().__init__()

    @property
    def dtype(self):
        return self.base.dtype

    @property
    def disable_orm(self):
        return self.base.disable_orm
    
    @property
    def disable_prm(self):
        return self.base.disable_prm
    
    @property
    def skip_special_tokens(self):
        return self.base.skip_special_tokens

    def _parse_outcome(self, outcome: str) -> dict[str, Any]:
        return self.base._parse_outcome(outcome)
    
    def _parse_process(self, prompt: str, output: str) -> dict[str, Any]:
        return self.base._parse_process(prompt, output)

    def process_reward(self, prompt: str, output: str, **references) -> float:
        return self.base.process_reward(prompt, output, **references)

    def outcome_reward(self, _outcome: str, **references) -> float:
        return self.base.outcome_reward(_outcome, **references)
