import torch
import abc
from dataclasses import dataclass
from core.utils.buf import TokenBuffer, Seqs, SequencesLike
from core.utils.th import fetch
from core.model import LLM
from .base import Inference, Context
from typing import Any, Literal


def _multinomial(probs: torch.Tensor, n: int = 1, replace=False) -> torch.Tensor:
    if torch._dynamo.is_compiling() and n == 1:
        # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
        distribution = torch.empty_like(probs).exponential_(1)
        return torch.argmax(probs / distribution, dim=-1, keepdim=True)
    
    if probs.ndim <= 2:
        return torch.multinomial(probs, num_samples=n, replacement=replace)
    else:
        shape = probs.shape[:-1]
        probs = probs.view(-1, probs.size(-1))
        out = torch.multinomial(probs, num_samples=n, replacement=replace)
        return out.view(*shape, n)


def batched_greedy(logits: torch.Tensor, n: int | None = None, _clone=True):
    if n is None:
        return torch.argmax(logits, dim=-1)
    else:
        out = torch.argmax(logits, dim=-1, keepdim=True).expand(*logits.shape[:-1], n)
        if _clone:
            return out.clone()
        else:
            return out


def batched_sample(
    logits: torch.Tensor,
    temperature: float | torch.Tensor = 1.0,
    top_k: int | None = None,
    top_p: float = 1.0,
    n: int | None = None,
    replace: bool = False,
) -> torch.Tensor:
    
    if top_p < 0.0 or top_p > 1.0:
        raise ValueError(f"top_p must be in [0, 1], got {top_p}")
    
    if top_p <= 0:
        return batched_greedy(logits, n)
    
    # optionally crop the logits to only the top k options
    if top_k is not None:
        v, i = torch.topk(logits, min(top_k, logits.size(-1)))  # (b, k)
        # do not use `torch.where` as in nanogpt because it will repeat top-k collisions
        logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)

    # scale logits with temperature.
    if _flag := isinstance(temperature, torch.Tensor):
        if temperature.shape != logits.shape[:-1]:
            temperature = temperature.expand(logits.shape[:-1])
        greedy = temperature <= 0
        if (_flag := bool(torch.any(greedy))):
            greedy_tokens = batched_greedy(logits[greedy], n, _clone=False)
            temperature = temperature.masked_fill(greedy, float("inf"))
        logits = logits / temperature.unsqueeze(-1)
    else:
        if temperature <= 0:
            return batched_greedy(logits, n)
        elif temperature != 1:
            logits = logits / temperature
    
    if top_p < 1:
        logits = _logits_top_p(logits, top_p)

    probs = torch.nn.functional.softmax(logits, dim=-1)
    if n is None:
        tokens = _multinomial(probs).squeeze(-1)
    else:
        tokens = _multinomial(probs, n, replace) 
    
    if _flag:
        tokens[greedy] = greedy_tokens
    
    return tokens


def _logits_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
    sorted_logits, sorted_indices = torch.sort(logits, descending=False)
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
    # Example:
    # sorted_probs=[0.1, 0.15, 0.2, 0.25, 0.3] -> sorted_cumprobs=[0.1, 0.25, 0.45, 0.7, 1.0]
    # sorted_indices_to_remove = [1, 1, 0, 0, 0] if top_p=0.7
    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
    # Keep at least 1 token always to prevent the case where no token is selected
    # In this case the most probable one is always kept
    sorted_indices_to_remove[..., -1:] = 0
    indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
    logits = logits.masked_fill(indices_to_remove, float("-inf"))
    return logits


class Sampling(Inference[Any, Any]):

    def __init__(
        self, llm: LLM, /,
        temperature: float = 1.0,
        top_k: int | None = None,
        top_p: float = 1.0,
    ):
        super().__init__(llm)
        
        self.__temperature = temperature
        self.__key_temperature = "Sampling@%d.temperature" % id(self)
        self.top_k = top_k
        self.top_p = top_p
    
    @property
    def temperature(self):
        try:
            return self.context.info.get(self.__key_temperature, self.__temperature)
        except AttributeError:  # context not initialized
            return self.__temperature
    
    @temperature.setter
    def temperature(self, value: float | torch.Tensor):
        if isinstance(value, (float, int)):
            try:
                self.context.info.pop(self.__key_temperature)
            except AttributeError:
                pass
            except KeyError:
                pass
            self.__temperature = value
        else:
            if not value.shape == (shape := self.context.shape):
                value = value.expand(shape).clone()
            self.context.info[self.__key_temperature] = value

    def _sample_next_tokens(self, logits: torch.Tensor, active: torch.Tensor | None = None):
        temperature = self.temperature
        if active is not None and isinstance(temperature, torch.Tensor):
            temperature = temperature[active]
        tokens = batched_sample(logits[..., -1, :], temperature, self.top_k, self.top_p)
        tokens = tokens.to(dtype=torch.int32)
        if active is None:
            return tokens
        else:
            pad = self.session.pad_index
            out = torch.full_like(active, pad, dtype=torch.int32)
            out[active] = tokens
            return out

    def launch(self, input: TokenBuffer):
        super().launch(input)
        
        ctx = self.session.empty(input.shape, device=self._llm.device)
        ctx.append_from_(input)
        self.session.prepare(ctx)
        self.llm.kv_reset(ctx.nelem, ctx.max_length)
    
    def infer_token(self) -> torch.Tensor:
        stopped = self.stopped
        active = ~stopped
        logits = self._predict_logits(active)
        next_tokens = self._sample_next_tokens(logits, active)

        if (logp := self.logprob) is not None:
            delta = fetch(torch.log_softmax(logits[..., -1, :], -1), next_tokens)
            write_mask = self.session._write_mask(stopped)
            logp = logp.type_as(logits)
            self.logprob = torch.where(write_mask, logp + delta, logp)

        return next_tokens
