from typing import Dict, Tuple, Any
from cachetools import LRUCache
from collections import defaultdict
from genlm.bytes import ByteBeamState, BeamParams
from genlm.backend import load_model_by_name

import numpy as np

NestedCtx = Tuple[Any, ...]
NEG_INF: float = float("-inf")
EOT_IDX = 256

class GenLMRealpha:
    @classmethod
    async def create(
        cls,
        model_name: str,
        llm=None,
        backend="hf",
        K: int = 5,
        prune_threshold: float = 0.01,
        verbose: bool = False
    ) -> "GenLMRealpha":
        if llm is None:
            llm = load_model_by_name(model_name, backend=backend)
        root_beam = await ByteBeamState.initial(
            llm, BeamParams(K=K, prune_threshold=prune_threshold)
        )
        return cls(llm, root_beam, K, prune_threshold, verbose)

    def __init__(
        self,
        llm: Any,
        root_beam: ByteBeamState,
        K: int,
        prune_threshold: float,
        verbose: bool = False
    ):
        self.llm = llm
        self.K = K
        self.prune_threshold = prune_threshold
        self.root_beam = root_beam
        self._beams: LRUCache= LRUCache(maxsize=100000)
        self._beams[()] = root_beam
        self._ctx: NestedCtx = ()
        self.verbose = verbose

    def empty_cache(self):
        self._beams.clear()

    async def logp_next_for_dict(self, ctx) -> dict:
        try:
            beam = await self._beam_for(ctx)
            logp_next = await beam.logp_next() # LazyByteProbs
            base = logp_next.materialize().to_dict() # Chart -> Dict
        except (AssertionError, ValueError) as e:
            if self.verbose:
                print("WARNING: Caught genlm", e, "…")
            base = {}

        probs = defaultdict(lambda: NEG_INF, base)
        return probs
    
    
    async def logp_next_for(self, ctx, *, dtype=np.float32, include_eot=True, copy=False) -> np.ndarray:
        try:
            beam = await self._beam_for(ctx)
            try:
                lbp = await beam.logp_next()
            except RuntimeError as e:
                # PyTorch reports illegal accesses as RuntimeError with "CUDA error: an illegal memory access..."
                if "CUDA error" in str(e) or "illegal memory access" in str(e).lower():
                    if self.verbose:
                        print("WARNING: CUDA illegal memory access in logp_next:", e)
                return np.full(EOT_IDX + 1, NEG_INF, dtype=dtype)

            ps = getattr(lbp, "ps", None)
            if ps is not None:
                arr = np.asarray(ps, dtype=dtype)
                if arr.shape != (EOT_IDX + 1,):
                    raise ValueError("Unexpected LazyByteProbs.ps shape")
            else:
                Q = lbp.materialize()
                arr = np.full(EOT_IDX + 1, NEG_INF, dtype=dtype)
                for k, v in Q.items():
                    arr[EOT_IDX if k is None else k] = v
        except (AssertionError, ValueError) as e:
            if self.verbose:
                print("WARNING: Caught genlm", e, "…")
            arr = np.full(EOT_IDX + 1, NEG_INF, dtype=dtype)
        return arr if include_eot else arr[:EOT_IDX]
    
    async def _beam_for(self, ctx: NestedCtx) -> ByteBeamState:
        """
        Recursively build (and cache) a beam for ctx.
        """
        if ctx in self._beams:
            return self._beams[ctx]

        if ctx == ():
            raise RuntimeError("root beam should already be cached")

        parent = ctx[:-1]
        ch = ctx[-1]
        parent_beam = await self._beam_for(parent)
        beam = await (parent_beam.prune() << int(ch))
        self._beams[ctx] = beam
        return beam

    @staticmethod
    def _materialize(beam: ByteBeamState) -> Dict[str, float]:
        logp_next = beam.logp_next_sync()
        return (
            logp_next.materialize()
            .map_keys(lambda x: bytes([x]).decode("utf-8") if x is not None else "EOT")
            .to_dict()
        )
    
    async def cleanup(self):
        await self.root_beam.cleanup()

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc, tb):
        await self.cleanup()
