import torch
import functools
from typing import Literal, Any, cast, Callable, Protocol
from contextlib import contextmanager
from core.utils.buf import TokenBuffer
from core.utils.th import nonzero_indices, fetch, clone_to, NestedTensorDict
from core.utils.kv import iterate_kv_tensors
from core.utils import iterate
from core.inference import Context, Session, Sampling
from core.model import ValueModel, Preprocessor

from ...reasoners.generative import MTPReasoner, GenHook, Callbacks
from .utils import IndividualContext, ContextStack, check_conflict
from .types import Symbol, PathLike, BinaryVerifier


type ReflEvent = Literal["__gen__", "__out__", "__reflection__"]
type ReflInfo = Literal["__n_step__", "__attempt__"]
type ReflStat = Literal[
    "refl_tokens",
    "refl_steps", # number of reflective steps 
    "refl_freq",  # frequency of valid reflection
    "refl_freq_fp",  # frequency of false positive
    "refl_freq_fn",  # frequency of false negative
    "pi_freq_neg",  # frequency of negative step
    "refl_freq_rej",  # frequency of rejection
]
type ReflCache = ReflStat | Literal["__stack__", "__random_reject_ratio__"]


class ReflHook(Protocol):
    def on_reflection(self, mask: torch.Tensor, reject: torch.Tensor): pass
    def after_reflection(self, mask: torch.Tensor, reject: torch.Tensor): pass


def convert_reasoner_impl[KEvent, KInfo, Thought](
    base: type[MTPReasoner[KEvent, KInfo, Any, Thought]],
    budget: int | None = None,
    context_length: int | None = None,
    reject_mode: Literal["retry", "revise"] = "retry",
    max_retry: int | None = None,
    external_verifier: Literal["vf", "oracle", "random"] | None = None,
    revise_temperature: float | None = None,
    reflect_temperature: float = 0,
    enable_statistics: bool = False,
    _oracle: Callable[[str, str], bool] | None = None,  # for oracle reflector and statistics
    _vf_path: PathLike | None = None,  # for value reflector
    _vf_tolerance: float | None = None,  # for value reflector
    _random_reject_ratio: float | None = None,  # for random reflector
    _get_reject_ratio_map: Callable[[PathLike], Callable[..., float]] | None = None,  # for random reflector
    _error_as_context_prompt: bool = True,  # For RL to omit rejected steps.
) -> type[MTPReasoner[KEvent, KInfo, ReflCache | Any, Thought]]:
    """
    Convert any recursive reasoner class into a reflective reasoner class.
    """

    class ReflectionReasoner(base):

        type ReflContext = Context[ReflEvent | KEvent, ReflInfo | KInfo, Any]
        type ReflSession = Session[ReflEvent | KEvent, ReflInfo | KInfo, Any, ReflCache]
        type TracebackStack = ContextStack[ReflEvent | KEvent, ReflInfo | KInfo, Any]

        @property
        def _stack(self) -> TracebackStack:
            stack = self._castsess(self.session).cache["__stack__"]
            return cast(ReflectionReasoner.TracebackStack, stack)
        
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self._refl_symbol_seqs = {
                v: self._as_seq(v) for k, v in Symbol.__members__.items()
            }
            # eos = self._as_seq(self.llm.preprocessor.tokenizer.eos_id)
            assert all(v.shape == (1,) for v in self._refl_symbol_seqs.values())
            self._eos.append(self._refl_symbol_seqs[Symbol.failure])
            self._raw_context_length = self.session.context_length
            if context_length is not None:
                self.session.context_length = context_length  # use the extended context length
            if budget is not None:
                self.max_step = self.max_step + budget

            if external_verifier == "oracle":
                if _oracle is None:
                    raise ValueError("Missing oracle verifier.")
                self._oracle = _oracle
            elif external_verifier == "random":
                if _random_reject_ratio is not None:
                    self._random_reject_ratio = _random_reject_ratio
                    self._random_reject_ratio_map = None
                elif _get_reject_ratio_map is not None:
                    if (ckpt_dir := self.llm.checkpoint_dir) is None:
                        raise ValueError("reuqire checkpoint directory.")
                    self._random_reject_ratio = None
                    self._random_reject_ratio_map = _get_reject_ratio_map(ckpt_dir)
                else:
                    raise ValueError("Can not infer rejection ratio for random reflection.")
            elif external_verifier == "vf": 
                self.vf = self.llm.get_value_model(_vf_path, trained_only=True)
                if _vf_tolerance is None:
                    raise ValueError("`vf_tolerance` must be specified for using value-based reflection.")
                self.vf_tolerance = _vf_tolerance
            elif external_verifier is None:
                self.reflective_inference = Sampling(self.llm, temperature=reflect_temperature)
            else:
                raise NotImplementedError(f"\"{external_verifier}\" reflector is not supported.")
            
            self.__revise_temperature = revise_temperature
            if isinstance(self.__revise_temperature, int):
                self.__revise_temperature = float(self.__revise_temperature)
            if revise_temperature is not None:
                if not isinstance(self.inference, Sampling):
                    raise ValueError("Set revision temperature only when decoding by sampling.")
                base_temperature = self.inference.temperature
                assert not isinstance(base_temperature, torch.Tensor)
                if isinstance(base_temperature, int):
                    base_temperature = float(base_temperature)
                self.__base_temperature = base_temperature
            
            self.__in_context_reflect = (reject_mode == "retry") and external_verifier is None
            self.__require_traceback = (reject_mode == "revise") or max_retry is not None

            self._statistics = None
            if enable_statistics:
                self._statistics = _ReflStatistics(_oracle, _ith_attempt_only=0)
                self._statistics.attach(self)
        
        def _castctx(self, ctx: Context):
            return cast(ReflectionReasoner.ReflContext, ctx)

        def _castsess(self, sess: Session):
            return cast(ReflectionReasoner.ReflSession, sess)
        
        @contextmanager
        def _pause(self, mask: torch.Tensor):
            old = self.terminated
            self.terminated = terminated = old | mask
            try:
                yield terminated
            finally:
                self.terminated = torch.where(mask, old, self.terminated)

        def _launch(self, input: TokenBuffer) -> None:
            super()._launch(input)
            if external_verifier is None:
                self.reflective_inference.connect(self.session)
            ctx = self._castctx(self.context)
            sess = self._castsess(self.session)
            ctx.set_event("__out__")
            ctx.info["__n_step__"] = ctx.make_tensor((), dtype=torch.int32)
            ctx.info["__attempt__"] = ctx.make_tensor((), dtype=torch.int32)
            if self.__require_traceback:
                sess.cache["__stack__"] = ContextStack.empty(ctx.shape, indices_device='cpu')
            if external_verifier == "random" and self._random_reject_ratio_map is not None:
                ratio = ctx.make_tensor((), dtype=torch.float32)
                for idx in iterate.indices(ctx.shape):
                    ref_dict = self.ref(idx)
                    ratio[idx] = self._random_reject_ratio_map(**ref_dict)
                sess.cache["__random_reject_ratio__"] = ratio

        def after_reasoning(self, thought: Thought, outcome: TokenBuffer):
            super().after_reasoning(thought, outcome)
            if external_verifier is None:
                self.reflective_inference.release()

        def _debug_print(self, tokens: TokenBuffer | None = None):
            if tokens is not None:
                print(self.llm.detokenize(tokens, skip_special_tokens=False))
            else:
                print(self.llm.detokenize(self.context, skip_special_tokens=False))
                print("terminated:", self.terminated)
                print("steps:", self._castctx(self.context).info["__n_step__"])

        def _finish_context(self, output: KEvent | ReflEvent | torch.Tensor,
                            reflective: bool = False):
            if self.__in_context_reflect and not reflective:
                return
            else:
                super()._finish_context(output)  # type: ignore[arg-type]

        def __set_revision_temperature(self, revision_mask: torch.Tensor):
            if self.__revise_temperature is not None:
                assert isinstance(self.inference, Sampling)
                self.inference.temperature = torch.where(
                    revision_mask,
                    self.__revise_temperature,
                    self.__base_temperature,
                )

        def _get_individual_context(self, idx: tuple[int, ...], device=None) -> IndividualContext:
            context = self.context
            session = self.session

            F: Callable[[torch.Tensor], torch.Tensor]
            if device is None:
                F = torch.clone
            else:
                device = torch.device(device)
                F = lambda x: clone_to(x, device=device)

            tokens = F(context.tokens_at(*idx))
            T = len(tokens)

            kvs: list[list[tuple[torch.Tensor, torch.Tensor]]] = []
            idx_kv = self._stack.idxmap(idx)
            for m in session.models():
                kvs_m: list[tuple[torch.Tensor, torch.Tensor]] = []
                for k, v in iterate_kv_tensors(m):
                    kv = F(k[idx_kv, :, :T]), F(v[idx_kv, :, :T])
                    kvs_m.append(kv)
            kvs.append(kvs_m)

            return IndividualContext(
                tokens,
                events={k: F(v[idx]) for k, v in context.events.items()},
                info={k: F(v[idx]) for k, v in context.info.items()},
                data={k: F(v[(*idx, slice(T))]) for k, v in context.data.items()},
                kvs=tuple(kvs),
            )
        
        def _restore_individual_context(self, i: tuple[int, ...], ctxi: IndividualContext):
            context = self.context
            session = self.session
            T = len(ctxi.tokens)
            
            # restore tokens
            context.lengths[i] = len(ctxi.tokens)
            context.tokens[(*i, slice(T))] = ctxi.tokens
            # restore info
            for k, v in ctxi.info.items():
                context.info[k][i] = v
            # restore data
            for k, v in ctxi.data.items():
                context.data[k][(*i, slice(T))] = v
            # restore events
            for k, v in ctxi.events.items():
                if k not in context.events:
                    context.set_event(k, when=-1)
                context.events[k][i] = v
            # restore individual KV cache
            ikv = self._stack.idxmap(i)
            for m, kvs in zip(session.models(), ctxi.kvs):
                for (k, v), (ki, vi) in zip(iterate_kv_tensors(m), kvs):
                    k[ikv, :, :T] = ki
                    v[ikv, :, :T] = vi
        
        def _solve_retry(self):
            sess = self._castsess(self.session)
            ctx = sess.context
            thinking = ~self.terminated
            attempt = ctx.info["__attempt__"]
            n_step = ctx.info["__n_step__"].clone()

            if torch.all(self.terminated | (n_step >= self.max_step)):
                self.inference.submit()
                thought, outcome = self._extract()
                return thought, outcome

            self.__set_revision_temperature(attempt > 0)
            sess.traceback("__out__", cond=thinking)
            ctx.set_event("__gen__")
            self._generate()

            # make reflection on valid steps. 
            # If the context overflows, we enforce it to be an error.
            stopped = self.inference.stopped
            overflow = ~stopped
            reflmask = thinking if budget is None else thinking & (n_step < budget)
            reject = self._reflect(reflmask)
            self._after_reflection(reflmask, reject)
            reject = reject | overflow

            attempt[thinking] += 1
    
            # push the pending cases into the stack
            if self.__require_traceback:
                pending = (thinking & ~reject)
                stack = self._stack
                for i in nonzero_indices(pending):
                    stack[i].append(self._get_individual_context(i, device="cpu"))

            with self._pause(reject):
                self._transit()
            
            attempt[~reject] = 0

            ctx.set_event("__out__", thinking & ~reject)
        
            # traceback on failures, if any
            if self.__require_traceback:
                assert max_retry is not None
                stack = self._stack
                attempt = sess.context.info["__attempt__"]
                failed = thinking & (attempt >= max_retry)
                for i in nonzero_indices(failed):
                    # find context to traceback
                    top = None
                    while (top := stack.pop(i)) is not None:
                        if top.info["__attempt__"] < max_retry:
                                break
                    if top is None:  # empty stack: nowhere to traceback, simply retry
                        continue
                    else:
                        self._restore_individual_context(i, top)
                sess.relocate_pos_write(failed)

            ctx.info["__n_step__"] = n_step + torch.where(thinking, 1, 0)  # count valid steps
            return self._solve_retry()

        def _solve_revise(self):
            REVISE = self._refl_symbol_seqs[Symbol.revise]
            FAILURE = self._refl_symbol_seqs[Symbol.failure]

            sess = self._castsess(self.session)
            ctx = sess.context
            thinking = ~self.terminated
            attempt = ctx.info["__attempt__"]
            n_step = ctx.info["__n_step__"].clone()

            if torch.all(self.terminated | (n_step >= self.max_step)):
                self.inference.submit()
                thought, outcome = self._extract()
                return thought, outcome

            self.__set_revision_temperature(attempt > 0)
            ctx.set_event("__gen__")
            self._generate()
            
            # detect failure
            failed = ctx.contains(FAILURE, since=ctx.when("__gen__", default=0))
            # if the context overflows, we enforce the case to fail
            _stopped = self.inference.stopped
            assert torch.all(_stopped | (ctx.lengths == ctx.max_length))
            failed = failed | (~_stopped)
            pending = thinking & (~failed)

             # count valid steps
            n_step = n_step + torch.where(pending, 1, 0) 
            attempt[pending] += 1
            del ctx, attempt

            # reflect over valid steps
            reflmask = pending if budget is None else pending & (n_step < budget)
            reject = self._reflect(reflmask)
            self._after_reflection(reflmask, reject)
            
            # update context info after reflection.
            pending = pending & ~reject
            ctx = sess.context
            attempt = ctx.info["__attempt__"]

            # push the pending cases into the stack
            assert self.__require_traceback
            stack = self._stack
            for i in nonzero_indices(pending):
                stack[i].append(self._get_individual_context(i, device="cpu"))

            to_revise = reject | failed
            with self._pause(to_revise):
                self._transit()

            attempt[pending] = 0
            ctx.set_event("__out__", pending)
            
            # traceback on failures, if any
            assert self.__require_traceback
            stack = self._stack
            for i in nonzero_indices(failed):
                # find context to traceback
                top = stack.pop(i)
                if top is None:  # empty stack: nowhere to traceback, simply retry
                    continue
                else:
                    self._restore_individual_context(i, top)

            ctx.info["__n_step__"] = n_step
            ctx.append_(REVISE, to_revise)
            sess.relocate_pos_write(to_revise)
            return self._solve_revise()

        def _on_reflection(self, mask: torch.Tensor, reject: torch.Tensor):
            for hook in self._hooks:
                cast(ReflHook, hook).on_reflection(mask, reject)
        
        def _after_reflection(self, mask: torch.Tensor, reject: torch.Tensor):
            for hook in self._hooks:
                cast(ReflHook, hook).after_reflection(mask, reject)

        def _reflect(self, mask: torch.Tensor) -> torch.Tensor:
            sess = self._castsess(self.session)
            ctx = sess.context

            if external_verifier == "oracle":
                error = _reflect_by_binary_verifier(self.llm.preprocessor, self._oracle, ctx, mask)
                self._on_reflection(mask, error)
                return error
            elif external_verifier == "vf":
                error = _reflect_by_value(self.vf, ctx, mask, self.vf_tolerance, sess.pad_index)
                self._on_reflection(mask, error)
                return error
            elif external_verifier == "random":
                if self._random_reject_ratio is not None:
                    threshold = self._random_reject_ratio
                else:
                    threshold = cast(torch.Tensor, sess.cache["__random_reject_ratio__"])
                error = mask & (torch.rand_like(mask, dtype=torch.float32) <= threshold)
                self._on_reflection(mask, error)
                return error
            else:
                assert external_verifier is None

            # use LLM to reflect
            inference = self.reflective_inference
            BEGIN = self._refl_symbol_seqs[Symbol.begin]
            REJECT = self._refl_symbol_seqs[Symbol.reject]
            ACCEPT = self._refl_symbol_seqs[Symbol.accept]
            assert inference is not None
            if reject_mode == "retry":
                ctx.set_event("__reflection__", mask=mask)
                ctx.append_(BEGIN, cond=mask)
                inference.stopped = ~mask
                inference.infer_sequence(self._eos + [REJECT, ACCEPT])
                error = ctx.contains(
                    REJECT,
                    since=ctx.when("__reflection__", default=ctx.lengths)
                )
                self._on_reflection(mask, error)
                
                context_output = torch.where(error, ctx.when("__reflection__"), ctx.when("__gen__")) \
                    if _error_as_context_prompt else "__gen__"
                self._finish_context(context_output, reflective=True)
                sess.traceback("__reflection__", cond=mask)
                ctx.remove_event("__reflection__")
            else:
                _proposal = sess.context.eseg("__gen__", None)
                with sess.tryfork(tokens=True,
                                  event=["__out__"],
                                  info=["__attempt__"],
                                  cache=True,
                                  state_device='cpu', restore_model='kv',
                                  update_cache=True) as _state:
                    _ctx_refl = sess.context
                    sess.traceback("__out__", cond=mask)
                    _ctx_refl.append_from_(_proposal, cond=mask)
                    _ctx_refl.set_event("__reflection__", mask=mask)
                    _ctx_refl.append_(BEGIN, cond=mask)
                    inference.stopped = ~mask
                    inference.infer_sequence(self._eos + [REJECT, ACCEPT])
                    error = _ctx_refl.contains(
                        REJECT,
                        since=_ctx_refl.when("__reflection__", default=_ctx_refl.lengths)
                    )
                    self._on_reflection(mask, error)
                    # self._finish_context("__reflection__", reflective=True)

            return error

        def __call__(self, input: TokenBuffer) -> tuple[Thought, TokenBuffer]:
            self.before_reasoning(input)
            if budget is not None and budget <= 0:
                thought, outcome = self._solve()
            elif reject_mode == "retry":
                thought, outcome = self._solve_retry()
            elif reject_mode == "revise":
                thought, outcome = self._solve_revise()
            self.after_reasoning(thought, outcome)
            return thought, outcome

    return ReflectionReasoner


# ===========================
# Hooks
# ===========================


class _ReflStatistics(Callbacks[ReflCache], GenHook, ReflHook):

    def __init__(self, oracle: Callable[[str, str], bool] | None = None,
                 _ith_attempt_only: int | None = None) -> None:
        
        super().__init__()
        self._oracle = oracle
        self._ith_attempt_only = _ith_attempt_only

    def before_reasoning(self, input: TokenBuffer):
        context = self.host.context
        cache = self.cache
        cache["refl_steps"] = context.make_tensor((), torch.int64, 0)
        cache["refl_tokens"] = context.make_tensor((), torch.int64, 0)
        cache["refl_freq"] = context.make_tensor((), torch.int64, 0)
        cache["refl_freq_rej"] = context.make_tensor((), torch.int64, 0)
        if self._oracle is not None:
            cache["refl_freq_fp"] = context.make_tensor((), torch.int64, 0)
            cache["refl_freq_fn"] = context.make_tensor((), torch.int64, 0)
            cache["pi_freq_neg"] = context.make_tensor((), torch.int64, 0)
    
    def on_reflection(self, mask: torch.Tensor, reject: torch.Tensor):
        host = self.host
        ctx = cast(Context[ReflEvent, ReflInfo, Any], host.context)
        stat = refl_stat_tensors(self.cache)

        t_refl = ctx.when("__reflection__", default=ctx.lengths)
        n_refl = (ctx.lengths - t_refl).clamp_min(0)
        valid_refl = reject | (n_refl > 2)

        if self._ith_attempt_only is not None:
            exclude = (ctx.info["__attempt__"] != self._ith_attempt_only)
            mask = mask.masked_fill(exclude, False)
            n_refl = n_refl.masked_fill(exclude, 0)
            valid_refl = valid_refl.masked_fill(exclude, False)
            reject = reject.masked_fill(exclude, False)
            
        stat["refl_steps"] += mask.type(torch.int64)
        stat["refl_tokens"] += n_refl
        stat["refl_freq"] += valid_refl.type(torch.int64)
        stat["refl_freq_rej"] += reject.type(torch.int64)

    def after_reflection(self, mask: torch.Tensor, reject: torch.Tensor):
        if (oracle := self._oracle) is None:
            return
        host = self.host
        ctx = host.context
        stat = refl_stat_tensors(self.cache)
        error = _reflect_by_binary_verifier(host.llm.preprocessor, oracle, ctx, mask)
        neg = error.type(torch.int64)
        fp = (error & (~reject)).type(torch.int64)
        fn = ((~error) & reject).type(torch.int64)

        if self._ith_attempt_only is not None:
            exclude = (ctx.info["__attempt__"] != self._ith_attempt_only)
            neg = neg.masked_fill(exclude, 0)
            fp = fp.masked_fill(exclude, 0)
            fn = fn.masked_fill(exclude, 0)

        stat["refl_freq_fp"] += fp
        stat["refl_freq_fn"] += fn
        stat["pi_freq_neg"] += neg


# ------------------------------- #
# ------ Utility Functions ------ #
# ------------------------------- #

def _reflect_by_binary_verifier[KEvent, KInfo](
    preprocessor: Preprocessor,
    binary_verifier: BinaryVerifier,
    ctx: Context[ReflEvent | KEvent, ReflInfo | KInfo, Any],
    mask: torch.Tensor
) -> torch.Tensor:
    out = ctx.make_flag(value = False)
    prompt_lengths = ctx.when("__out__", default=0)
    gen_start = ctx.when("__gen__", default=ctx.lengths)
    decode = functools.partial(preprocessor.decode,
                               skip_special_tokens=False)
    for idx in nonzero_indices(mask):
        prompt_length = int(prompt_lengths[idx])
        output_start = int(gen_start[idx])
        prompt = decode(ctx.tokens_at(*idx, stop=prompt_length))
        output = decode(ctx.tokens_at(*idx, start=output_start))
        error = not binary_verifier(prompt, output)
        out[idx] = error
    return out


def _reflect_by_value[KEvent, KInfo](
    vf: ValueModel, 
    ctx: Context[ReflEvent | KEvent, ReflInfo | KInfo, Any], 
    mask: torch.Tensor,
    tolerance: float,
    _pad_index: int = 0,
):
    _shape = ctx.shape
    prompts = ctx.eseg(None, "__out__")
    outputs = ctx.eseg("__gen__", None)
    buf = TokenBuffer.concat(prompts, outputs, pad_index=_pad_index)
    T = buf.max_length
    v = vf.forward(buf.tokens.reshape(-1, T)).view(_shape + (T,))
    assert torch.all((~mask) | (prompts.lengths > 0))
    t_output = torch.where(mask, buf.lengths - 1, T - 1)
    t_prompt = torch.where(mask, prompts.lengths - 1, T - 1)
    v_prompt = fetch(v, t_prompt)
    v_output = fetch(v, t_output)
    adv = v_output - v_prompt
    reject = adv < tolerance
    return reject

def refl_stat_tensors(cache: NestedTensorDict):
    return cast(NestedTensorDict[ReflStat, torch.Tensor], cache)
