# Deprecated! Not debugged and maintained.

import torch
import dataclasses as dc
import itertools
from typing import NamedTuple, Callable
from pathlib import Path

from core.inference.base import Session
from core.utils.buf import TokenBuffer, Seqs
from core.utils import kv, th
from core.model import ValueModel, LLM
from .base import Inference, Context
from .sampling import Sampling
from math import inf as INF


class Assessment(NamedTuple):

    scores: torch.Tensor  # (*shape)
    choices: torch.Tensor | None = None  # (*shape)


class BeamSearch(Inference):

    _activated: torch.Tensor
    _logprob_choices: torch.Tensor

    def __init__(self, llm: LLM, /, k: int, n_choices: int | None = None, submit_best: bool = True):
        super().__init__(llm)
        self.k = k
        self.n_choices = n_choices
        self.submit_best = submit_best
    
    def release(self):
        super().release()
        try:
            del self._activated, self._logprob_choices
        except AttributeError:
            pass
    
    def assess(self):
        session = self.session
        ctx = session.context
        active = ~self.stopped
        logits = self._predict_logits(None)  # TODO: use active mask to reduce computaional cost
        logprobs = ctx.info.pop('logprob')
        logprobs = logprobs.unsqueeze(-1)
        
        write_mask = session._write_mask(active).unsqueeze(-1)
        logprobs_write = torch.log_softmax(logits[..., -1, :], -1)

        if self.n_choices is None:
            choices = None
        else:
            logprobs_write, choices = torch.topk(logprobs_write, k=self.n_choices)

        logprobs_choices = torch.where(write_mask, logprobs + logprobs_write, logprobs)
        self._logprob_choices = logprobs_choices

        return Assessment(logprobs_choices, choices)

    def launch(self, input: TokenBuffer):
        super().launch(input)

        ctx = self.session.empty(input.shape, device=self._llm.device)
        ctx.append_from_(input)
        ctx = ctx.reshape(*input.shape, 1).expand(*input.shape, self.k).clone()

        self.session.prepare(ctx)
        self.llm.kv_reset(ctx.nelem, ctx.max_length)

        self.logprob = 0
        self._activated = ctx.make_flag()
        self._activated[..., 0] = True
    
    def submit(self):
        if not self.submit_best:
            return
        ctx = self.session.context
        _idx = (*(slice(None) for _ in range(ctx.ndim - 1)), 0)
        self.session.apply_transform(lambda x: x[_idx].clone())
    
    def infer_token(self) -> torch.Tensor:

        session = self.session
        ctx = self.context
        d_ctx = ctx.ndim
        scores, choices = self.assess()
        llm = self._llm

        write_mask = session._write_mask(self.stopped)
        next_activated = (self._activated & write_mask)
        next_activated = next_activated.unsqueeze(-1).expand(scores.shape).clone()
        next_activated[..., 0] |= self._activated  # make sure that there is at least one activated successor.

        scores = torch.where(next_activated, scores, -INF)
        n_choices = llm.vocab_size if self.n_choices is None else self.n_choices 

        # check shape: (*input_shape, k, n_choices)
        assert scores.shape[-2:] == (self.k, n_choices) and scores.shape[:-1] == ctx.shape
        assert choices is None or choices.shape == scores.shape

        scores = scores.flatten(start_dim=-2)  # (*input_shape, k * n_choices)
        topk_scores, topk_indices = torch.topk(scores, self.k, -1)
        prefix_indices = topk_indices // n_choices

        def choose_next(x: torch.Tensor):
            return x.flatten(-2).gather(-1, topk_indices)

        def choose_current(x: torch.Tensor):
            # x[*i, k, *j] = x[*i, idx[*i, k], *j]

            idx = prefix_indices.reshape(ctx.shape + (1,) * (x.ndim - d_ctx))  # [*input_shape, k, 1, ..., 1]
            idx = idx.expand(x.shape).to(device=x.device)
            x.copy_(x.gather(d_ctx - 1, idx))
        
        def debug_texts(*idx: int):
            for k in range(self.k):
                idx_k = k if len(idx) == 0 else idx + (k,)
                if self._activated[idx_k]:
                    tokens = ctx.tokens_at(*idx, k)
                    text = llm.preprocessor.decode(tokens)
                    print(f"#{k}: \"{text}\"")

        # apply operations to the current buffers
        assert self.logprob is None
        session.apply_operation(choose_current)

        # choose the next tokens
        if choices is None:
            next_tokens = topk_indices % n_choices
        else:
            next_tokens = choose_next(choices)

        self._activated = choose_next(next_activated)
        self.logprob = choose_next(self._logprob_choices)
        del self._logprob_choices

        next_tokens = next_tokens.to(dtype=ctx.tokens.dtype) 
        return next_tokens


class TokenVFBeamSearch(BeamSearch):

    def __init__(self, llm: LLM, /, k: int,
                 n_choices: int | None = None,
                 submit_best: bool = True,
                 vf: ValueModel | Path | str | None = None,
                 vf_batch_size: int = 1,
                 weight_vf: float = 1., weight_pr: float = 0.):
        
        super().__init__(llm, k, n_choices, submit_best)

        if not isinstance(vf, ValueModel):
            vf = llm.get_value_model(vf, trained_only=True)
        
        self._vf = vf
        self._vf_batch_size = vf_batch_size
        self._weight_vf = weight_vf
        self._weight_prob = weight_pr
    
    def connect(self, session: Session):
        super().connect(session)
        self.vf = session.manage_model(self._vf, device=self._llm.device)
    
    def release(self):
        super().release()
        del self.vf

    def launch(self, input: TokenBuffer):
        super().launch(input)
        ctx = self.context
        batch_size = ctx.nelem * self.k * self._vf_batch_size
        kv.reset(self._vf, batch_size, ctx.max_length, device=ctx.device)
        self._vf.eval()

    def assess(self):
        logprobs, choices = super().assess()
        v = self._verify(choices)
        scores = self._weight_vf * v + self._weight_prob * logprobs

        return Assessment(scores, choices)

    def _verify(self, choices: torch.Tensor | None):
        # fill the 
        session = self.session
        ctx = session.context
        vf_batch_size = self._vf_batch_size
        t1 = self.vf.pos
        t2 = session.pos_write

        if t1 < t2:
            self.vf.__call__(t1, t2, None)  # TODO: use active mask to reduce cost

        # the model has predicted with `(*batch_shape, k)`
        # now we expand it to (vf_batch_size, *batch_shape, k)
        kv.kv_apply(
            self._vf,
            lambda x: x.expand(vf_batch_size, *x.shape),
            in_shape=ctx.shape,
            pos=slice(t2)
        )

        n_choices = self._llm.vocab_size if self.n_choices is None else self.n_choices 
        if choices is not None:
            vf_inputs = torch.permute(choices, (ctx.ndim, *range(ctx.ndim)))
        else:
            vf_inputs = torch.arange(n_choices, device=ctx.device, dtype=torch.int32)
            vf_inputs = vf_inputs.reshape((n_choices,) + (1,) * ctx.ndim).expand(n_choices, *ctx.shape)
        assert vf_inputs.shape == (n_choices,) + ctx.shape 

        vs = []
        for i in range(0, n_choices, vf_batch_size):
            choices_batch = vf_inputs[i: i + vf_batch_size]
            v = self.vf.predict(choices_batch.unsqueeze(-1), ctx.pos_idx[t2: t2+1])
            v = v.squeeze(-1)
            vs.append(v)
        v = torch.cat(vs, dim=0)
        assert v.shape == vf_inputs.shape

        v = torch.permute(v, (*range(1, ctx.ndim+1), 0))
        assert v.shape == ctx.shape + (n_choices,)

        return v


class SeqVfBeamSearch(Inference):

    @dc.dataclass
    class Proposal(TokenBuffer):

        logprob: torch.Tensor = dc.field(init=False)
        stopped: torch.Tensor = dc.field(init=False)
        score: torch.Tensor = dc.field(init=False)

        def __post_init__(self):
            self.logprob = self.make_tensor((), torch.float32)
            self.stopped = self.make_flag()
            self.score = self.make_tensor((), torch.float32)

            return super().__post_init__()

    def __init__(self, llm: LLM, /, k: int, n_choices: int,
                 submit_best: bool = True,
                 vf: ValueModel | Path | str | None = None,
                 weight_vf: float = 1., weight_pr: float = 0.,
                 propose_size: int = 1,
                 temperature: float = 1,
                 top_k: int | None = None,
                 top_p: float = 1):
        
        super().__init__(llm)

        self.k = k
        self.n_choices = n_choices
        self.submit_best = submit_best

        if not isinstance(vf, ValueModel):
            vf = llm.get_value_model(vf, trained_only=True)
        
        self._vf = vf
        self.propose_size = propose_size
        self._weight_vf = weight_vf
        self._weight_prob = weight_pr
        self.proposer = Sampling(llm, temperature, top_k, top_p)
    
    def connect(self, session: Session):
        super().connect(session)
        self.proposer.connect(session)
        self.vf = session.manage_model(self._vf, device=self._llm.device)
    
    def launch(self, input: TokenBuffer):
        super().launch(input)

        ctx = self.session.empty(input.shape, device=self._llm.device)
        ctx.append_from_(input)
        ctx = ctx.reshape(*input.shape, 1).expand(*input.shape, self.k).clone()
        self.session.prepare(ctx)
        
        self.llm.kv_reset(ctx.nelem * self.propose_size, ctx.max_length)
        self.vf.kv_reset(ctx.nelem * self.propose_size, ctx.max_length)

        self.logprob = 0
    
    def infer_token(self) -> torch.Tensor:
        return self.proposer.infer_token()

    def _propose(
        self,
        n: int,
        stop_seqs: Seqs,
        batchsize: int = 1,
        batchdim: int = 0,
    ):
        """
        Propose `n` subsequent sequences for each prefix in `sess`.
        """

        session = self.session
        context = session.context
        ndim = context.ndim
        shape = context.shape
        weight_v = self._weight_vf
        weight_p = self._weight_prob

        if batchdim < 0:
            batchdim = ndim + 1 + batchdim
        if not (0 <= batchdim < ndim + 1):
            raise IndexError("Can not insert an addition dimension %d (dimension out of range)" % batchdim)
        
        proposal_shape = (*shape[:batchdim], n, *shape[batchdim:])
        slices = (slice(None),) * batchdim

        expand_fn: Callable[[torch.Tensor], torch.Tensor] = lambda x: \
            x.unsqueeze(batchdim).expand(proposal_shape + x.shape[ndim:]).clone()
        proposal = self.Proposal(expand_fn(context.tokens), expand_fn(context.lengths))
        proposal_shape = (*shape[:batchdim], batchsize, *shape[batchdim:])

        # first: load the common prefix
        pos_write = session.pos_write
        assert self.llm is self.proposer.llm
        for m in (self.vf, self.llm):
            if m.pos < pos_write:
                m.__call__(m.pos, pos_write, None)  # TODO: use active mask to reduce cost

        session.apply_transform(expand_fn)
        
        for i in range(0, n, batchsize):
            b = min(batchsize, n - i)
            dest = (*slices, slice(i, i+b))
            src = (*slices, slice(b))
            
            with session.tryfork(info=True):

                self.proposer.infer_sequence(stop_seqs)
                context = session.context
                
                # evaluate value function
                end = int(context.lengths.max())
                start = min(self.vf.pos, end - 1)
                assert start >= 0 
                last_indices = context.lengths - start - 1
                v = self.vf.__call__(start, end, None)  # TODO: use active mask to reduce cost
                assert v.shape == (*context.shape, end - start)
                v = th.fetch(v, last_indices)

                # compute score
                logp = self.logprob
                assert logp is not None
                score = v * weight_v + logp * weight_p

                proposal.tokens[dest] = context.tokens[src]
                proposal.lengths[dest] = context.lengths[src]
                proposal.logprob[dest] = logp[src]
                proposal.score[dest] = score[src]

        return proposal

    def infer_sequence(self, stop_seqs):
        session = self.session
        ctx = self.context
        d_ctx = ctx.ndim
        k = self.k
        n = self.n_choices

        with session.tryfork(info=['stopped', 'logprob'], tokens=True):
            proposal = self._propose(n, stop_seqs, self.propose_size, d_ctx)

        def debug_print():
            preprocessor = self._llm.preprocessor
            for i, tokens in proposal.enumerate():
                score = proposal.score[i].item()
                logp = proposal.logprob[i].item()
                print('[' +
                      '-'.join(map(str, i)) +
                      "|score=%.3f, logp=%.3f]" % (score, logp) +
                      f"\"{preprocessor.decode(tokens)}\"")        
        
        d_proposal = proposal.ndim
        assert ctx is self.context

        scores = proposal.score
        assert proposal.shape[-2:] == (k, n) and proposal.shape[:-1] == ctx.shape

        scores = scores.flatten(start_dim=-2)  # (*input_shape, k * n_choices)
        topk_scores, topk_indices = torch.topk(scores, k, -1)
        prefix_indices = topk_indices // n

        def choose_next(x: torch.Tensor):
            elem_shape = x.shape[d_proposal:]
            x = x.reshape(*proposal.shape[:-2], k * n, *elem_shape)
            idx = topk_indices.reshape(ctx.shape + (1,) * (x.ndim - d_ctx))  # [*input_shape, k, 1, ..., 1]
            idx = idx.expand(ctx.shape + elem_shape).to(device=x.device)
            return x.gather(d_ctx - 1, idx)

        def choose_current(x: torch.Tensor):
            # x[*i, k, *j] = x[*i, idx[*i, k], *j]

            idx = prefix_indices.reshape(ctx.shape + (1,) * (x.ndim - d_ctx))  # [*input_shape, k, 1, ..., 1]
            idx = idx.expand(x.shape).to(device=x.device)
            x.copy_(x.gather(d_ctx - 1, idx))

        # apply operations to the current buffers
        ctx.info.pop('logprob')
        session.apply_operation(choose_current)

        # choose the next tokens
        ctx.tokens = choose_next(proposal.tokens)
        ctx.lengths = choose_next(proposal.lengths)
        self.stopped = choose_next(proposal.stopped)
        self.logprob = choose_next(proposal.logprob)

        if __debug__:
            self.session._check_shape()
    
    def submit(self):
        if not self.submit_best:
            return
        ctx = self.session.context
        _idx = (*(slice(None) for _ in range(ctx.ndim - 1)), 0)
        self.session.apply_transform(lambda x: x[_idx].clone())
    
    def release(self):
        super().release()
        self.proposer.release()
