import torch
import random
from core.model import LLM, ValueModel
import torch
import abc
import dataclasses as dc
from typing import NamedTuple, Final, Any, Literal
from core.utils.buf import TokenBuffer, Seqs
from core.utils.th import NamedDataset, NestedTensorList
from core.utils.iterate import indices as iterate_indices
from core.utils import th
from core.reasoning.rm import RewardModel
from core.reasoning.reasoners.generative import GenHook, GenerativeReasoner, Callbacks
from core.inference import Inference
from core.inference.sampling import Sampling


@dc.dataclass
class InferenceBuffer(TokenBuffer):

    __sequential__ = ['logits', 'probs']

    prompt_lengths: torch.Tensor
    mask: torch.Tensor
    logits: torch.Tensor | None = None
    probs: torch.Tensor | None = None
    process_rewards: torch.Tensor | None = None


class ReplayDataset(NamedDataset):

    tokens: list[torch.Tensor]
    lengths: list[int]
    prompt_lengths: list[int]
    terminated: list[bool]
    truncated: list[bool]
    logits: list[torch.Tensor]
    probs: list[torch.Tensor]
    rewards: list[float]


class Collector(Callbacks[Literal['trajectory']], GenHook):

    reward_model: RewardModel | None = None
    _data: ReplayDataset

    class Require(NamedTuple):

        content: bool = True
        truncated: bool = True
        terminated: bool = True
        logits: bool = False
        probs: bool = False
        rewards: bool = False

    def __init__(
        self,
        device: torch.device | str | None = None,
        require: Require = Require(),
        *,
        prob_bias: float = 1e-2,
        allow_abortion: bool = False,
    ):
        self._device = (
            None if device is None
            else device if isinstance(device, torch.device)
            else torch.device(device)
        )
        
        self.require: Final = require
        self._require_logits: Final = require.logits or require.probs
        self._allow_abortion = allow_abortion

        self._prob_bias = prob_bias
        self._context_length: int = -1
        self._cnt: int = 0
        self._data = self._init_data()

    def _init_data(self):
        return ReplayDataset()

    def reset(self):
        self._context_length = -1
        self._cnt = 0

        data = self._data
        data.clear()
        if self.require.content:
            data.tokens = []
            data.lengths = []
            data.prompt_lengths = []
        if self.require.terminated:
            data.terminated = []
        if self.require.truncated:
            data.truncated = []
        if self.require.logits:
            data.logits = []
        if self.require.probs:
            data.probs = []
        if self.require.rewards:
            data.rewards = []
    
    def attach(self, host: GenerativeReasoner):
        super().attach(host)
        self.reset()
    
    def after_context(self, prompt_lengths: torch.Tensor):
        device = self._device
        host = self.host
        ctx = host.context        
        tokens = ctx.tokens
        logits = ctx.data.get('logits')
        terminated = host.terminated
        mask = ~terminated
        
        if self._context_length < 0:
            self._context_length = ctx.max_length
        elif ctx.max_length != self._context_length:
            raise IndexError("Incongruous context length.")

        if self.require.logits:
            if logits is None:
                raise ValueError("the logits is required by the buffer.")
            logits_ = logits.to(device=device).clone()
        else:
            logits_ = None

        if self.require.probs:
            if logits is None:
                raise ValueError("Logits is required to compute probabilities.")
            if not isinstance(self.host.inference, Sampling):
                raise ValueError("Probability can be computed only when using the sampling inference.")
            temperature = self.host.inference.temperature
            probs = th.probabilities(logits, tokens, temperature, bias=self._prob_bias)
            probs = probs.to(device=device)
        else:
            probs = None
        
        if self.require.rewards or self._allow_abortion:
            if (rm := self.reward_model) is None:
                raise ValueError("Process supervision requires a reward model.")
            process_supervision = rm.supervise_process(
                self.host.ref,
                self.host.llm.preprocessor, ctx, prompt_lengths, mask,
                rewards=self.require.rewards,
                abortions=self._allow_abortion,
            )
            process_rewards = process_supervision.get("rewards")
            abortions = process_supervision.get("abortions")
        else:
            process_rewards = None
            abortions = None

        buf = InferenceBuffer(
            tokens.to(device=device).clone(),
            ctx.lengths.to(device=device).clone(),
            prompt_lengths.to(device=device).clone(),
            mask.to(device=device).clone(),
            logits=logits_,
            probs=probs,
            process_rewards=process_rewards,
        )
        self.trajectory.append(buf)
        
        if abortions is not None:
            self.host.terminated = self.host.terminated | abortions

    @property
    def trajectory(self) -> NestedTensorList[InferenceBuffer]:
        out = self.cache["trajectory"]
        assert isinstance(out, NestedTensorList)
        return out

    def before_reasoning(self, input: TokenBuffer):
        
        context = self.host.context
        if self._require_logits and context.data.get('logits') is None:
            llm = self.host.llm
            context.data['logits'] = context.make_sequence((llm.vocab_size,), torch.float)

        self.cache["trajectory"] = NestedTensorList()

    def after_reasoning(self, thought, outcome):
        data = self._data
        terminated = self.host.terminated
        trajectory = self.trajectory
        if len(trajectory) == 0:
            return
        shape = trajectory[0].shape
        assert all(seq.shape == shape for seq in trajectory)

        # compute outcome rewards
        if self.require.rewards:
            if (rm := self.reward_model) is None:
                raise ValueError("Rewards are required while missing a reward model.")
            outcome_rewards = rm.outcome_rewards(
                self.host.ref, self.host.llm.preprocessor, outcome, terminated
            ).to(device=self._device)
        else:
            outcome_rewards = None
        
        for idx in iterate_indices(shape):
            _case_is_not_empty = False
            for t, buf in enumerate(trajectory):
                valid = bool(buf.mask[idx])
                if not valid:
                    continue
                _case_is_not_empty = True

                if self.require.content:
                    tokens = buf.tokens[idx].clone()
                    length = int(buf.lengths[idx])
                    prompt_length = int(buf.prompt_lengths[idx])
                    assert tokens.shape == (self._context_length,)
                    assert 0 <= prompt_length <= self._context_length
                    data.tokens.append(tokens)
                    data.lengths.append(length)
                    data.prompt_lengths.append(prompt_length)
                if self.require.terminated:
                    data.terminated.append(False)
                if self.require.truncated:
                    data.truncated.append(False)
                if self.require.logits:
                    assert buf.logits is not None
                    data.logits.append(buf.logits[idx].clone())
                if self.require.probs:
                    assert buf.probs is not None
                    data.probs.append(buf.probs[idx].clone())
                if self.require.rewards:
                    assert buf.process_rewards is not None
                    reward = float(buf.process_rewards[idx])
                    data.rewards.append(reward)
            if _case_is_not_empty:
                if self.require.truncated:
                    data.truncated[-1] = not bool(terminated[idx])
                if self.require.terminated:
                    data.terminated[-1] = bool(terminated[idx])
                if self.require.rewards:
                    assert outcome_rewards is not None
                    data.rewards[-1] += float(outcome_rewards[idx])

    @property
    def data(self):
        return self._data


def split_data[T](data: list[T], ratio: float):

    data = data.copy()
    random.shuffle(data)
    n = int(len(data) * ratio)
    return data[:n], data[n:]


def average_return(data: ReplayDataset, terminated_only=False):

    returns: list[float] = []

    rewards = data.rewards
    terminated_ = data.terminated
    truncated_ = data.truncated

    cum_r = 0.
    for i in range(len(data)):
        cum_r += rewards[i]
        terminated = terminated_[i]
        truncated = truncated_[i]
        if terminated or (truncated and not terminated_only):
            returns.append(cum_r)
        if terminated or truncated:
            cum_r = 0.
    
    if len(returns) > 0:
        return sum(returns) / len(returns)
    else:
        return float('nan')


def split_by_suffixes(buf: InferenceBuffer, suffixes: Seqs | None):

    if suffixes is None:
        for i, tokens in buf.enumerate():
            prompt_length = int(buf.prompt_lengths[i])
            yield tokens[:prompt_length], tokens[prompt_length:]
    else:
        suffix, ignore_mask = th._combine_suffixes(*suffixes)
        suffix_lengths = torch.tensor([len(suf) for suf in suffixes],
                                      device=buf.device, dtype=torch.int64)
        nsuf = suffix.size(0)
        match = th.match_suffix(buf.tokens, suffix, ignore_mask)
        assert match.shape == (*buf.shape, buf.max_length, nsuf)
        match[buf._mask(stop=buf.prompt_lengths)] = False

        for i, tokens in buf.enumerate():
            prompt_length = int(buf.prompt_lengths[i])
            prompt = tokens[:prompt_length]
            
            idxs = torch.nonzero(match[i])
            assert idxs.shape == (idxs.size(0), 2)
            
            if nsuf > 1:
                order = torch.argsort(idxs[:, 0])
                idxs = idxs[order]
            
            pos_idxs, suf_idx = (idxs[:, 0] + 1).tolist(), idxs[:, 1]
            suf_lens = suffix_lengths[suf_idx].tolist()
            
            pos = prompt_length
            for pos_next, suf_len in zip(pos_idxs, suf_lens):
                yield prompt, tokens[pos: pos_next - suf_len]
                pos = pos_next


def estimate_probs(
    data: ReplayDataset,
    llm: LLM,
    temperature: float,
    batch_size: int,
    device: str | torch.device | None,
):
    n = len(data)
    model = llm.model
    probs = []
    for i in range(0, n, batch_size):
        b = min(n - i, batch_size)
        batch_tokens = torch.stack(data.tokens[i: i + b])
        batch_logits = model.forward(batch_tokens)
        batch_probs = th.probabilities(batch_logits, batch_tokens, temperature)
        probs.append(batch_probs.to(device=device))
    probs = torch.cat(probs)
    assert probs.size(0) == n

    return probs


def estimate_values(
    vf: ValueModel,
    data: ReplayDataset,
    batch_size: int,
    device: torch.device | str | None,
):
    n_seq = len(data)
    list_v = []
    for i in range(0, n_seq, batch_size):
        b = min(batch_size, n_seq - i)
        tokens = torch.stack(data.tokens[i: i + b])
        v = vf.forward(tokens)  # [batch_size, length]
        list_v.append(v.to(device=device))
    v = torch.cat(list_v)  # [n_seq, length] or [n_seq]
    assert v.size(0) == n_seq
    del list_v

    return v
