import subprocess
from pathlib import Path
from collections import defaultdict, deque
from typing import List, Tuple, Optional, Dict, Any, FrozenSet, Union, Callable
import pickle
import os
import time
import asyncio
import functools
import concurrent.futures
from cachetools import LRUCache
import copy
import bisect
from itertools import combinations, repeat

import pynini
from scipy.special import logsumexp
import numpy as np
import pathlib
from tqdm import tqdm
import torch

from tokenizer_conversion.machines.utils.config import Config, _LMBackend
from tokenizer_conversion.machines.utils.types import Packed, Block, Beam

from tokenizer_conversion.machines.utils.utils import _log, _hash_cfg
from tokenizer_conversion.machines.utils.pruning import _prune_by_logweights
from tokenizer_conversion.machines.utils.symbols import build_label_acceptor

from tokenizer_conversion.machines.ngram import train_ngram, ngram_probs, _save_ngram, _load_ngram, prepare_ngram_arrays
from tokenizer_conversion.benchmarking.utils.data_utils import load_wikitext_paragraphs_bytes


# Catch line profiler errors
try:
    profile
except NameError:
    def profile(func):
        return func

NEG_INF: float = float("-inf")

# INDICES
POWERSTATE, BEAM_OUT = (0, 1)
YS, LOGP = (0, 1)


class TransducedLM:
    def __init__(self, fst: pynini.Fst, llm_name=None, llm=None, use_genlm=False, preprocess_fst=True, eos_out="257"):
        # Optionally pre-process FST to ensure it's output-pushed and epsilon-free
        if preprocess_fst:
            fst = self._preprocess_fst(fst)
            
        self.fst = fst
        
        # INITIALIZE LM
        if llm_name is not None and llm is not None:
            print("Only provide one of llm_name and llm.")
        self.llm = llm
        if llm is not None:
            self.eos = llm.tokenizer.eos_token
        if llm_name:
            self.llm_name = llm_name
        else:
            if llm and hasattr(llm, "model") and hasattr(llm.model, "config"):
                llm_name = getattr(llm.model.config, "_name_or_path", None)
                self.llm_name = llm_name
        self.llm_name = llm_name
        self.bos_id = None if self.llm is None else self.llm.tokenizer.bos_token_id

        # HANDLE FST SYMBOL MAPPINGS
        self.eps_id = 0 # MUST always be 0 in Pynini
        self.eps = self.fst.input_symbols().find(self.eps_id)
        self.eos_out = eos_out
        self.special_tokens = []

        self._out_id_to_sym = {sym:label for sym, label in fst.output_symbols()}
        self._in_id_to_sym  = {sym:int(label) for sym, label in fst.input_symbols()}
        self._out_sym_to_id = {label:sym for sym, label in fst.output_symbols()}
        self._in_sym_to_id  = {label:sym for sym, label in fst.input_symbols()}
        self._all_valid_ids = np.array([sid for sid, _ in self.fst.output_symbols() if sid != 0], dtype=np.int16)
        self.num_in_syms = len(self._in_id_to_sym.keys())
        self.num_out_syms = len(self._out_id_to_sym.keys())


        self.universal_states = {}
        self.state_names = {} # used for visualization

        # INITIALIZE CACHE
        self._logp_next_cache: LRUCache = LRUCache(maxsize=10000)
        self._cover_logp_cache: LRUCache = LRUCache(maxsize=20000)
        self._cover_beam_cache: LRUCache = LRUCache(maxsize=40000)
        self._universal_set_cache = {}
        self._logp_pending = {}
        self._universal_sets_by_size: dict[int, set[FrozenSet[int]]] = defaultdict(set)
        self._state_closure_output_syms = {}
        self._local_cache: dict[frozenset[int], tuple[dict, int]] = {}
        self._ps2id = {} 
        self._id2ps = [frozenset()]
        self._next_by_tok = [np.zeros(self.num_in_syms, np.int32)]
        self._next_prefilled = [False]
        self._univ_arr = np.array([-1], np.int8)
        self._next_pset_after_with_y_dict = {}

        # PRECOMPUTE FST PROPERTIES
        self.sigma_star = self.get_sigma_star()
        self.create_eps_graph()
        self.create_out_graphs()    
        if use_genlm:
            self._build_first_symbol_table_with_next()
        else:
            self._build_first_symbol_table()
        self._build_first_epsin_table()
        self._has_eps_out = any(arc.olabel == self.eps_id for s in self.fst.states() for arc in self.fst.arcs(s))
                
                
        self.backtracking_stats = defaultdict(float)

    def _preprocess_fst(self, fst: pynini.Fst) -> pynini.Fst:
        """Pre-process FST to ensure it's output-pushed and epsilon-free."""
        # Push labels to the left and remove epsilons
        fst = pynini.push(fst, push_labels=True,
                         push_weights=True,
                         remove_total_weight=True,
                         reweight_type="to_initial")
        fst = pynini.rmepsilon(fst)
        return fst

    def _build_first_symbol_table(self):
        max_in = self.num_in_syms + 1
        n = self.fst.num_states()
        self._first = np.full((n, max_in), -1, np.int16)   # out-label
        for s in self.fst.states():
            for arc in self.fst.arcs(s):
                if arc.olabel != self.eps_id:              # non-ε output only
                    tok = self._in_id_to_sym[arc.ilabel]
                    self._first[s, tok] = arc.olabel


    def _build_first_symbol_table_with_next(self):
        max_in = self.num_in_syms + 1
        n = self.fst.num_states()
        self._first = np.full((n, max_in), -1, np.int16)   # out-label
        self._next  = np.full((n, max_in), -1, np.int32)   # next-state
        for s in self.fst.states():
            for arc in self.fst.arcs(s):
                if arc.olabel != self.eps_id:              # non-ε output only
                    tok = self._in_id_to_sym[arc.ilabel]
                    self._first[s, tok] = arc.olabel
                    self._next[s, tok] = arc.nextstate

    @functools.lru_cache(maxsize=200_000)
    def _states_arr(self, ps: frozenset[int]):
        return np.fromiter(ps, dtype=np.int32)

    def _next_pset_after_with_y(self, ps, tok, y):
        key = (ps, tok, y)
        if key in self._next_pset_after_with_y_dict:
            return self._next_pset_after_with_y_dict[key]
        states = self._states_arr(ps)
        lab = self._first[states, tok]
        nxt = self._next[states, tok]
        take = (lab == y) & (nxt >= 0)
        res = frozenset(map(int, nxt[take]))
        self._next_pset_after_with_y_dict[key] = res
        return res
    
    def _build_first_epsin_table(self):
        """
        _first_epsin[q] = first output label ID reachable by a path of input-ε arcs
                        starting at q (must emit at least one non-ε output at the
                        earliest ε-depth). -1 if none (i.e., you must read a token first)
                        or if the first ε-depth is ambiguous (multiple labels).
        """
        n = self.fst.num_states()
        EPS = self.eps_id
        first = np.full(n, -1, dtype=np.int32)

        # Pre-split ε-in adjacencies for speed
        eps0_next = []   # ilabel==EPS & olabel==EPS  : list[nextstate]
        eps_emit  = []   # ilabel==EPS & olabel!=EPS  : list[olabel]
        for s in self.fst.states():
            a0, ao = [], []
            for arc in self.fst.arcs(s):
                if arc.ilabel == EPS:
                    if arc.olabel == EPS:
                        a0.append(int(arc.nextstate))
                    else:
                        ao.append(int(arc.olabel))
            eps0_next.append(tuple(a0))
            eps_emit.append(tuple(ao))

        from collections import deque
        for s in self.fst.states():
            if first[s] != -1:
                continue

            q = deque([int(s)])
            seen = {int(s)}
            labels_at_min_layer = None

            while q and labels_at_min_layer is None:
                layer_size = len(q)
                # process one ε-layer
                for _ in range(layer_size):
                    u = q.popleft()
                    # any ε-in arc that emits here? these are "first" at this depth
                    if eps_emit[u]:
                        if labels_at_min_layer is None:
                            labels_at_min_layer = set(eps_emit[u])
                        else:
                            labels_at_min_layer.update(eps_emit[u])
                        # do NOT expand past emitting arcs (we only care about "first")
                        continue
                    # otherwise, continue through ε-in, ε-out arcs only
                    for v in eps0_next[u]:
                        if v not in seen:
                            seen.add(v)
                            q.append(v)

            if not labels_at_min_layer:
                first[s] = -1
            elif len(labels_at_min_layer) == 1:
                first[s] = int(next(iter(labels_at_min_layer)))
            else:
                first[s] = -1
        self._first_epsin = first


    def first_symbol_epsin_set(self, subset: Union[int, frozenset[int]]) -> int:
        """Return y if every state in subset deterministically emits y via input-ε
        before reading a new token; otherwise return -1."""
        if isinstance(subset, (int, np.integer)):
            subset = (int(subset),)
        else:
            subset = tuple(subset if isinstance(subset, frozenset) else frozenset(subset))

        y = -1
        for q in subset:
            gq = self._first_epsin[q]
            if gq == -1:
                return -1                     # at least one state must read input first
            if y == -1:
                y = gq
            elif y != gq:
                # Different ε-next outputs among members → not determinate
                return -1
        return y
    

    def first_symbol_vectorized(self, subset: Union[int, frozenset[int]], tok_ids: np.ndarray) -> np.ndarray:
        tok_ids = np.asarray(tok_ids, dtype=np.int32)
        if tok_ids.size == 0:
            return np.empty(0, dtype=np.int16)

        if isinstance(subset, (int, np.integer)):
            return self._first[int(subset), tok_ids]

        if not isinstance(subset, frozenset):
            subset = frozenset(subset)
        if not subset:
            return np.full(tok_ids.size, -1, dtype=np.int16)

        states = np.fromiter(subset, dtype=np.int32)
        M32 = self._first[np.ix_(states, tok_ids)].astype(np.int32, copy=False)
        valid = (M32 >= 0)
        S, T = M32.shape
        if S == 1:
            return M32[0].astype(np.int16, copy=False)

        HI = np.iinfo(np.int32).max
        min_valid = np.min(np.where(valid, M32, HI), axis=0)
        max_valid = np.max(np.where(valid, M32, -1), axis=0)

        has_any = (min_valid != HI)
        conflict = has_any & (min_valid != max_valid)

        out = np.full(T, -1, dtype=np.int16)
        ok = has_any & ~conflict
        out[ok] = min_valid[ok].astype(np.int16, copy=False)
        return out

    def is_final(self, state: int) -> bool:
        final_weight = self.fst.final(state).to_string()
        return final_weight == "0"

    def visualize(self, path: str, color_universals: Optional[bool] = True):
        if color_universals:
            highlight = {str(k): v for k, v in (self.universal_states or {}).items() if v}

        dot_path = Path(path).with_suffix(".dot")
        if self.state_names:
            # create symbol table for state names
            state_symtab = pynini.SymbolTable(name="state_symtab")
            for state_id, state_name in self.state_names.items():
                state_symtab.add_symbol(state_name, state_id)
        else:
            state_symtab = None
        self.fst.draw(
            str(dot_path),
            portrait=True,
            acceptor=False,
            show_weight_one=False,
            ssymbols=state_symtab,
        )

        import pydot
        if color_universals:
            graph = pydot.graph_from_dot_file(str(dot_path))[0]
            for node in graph.get_nodes():
                name = node.get_name().strip('"')
                if name in highlight or highlight.get(int(name), False):
                    node.set_style("filled")
                    node.set_fillcolor("#ffd966")
            graph.write_raw(str(dot_path))

        try:
            subprocess.run(["dot", "-Tsvg", str(dot_path), "-o", path], check=True)
        except FileNotFoundError:
            print("Graphviz 'dot' command not found. Please install Graphviz.")
        except subprocess.CalledProcessError as e:
            print(f"Error running dot: {e}")


    def save(self, path: str):
        """
        Save FST to a dir
        Args:
            path (str): The path to save the FST.
        """
        if not os.path.exists(path):
            os.makedirs(path)
        # Save the FST
        self.fst.write(os.path.join(path, "fst.fst"))
        # Save the universal states
        if self.universal_states:
            with open(os.path.join(path, "universal_states.pkl"), "wb") as f:
                pickle.dump(self.universal_states, f)
        # Save the state names
        if self.state_names:
            with open(os.path.join(path, "state_names.pkl"), "wb") as f:
                pickle.dump(self.state_names, f)
        # save llm name
        if self.llm and hasattr(self.llm, "model") and hasattr(self.llm.model, "config"):
            llm_name = getattr(self.llm.model.config, "_name_or_path", None)
            with open(os.path.join(path, "llm_name.txt"), "w") as f:
                f.write(llm_name)
        elif self.llm_name:
            with open(os.path.join(path, "llm_name.txt"), "w") as f:
                f.write(llm_name)
        if self.eos_out:
            with open(os.path.join(path, "eos_out.txt"), "w") as f:
                f.write(self.eos_out)
        
        if self._universal_set_cache:
            with open(os.path.join(path, "set_universal_states.pkl"), "wb") as f:
                pickle.dump(self._universal_set_cache, f)
        

    @classmethod
    def load(cls, path: str, use_genlm=False, llm_name=None):
        """
        Load FST from a dir and return a new instance.
        Args:
            path (str): The path to load the FST from.
        Returns:
            An instance of the class.
        """
        fst = pynini.Fst.read(os.path.join(path, "fst.fst"))
        if not llm_name: 
            if os.path.exists(os.path.join(path, "llm_name.txt")):
                with open(os.path.join(path, "llm_name.txt"), "r") as f:
                    llm_name = f.read().strip()
        instance = cls(fst, llm_name=llm_name, use_genlm=use_genlm)

        if os.path.exists(os.path.join(path, "universal_states.pkl")):
            with open(os.path.join(path, "universal_states.pkl"), "rb") as f:
                instance.universal_states = pickle.load(f)

        if os.path.exists(os.path.join(path, "set_universal_states.pkl")):
            with open(os.path.join(path, "set_universal_states.pkl"), "rb") as f:
                instance._universal_set_cache = pickle.load(f)

        if os.path.exists(os.path.join(path, "state_names.pkl")):
            with open(os.path.join(path, "state_names.pkl"), "rb") as f:
                instance.state_names = pickle.load(f)

        if os.path.exists(os.path.join(path, "eos_out.txt")):
            with open(os.path.join(path, "eos_out.txt"), "r") as f:
                instance.eos_out = f.read().strip()
        return instance

    def apply(
            self, 
            input_str: Optional[str] = None, 
            input_tokens: Optional[List]=None, 
            join_char: Optional[bool]=True, 
            use_llm: Optional[bool]=True,
            bytes_input: Optional[bool]=False,
            visualize: Optional[bool]=False
        ):
        """
        Apply FST to transform an input
        Args:
            input (str): The input string to transduce.
            tokenizer (AutoTokenizer): The tokenizer to use for tokenizing the text.
        """
        # TODO this needs to be simplified
        if input_str is None and input_tokens is None:
            raise ValueError("Either input_str or input_tokens must be provided.")
        if not input_tokens:
            if use_llm:
                input_tokens = self.llm.tokenizer.encode(input_str, add_special_tokens=False)
                input_tokens = [str(tok) for tok in input_tokens]
            else:
                # Split into characters
                if bytes_input:
                    input_tokens = list([str(inp) for inp in input_str.encode("utf-8")])
                else:
                    input_tokens = list(input_str)

        input_fst = pynini.Fst()
        acceptors = []
        for token in input_tokens:
            label = self.fst.input_symbols().find(token)
            if label == -1:
                raise ValueError(f"Token {token} not found in input symbols.")
            if bytes_input:
                acceptors.append(pynini.accep(token, token_type=self.fst.input_symbols()))
            else:
                acceptors.append(build_label_acceptor(label))

        # Concatenate acceptors
        if not acceptors:
            input_fst = pynini.accep("", token_type=self.fst.input_symbols())
        else:
            input_fst = acceptors[0]
            for tok_fst in acceptors[1:]:
                input_fst = input_fst + tok_fst
        
        if visualize:
            input_fst.set_input_symbols(self.fst.input_symbols())
            input_fst.set_output_symbols(self.fst.output_symbols())
            wrapped = TransducedLM(input_fst, llm=self.llm)
            wrapped.visualize("apply_input_fst.svg")
        # Construct lattice
        lattice = input_fst @ self.fst

        # Obtain shortest path
        path = pynini.shortestpath(lattice, nshortest=1, unique=True)
        
        if visualize:
            path.set_input_symbols(self.fst.input_symbols())
            path.set_output_symbols(self.fst.output_symbols())
            wrapped = TransducedLM(path, llm=self.llm)
            wrapped.visualize("path.svg")

        output_labels = []

        state = path.start()
        while not path.final(state).to_string() == "0":
            for arc in path.arcs(state):
                if arc.olabel != self.eps_id:
                    output_labels.append(arc.olabel)
                state = arc.nextstate
                break

        symbol_table = path.output_symbols()
        if symbol_table is not None:
            output_tokens = [symbol_table.find(label) for label in output_labels]
        else:
            output_tokens = output_labels
        if join_char:
            return "".join(output_tokens)
        else:
            return output_tokens
        
    def get_sigma_star(self):
        sig = pynini.Fst()
        s = sig.add_state()
        sig.set_start(s)
        sig.set_final(s)
        for ilabel, _ in self.fst.input_symbols():
            if ilabel == self.eps_id:
                continue
            sig.add_arc(s, pynini.Arc(ilabel, ilabel, pynini.Weight.one(self.fst.weight_type()), s))
        return sig

    @functools.lru_cache(maxsize=200000)
    def input_epsilon_closure(self, start_state: int) -> set:
        """
        Computes the input epsilon closure of a state in the FST.
        """
        closure = set()
        queue = deque([start_state])
        while queue:
            state = queue.popleft()
            if state in closure:
                continue
            closure.add(state)
            arc_iterator = self.fst.arcs(state)
            for arc in arc_iterator:
                if arc.ilabel == 0:
                    queue.append(arc.nextstate)
        return closure
    
    def closure_input_reads_all_symbols(self, state: int, closure: set, verbose: Optional[bool]) -> set:
        symbols = self.fst.input_symbols()
        read_symbols = set()
        for state in closure:
            for arc in self.fst.arcs(state):
                read_symbols.add(arc.ilabel)
        # compare read_symbols with all symbols
        all_symbols = set([sym_id for sym_id, _ in symbols])

        # If the state is itself final, we don't need epsilon
        if self.is_final(state):
            all_symbols.discard(self.eps_id)
            read_symbols.discard(self.eps_id)

        if read_symbols == all_symbols:
            return True
        else:
            _log(verbose, f"State {state} missing symbols: {all_symbols - read_symbols}")
            return False

    def is_universal(
            self, 
            start_state: int, 
            verbose: bool= False, 
            closure: Optional[set] = None
        ) -> bool:
        
        """
        Checks whether `start_state` is universal, i.e. all strings in the
        language are accepted from `start_state`.

        Args:
        start_state: integer state ID from which to check universality.

        Returns:
        True if the right input language accepted from 'start_state' is sigma star,
        False otherwise.
        """
        # We first check if we can read all input symbols from the input closure
        if closure is None:
            closure = self.input_epsilon_closure(start_state)

        # Precheck for large fsts
        if not self.closure_input_reads_all_symbols(start_state, closure, verbose=verbose):
            _log(verbose, f"State {start_state} does not read all symbols")
            return False

        if not hasattr(self, "sigma_star"):
            self.sigma_star = self.get_sigma_star()
        sub = self.fst.copy()
        sub.set_start(start_state)
        sub.project("input")
        sub = pynini.rmepsilon(sub)
        sub = pynini.determinize(sub)
        sub = pynini.minimize(sub, allow_nondet=False)
        _log(verbose, f"minimised sub-FSA states = {sub.num_states()}")
        diff = pynini.difference(self.sigma_star, sub)
        diff = pynini.connect(diff)
        return diff.num_states() == 0
    
    @profile
    def is_universal_set(
            self,
            start_states: Union[int, FrozenSet[int]],
            verbose: bool = False,
        ) -> bool:
        """
        Universality for a set of states
        """
        if not hasattr(self, "sigma_star"):
            self.sigma_star = self.get_sigma_star()
        if isinstance(start_states, int):
            start_states = frozenset({start_states})
        elif not isinstance(start_states, (set, frozenset)):
            raise TypeError("start_states must be int or (frozen)set[int]")
        sub = self.fst.copy()
        new_start = sub.add_state()
        sub.set_start(new_start)
        one = pynini.Weight.one(sub.weight_type())
        for s in start_states:
            sub.add_arc(new_start, pynini.Arc(self.eps_id, self.eps_id, one, s))
        sub.project("input")
        sub = pynini.rmepsilon(sub)
        sub = pynini.determinize(sub)
        sub = pynini.minimize(sub, allow_nondet=False)
        _log(verbose, f"minimised sub-FSA states = {sub.num_states()}")
        diff = pynini.difference(self.sigma_star, sub)
        diff = pynini.connect(diff)
        return diff.num_states() == 0

    def sort_states_by_incoming_arcs(self):
        incoming_arc_counts = defaultdict(int)
        for state in self.fst.states():
            for arc in self.fst.arcs(state):
                incoming_arc_counts[arc.nextstate] += 1

        all_states = list(self.fst.states())
        sorted_states = sorted(all_states, key=lambda s: incoming_arc_counts[s], reverse=True)
        return sorted_states
    
    def compute_universal_states(
            self,
            path_universal_states: Optional[str]=None,
            verbose: Optional[bool] = False
        ):
        """
        Computes the set of universal states in the FST.
        A state is universal if it's right input language is sigma star.
        """
        if path_universal_states is not None:
            # attempt to load the universal states from the file
            try:
                with open(path_universal_states, "rb") as f:
                    self.universal_states = pickle.load(f)
            except FileNotFoundError:
                print(f"File {path_universal_states} not found. Computing universal states.")
                self.universal_states = {}

        # We sort states by number of incoming arcs
        # States with more incoming arcs have a higher chance of being in the closure
        sorted_states = self.sort_states_by_incoming_arcs()
        for state in sorted_states:
            if state in self.universal_states:
                continue
            _log(verbose, f"Computing universal states for {state}/{self.fst.num_states()}")
            
            input_closure = self.input_epsilon_closure(state)
            for closure_state in input_closure:
                if self.universal_states.get(closure_state):
                    self.universal_states[state] = True
                    _log(verbose, f"State {closure_state} in input closure. Marking as {state}")
                    break
            else:
                self.universal_states[state] = self.is_universal(
                    state,
                    verbose=verbose,
                    closure=input_closure
                )
            _log(verbose, f"Is state {state} universal: {self.universal_states[state]}")
        # Save Dict with universal states
        if path_universal_states is not None:
            with open(path_universal_states, "wb") as f:
                pickle.dump(self.universal_states, f)

    #
    # ---- Helper methods for optimal decomposition
    #

    def empty_cache(self):
        self._logp_next_cache.clear()
        self._cover_logp_cache.clear()
        self._cover_beam_cache.clear()
        self._logp_pending = {}

    async def _logp_for(self, ctx):
        if ctx not in self._logp_next_cache:
            self._logp_next_cache[ctx] = await self.genlm_realpha.logp_next_for(ctx)
        return self._logp_next_cache[ctx]

    @functools.lru_cache(maxsize=None)
    def _arcs(self, state):
        arcs = list(self.fst.arcs(state))
        n = len(arcs)
        ilabels = [arc.ilabel for arc in arcs]
        olabels = [arc.olabel for arc in arcs]
        ilabels = np.fromiter((a.ilabel for a in arcs), dtype=np.int32, count=n)
        in_syms = np.fromiter((self._in_id_to_sym[a.ilabel] for a in arcs), dtype=np.int32, count=n)
        olabels = np.fromiter((a.olabel for a in arcs), dtype=np.int16, count=n)
        nextstates = np.fromiter((a.nextstate for a in arcs), dtype=np.int32, count=n)
        return [arcs, ilabels, in_syms, olabels, nextstates]

    @profile
    async def _get_cached_log_dist(
        self,
        context: tuple[str, ...],
        *,
        cfg=None,
        cfg_hash=None,
        backend: _LMBackend=_LMBackend.GENLM_BYTES,
        thread_pool: concurrent.futures.ThreadPoolExecutor | None = None,
    ):
        key = (context, cfg_hash)
        
        try:
            return self._logp_next_cache[key]
        except KeyError:
            pass
        futs = self._logp_pending
        task = futs.get(key)
        if task is None:
            if backend is _LMBackend.GENLM_BYTES:
                task = asyncio.create_task(self._logp_for(context))
        
            elif backend is _LMBackend.GENLM_ASYNC:
                task = asyncio.create_task(self.llm.next_token_logprobs([self.bos_id, *context]))

            else:
                raise ValueError("Invalid backend specification")
            futs[key] = task
        try:
            result = await task
        finally:
            futs.pop(key, None)
        
        if backend is _LMBackend.GENLM_ASYNC:
            #result.detach().cpu().float()
            result = result.detach().to(torch.float16).cpu().numpy()

        #self._logp_next_cache[key] = result
        self._logp_next_cache[key] = result
        return result

    def precompute_next_pset(self):
        _log(True, "Precomputing next psets")
        # Safer token-id set
        in_tok_ids = None
        if hasattr(self, "_in_sym_to_id") and self._in_sym_to_id is not None:
            try:
                in_tok_ids = set(int(v) for v in self._in_sym_to_id.values())
            except Exception:
                in_tok_ids = set(self._in_sym_to_id.values())

        states = [int(s) for s in self.fst.states()]  # materialize once
        state2tok2ns: dict[int, dict[int, frozenset[int]]] = {}

        # Build per-state: tok -> next-state-set
        for s in states:
            tok2ns: dict[int, frozenset[int]] = {}
            _arcs, _ilabels, in_syms, _olabels, nextstates = self._arcs(s)
            if getattr(in_syms, "size", 0):
                order = in_syms.argsort(kind="mergesort")
                in_sorted = in_syms[order]
                ns_sorted = nextstates[order]
                uniq_tok, idx_first = np.unique(in_sorted, return_index=True)
                idx_first = idx_first.tolist()
                idx_last = idx_first[1:] + [in_sorted.size]
                for tok_id, i0, i1 in zip(uniq_tok.tolist(), idx_first, idx_last):
                    tok_id = int(tok_id)
                    if in_tok_ids is not None and tok_id not in in_tok_ids:
                        continue
                    ns_slice = ns_sorted[i0:i1]
                    if ns_slice.size == 0:
                        continue
                    tok2ns[tok_id] = frozenset(int(x) for x in ns_slice)
            state2tok2ns[s] = tok2ns

            # Prefill singleton (s,)
            ps_id = self._intern_ps(frozenset((s,)))
            arr = self._next_by_tok[ps_id]
            for t, ns in tok2ns.items():
                arr[t] = self._intern_ps(ns)
            self._next_prefilled[ps_id] = True

        # Prefill pairs by unioning per-state maps
        n = len(states)
        for ai in range(n):
            s1 = states[ai]; map1 = state2tok2ns.get(s1, {})
            for bj in range(ai + 1, n):
                s2 = states[bj]; map2 = state2tok2ns.get(s2, {})
                if not map1 and not map2:
                    continue

                ps_key = frozenset((s1, s2))
                ps_id = self._intern_ps(ps_key)
                arr = self._next_by_tok[ps_id]

                # Just the tokens that exist on either state
                for t in (map1.keys() | map2.keys()):
                    ns = (map1.get(t, frozenset()) | map2.get(t, frozenset()))
                    if ns:
                        arr[t] = self._intern_ps(ns)
                self._next_prefilled[ps_id] = True
        
    def precompute_universal_set(self, verbose=False):
        """
        Computes universality for sets of states in the FST.
        """
        _log(verbose, "Precomputing universal set")
        non_univ = [s for s, u in self.universal_states.items() if not u]
        _log(verbose, "Tuples")
        comb_2 = list(combinations(non_univ, 2))
        for state_pair in tqdm(comb_2):
            fs = frozenset(state_pair)
            #_log(verbose, f"Processing pair: {fs}")
            if fs not in self._universal_set_cache:
                self._universal_set_cache[fs] = self.is_universal_set(fs)

        _log(verbose, "Triples")
        comb_3 = list(combinations(non_univ, 3))
        for triple in tqdm(comb_3):
            fs_triple = frozenset(triple)
            if fs_triple in self._universal_set_cache:
                continue
            if any(
                self._universal_set_cache.get(frozenset(pair), False)
                for pair in combinations(triple, 2)
            ):
                self._universal_set_cache[fs_triple] = True
                continue

            self._universal_set_cache[fs_triple] = self.is_universal_set(fs_triple)


    def _record_if_universal(self, S: FrozenSet[int], is_universal: bool) -> None:
        if is_universal:
            self._universal_sets_by_size[len(S)].add(S)

    def is_powerset_universal(self, S: FrozenSet[int]) -> bool:
        # Early exit 1: cache lookup
        if S in self._universal_set_cache:
            return self._universal_set_cache[S]
        
        # Early exit 2: single state
        if len(S)== 1:
            state, = S
            return self.universal_states[state]
        
        # Early exit 3: single state in set is universal
        if any(self.universal_states[state] for state in S):
            return True
        
        # Early exit 4: subset is cached and universal
        for k in range(1, len(S) + 1):
            for U in self._universal_sets_by_size.get(k, ()):
                if U.issubset(S):
                    self._universal_set_cache[S] = True
                    return True

        # Compute values    
        val = self.is_universal_set(S)
        self._universal_set_cache[S] = val
        self._record_if_universal(S, val)
        return val
    

    def has_final(self, S: FrozenSet[int]) -> bool:
        """True iff at least one state of S is final."""
        return any(self.is_final(s) for s in S)
    
    def check_fst_properties(self) -> Dict[str, bool]:
        """
        Check if FST has properties required for optimized logp_next.
        """
        properties = {
            'epsilon_free_output': True,
            'has_epsilon_id_zero': self.eps_id == 0,
            'has_routing_table': hasattr(self, '_first')
        }
        
        # Check for epsilon outputs on non-epsilon inputs
        for state in self.fst.states():
            for arc in self.fst.arcs(state):
                if arc.ilabel != 0 and arc.olabel == 0:
                    properties['epsilon_free_output'] = False
                    break
            if not properties['epsilon_free_output']:
                break
        
        return properties
    
    def create_eps_graph(fst):
        fst._eps_graph: list[list[tuple[int, int | None]]] = []
        for s in fst.fst.states():
            eps_edges = []
            for arc in fst.fst.arcs(s):
                if arc.ilabel == 0:
                    eps_edges.append((arc.nextstate, arc.olabel))
            fst._eps_graph.append(eps_edges)
    
    def create_out_graphs(self):
        """Precompute two adjacencies: ε-output arcs and non-ε-output arcs."""
        eps = self.eps_id
        out_eps_graph = []
        out_non_eps_graph = []
        for s in self.fst.states():
            eps_edges = []
            non_edges = []
            for arc in self.fst.arcs(s):
                if arc.olabel == eps:
                    eps_edges.append((arc.nextstate, self._in_id_to_sym[arc.ilabel]))
                else:
                    non_edges.append((arc.nextstate, self._in_id_to_sym[arc.ilabel], arc.olabel))
            out_eps_graph.append(eps_edges)
            out_non_eps_graph.append(non_edges)

        self._out_eps_graph = out_eps_graph
        self._out_non_eps_graph = out_non_eps_graph

    @functools.lru_cache(maxsize=200_000)
    def output_epsilon_frontier_pairs(self, start_state: int) -> tuple[tuple[tuple[int, ...], int], ...]:
        """Frontier for one state, but return ONLY (in_seq, out_lab) pairs (deduped)."""
        eps_id = self.eps_id
        eps_graph = self._out_eps_graph
        non_graph = self._out_non_eps_graph

        # use a set for dedup inside the state
        pairs: set[tuple[tuple[int, ...], int]] = set()
        stack: list[tuple[int, tuple[int, ...]]] = [(start_state, ())]
        seen: set[tuple[int, tuple[int, ...]]] = set()

        while stack:
            s, in_pref = stack.pop()
            key = (s, in_pref)
            if key in seen:
                continue
            seen.add(key)

            # emit frontier (no nxt)
            for _nxt, ilab, olab in non_graph[s]:
                in_seq = in_pref if ilab == eps_id else (in_pref + (ilab,))
                pairs.add((in_seq, olab))

            # continue via ε-output arcs only
            for nxt, ilab in eps_graph[s]:
                new_pref = in_pref if ilab == eps_id else (in_pref + (ilab,))
                stack.append((nxt, new_pref))

        # return an immutable container for lru_cache friendliness
        return tuple(pairs)

    @functools.lru_cache(maxsize=100_000)
    def _pset_frontier_counts(self, ps: frozenset[int]) -> tuple[dict[tuple[tuple[int, ...], int], int], int]:
        """Aggregate frontier counts for a whole power-state set."""
        counts: dict[tuple[tuple[int, ...], int], int] = defaultdict(int)
        for s in ps:
            for pair in self.output_epsilon_frontier_pairs(int(s)):
                counts[pair] += 1
        return counts, max(1, len(ps))

    @functools.lru_cache(maxsize=200000)
    def input_epsilon_closure_track_trg(
        self,
        start_state: int,
        missing_out: tuple[int, ...] | None = None,
    ) -> set[tuple[int, tuple[int, ...]]]:
        """
        ε-closure that also remembers the concatenated output string
        along the ε-only path.  Returned as a set of (state, out_tuple).
        """
        closure: set[tuple[int, tuple[int, ...]]] = set()

        stack = [(start_state, ())] # (state, out_so_far)
        eps_graph = self._eps_graph # List[(dest, olabel)] per state
        while stack:
            state, out = stack.pop()
            key = (state, out)
            if key in closure:
                continue
            closure.add(key)
            for nxt, olabel in eps_graph[state]:
                if olabel == self.eps_id: 
                    stack.append((nxt, out))
                    continue
                if missing_out is None:
                    stack.append((nxt, out + (olabel,)))
                    continue
                i = len(out)
                if i >= len(missing_out)-1: # already matched everything # TODO better handling of overshooting
                    continue            # extra output not allowed
                if olabel != missing_out[i]:            
                    continue            # Diverges from the next symbol cover
                stack.append((nxt, out + (olabel,)))
        return closure

    @profile
    async def gather_ctx2logdist(self, uniq_ctx, config, cfg_hash):
        if not uniq_ctx:
            return {}
        cached, to_score = {}, []
        for ctx in uniq_ctx:
            key = (ctx, cfg_hash)
            if key in self._logp_next_cache:
                cached[ctx] = self._logp_next_cache[key]
            else:
                to_score.append(ctx)
        if not to_score:
            return {**cached}
        # Call
        if config.backend == _LMBackend.GENLM_ASYNC:
            in_ids = [[self.bos_id, *ctx] for ctx in to_score]
            scores = await self.llm.batch_next_token_logprobs(in_ids)
            logps = scores.detach().to(torch.float16).cpu().numpy()
            for logp, ctx in zip(logps, to_score):
                self._logp_next_cache[(ctx, cfg_hash)] = logp
            return {
                ctx: self._logp_next_cache[(ctx, cfg_hash)]
                for ctx in uniq_ctx
            }

        elif config.backend == _LMBackend.GENLM_BYTES:
            return dict(
            zip(uniq_ctx,
                await asyncio.gather(
                    *[self._get_cached_log_dist(c, backend=config.backend, cfg_hash=cfg_hash)
                    for c in uniq_ctx]))
            )
    
    def get_cached_beams(self, out_tokens, config):
        tokens = tuple(out_tokens)
        for i in range(len(tokens) - 1, -1, -1):
            key = tokens[:i]
            if key in self._cover_beam_cache:
                return self._cover_beam_cache[key]
        return  {((), 0.0): [(self.fst.start(), np.empty(0, dtype=np.int16),)]} 

    @functools.lru_cache(maxsize=None)
    def _vectorized_arcs(self, state: int):
        arcs, ilabels, in_syms, olabels, nextstates = self._arcs(state)
        n = len(arcs)
        in_special = ~np.isin(ilabels, self.special_tokens, assume_unique=True)
        vec = (np.full(n, state, dtype=np.int32), in_syms, olabels, nextstates, in_special)
        return vec
    
    @functools.lru_cache(maxsize=2000)
    def get_vectorized(self, states, track_powerstates):
        all_states: List[int] = []
        all_in_syms: List[int] = []
        all_olabels: List[int] = []
        all_states_out: List[str] = []
        all_states_out_len: List[int] = []
        all_next_states: List[int] = []
        all_in_sym_not_special: List[bool] = []
        for s, s_out in states:
            vec_states, vec_insyms, vec_olabels, vec_next_states, vec_in_sym_not_special = self._vectorized_arcs(s)
            n_arcs = len(vec_states)
            all_states.append(vec_states)
            all_in_syms.append(vec_insyms)
            all_olabels.append(vec_olabels)
            all_next_states.append(vec_next_states)
            all_in_sym_not_special.append(vec_in_sym_not_special)
            
            if track_powerstates:
                all_states_out.extend([np.asarray(s_out, dtype=np.int16)] * n_arcs)
            all_states_out_len.extend([len(s_out)] * n_arcs)
        # Convert to numpy arrays for vectorized operations
        all_states = np.concatenate(all_states) if len(all_states) != 1 else all_states[0]
        all_insyms = np.concatenate(all_in_syms) if len(all_in_syms) != 1 else all_in_syms[0]
        all_olabels = np.concatenate(all_olabels) if len(all_olabels) != 1 else all_olabels[0]
        all_next_states = np.concatenate(all_next_states) if len(all_next_states) != 1 else all_next_states[0]
        all_in_sym_not_special = np.concatenate(all_in_sym_not_special) if len(all_in_sym_not_special) != 1 else all_in_sym_not_special[0]
        all_advances = (all_olabels != self.eps_id)
        if track_powerstates:
            all_states_out = np.array(all_states_out, dtype=object) 
        else:
            # dummy
            all_states_out = np.zeros_like(all_states, dtype=np.int16)
        all_states_out_len = np.asarray(all_states_out_len, dtype=np.int16)

        res = (
            all_insyms,
            all_olabels,
            all_advances,
            all_next_states,
            all_in_sym_not_special,
            all_states_out,
            all_states_out_len
        )
        return res


    def _register_block(self, blocks: List[Block], all_n_state, all_adv, all_olab, all_states_out, acc_idx, beamout_pref, len_beamout) -> int:
        blocks.append(Block(all_n_state, all_adv, all_olab, all_states_out,acc_idx, beamout_pref, int(len_beamout)))
        return len(blocks) - 1

    def _pack_beam(self, all_insym: np.ndarray, eps_mask: np.ndarray, acc_m: np.ndarray, blocks: List[Block],
                all_n_state, all_adv, all_olab, all_states_out, beamout_pref, len_beamout) -> Tuple[Optional[Packed], Optional[Tuple[int,np.ndarray]], bool]:
        """
        Efficiently packs all candidates
        Returns: (packed_non_eps, eps_chunk, has_non_eps)
        packed_non_eps: (block_id, sym_runs, starts, ends, pos_ord) or None
        eps_chunk: (block_id, eps_take) or None  (positions in acc_idx)
        """
        acc_idx = np.flatnonzero(acc_m)
        if acc_idx.size == 0:
            return None, None, False

        insym_v = all_insym[acc_idx]
        ne_mask = eps_mask[acc_idx]
        ne_pos  = np.flatnonzero(ne_mask)
        has_ne  = ne_pos.size > 0

        blk_id = self._register_block(blocks, all_n_state, all_adv, all_olab, all_states_out, acc_idx, beamout_pref, len_beamout)

        eps_chunk = None
        if ne_pos.size < acc_idx.size:
            eps_pos = np.ones(acc_idx.size, dtype=bool)
            eps_pos[ne_pos] = False
            eps_take = np.flatnonzero(eps_pos)
            if eps_take.size:
                eps_chunk = (blk_id, eps_take)

        # Pack non-eps cands
        packed = None
        if has_ne:
            ins_ne = insym_v[ne_pos]
            order = np.argsort(ins_ne, kind="stable")
            sym_ord = ins_ne[order]
            pos_ord = ne_pos[order]
            starts = np.r_[0, 1 + np.flatnonzero(sym_ord[1:] != sym_ord[:-1])]
            ends = np.r_[starts[1:], sym_ord.size]
            sym_runs = sym_ord[starts].astype(np.int32, copy=False)  # unique, sorted
            packed = (blk_id, sym_runs, starts, ends, pos_ord)

        return packed, eps_chunk, has_ne
    
    @profile
    def _materialize_survivors(
        self,
        survivors: np.ndarray,
        kind: np.ndarray,          # 0=ne, 1=eps
        sym: np.ndarray,
        owner_idx: np.ndarray,
        key_lp: np.ndarray,
        owners: List[Tuple[str, tuple, float, Any]],
        blocks: List[Block],
    ):
        beams: Dict[Tuple[tuple,float], list] = {}
        for i in survivors:
            _, ys, _, ref = owners[int(owner_idx[i])]
            if kind[i] == 0:  # non-eps
                s = sym[i]
                lst = beams.setdefault((ys + (s,), key_lp[i]), [])
                for (blk_id, sym_runs, starts, ends, pos_ord) in ref:
                    j = np.searchsorted(sym_runs, s)
                    if j < sym_runs.size and sym_runs[j] == s:
                        take_pos = pos_ord[starts[j]:ends[j]]
                        blk = blocks[blk_id]
                        abs_idx = blk.acc_idx[take_pos]
                        lst.extend(zip(
                            blk.nstate[abs_idx],
                            repeat(blk.bpref, abs_idx.size),
                            repeat(blk.L, abs_idx.size),
                            blk.adv[abs_idx],
                            blk.olab[abs_idx],
                            blk.sout[abs_idx],
                        ))
            else:  # eps
                lst = beams.setdefault((ys, key_lp[i]) , [])
                for (blk_id, eps_pos) in ref:
                    blk = blocks[blk_id]
                    abs_idx = blk.acc_idx[eps_pos]
                    lst.extend(zip(
                        blk.nstate[abs_idx],
                        repeat(blk.bpref, abs_idx.size),
                        repeat(blk.L, abs_idx.size),
                        blk.adv[abs_idx],
                        blk.olab[abs_idx],
                        blk.sout[abs_idx],
                    ))
        return beams
    
    @profile
    def _build_keys(
        self,
        cand_pack: Dict[Tuple[tuple,float], List[Packed]],
        cand_eps_idx: Dict[Tuple[tuple,float], List[Tuple[int,np.ndarray]]],
        get_ctx,
    ):
        """
        Returns:
        survivors: np.ndarray[int]
        kind:      np.ndarray[uint8]   (0=non-eps, 1=eps)
        sym:       np.ndarray[int32]   (symbol or -1 for eps)
        owner_idx: np.ndarray[int32]   (index into 'owners')
        logps:     np.ndarray[float32] (key scores)
        owners:    List[Tuple[str, tuple, float, Any]]
                    entries are ("ne", ys, base_logp, records) or ("eps", ys, base_logp, eps_chunks)
        """
        owners = []
        kind_chunks, sym_chunks, lp_chunks, owner_chunks = [], [], [], []

        owner_id = 0
        for (ys, logp), records in cand_pack.items():
            dist = get_ctx(ys)
            all_syms = (np.concatenate([r[1] for r in records]) if len(records) > 1 else records[0][1])
            uniq = np.unique(all_syms).astype(np.int32, copy=False)
            if hasattr(dist, "get"):
                dget = dist.get
                new_lp = (float(logp) + np.fromiter((dget(int(s), float("-inf")) for s in uniq),
                    dtype=np.float32, count=uniq.size))
            else:
                dist = np.asarray(dist, dtype=np.float32, order="C")
                new_lp = (float(logp) + dist[uniq]).astype(np.float32, copy=False)

            owners.append(("ne", ys, float(logp), records))
            n = uniq.size
            kind_chunks.append(np.zeros(n, dtype=np.uint8))
            sym_chunks.append(uniq.astype(np.int32, copy=False))
            lp_chunks.append(new_lp)
            owner_chunks.append(np.full(n, owner_id, dtype=np.int32))
            owner_id += 1

        for (ys, logp), eps_chunks in cand_eps_idx.items():
            owners.append(("eps", ys, float(logp), eps_chunks))
            kind_chunks.append(np.ones(1, dtype=np.uint8))
            sym_chunks.append(np.array([-1], dtype=np.int32))
            lp_chunks.append(np.array([float(logp)], dtype=np.float32))
            owner_chunks.append(np.array([owner_id], dtype=np.int32))
            owner_id += 1

        if not lp_chunks:
            empty = np.empty(0, dtype=np.int32)
            return empty, empty, empty, empty, empty.astype(np.float32), owners

        kind = np.concatenate(kind_chunks)
        sym = np.concatenate(sym_chunks)
        logps = np.concatenate(lp_chunks)
        owner_idx = np.concatenate(owner_chunks)
        return kind, sym, owner_idx, logps, owners

    @profile
    async def optimal_decomposition(
        self,
        out_tokens: List[str],
        config: Config,
        cache_result: bool = True,
        context_logp: float = None,
    ) -> Tuple[List[Beam], List[Beam]]:
        
        """
        Performs a beam search over fst to split paths into remainder and quotient.
        Can be call with a precover or the original FST
        """
        out_tokens_np = np.asarray(out_tokens, dtype=np.int16)
        # Alias
        n_out = len(out_tokens)
        eps = int(self.eps)
        # config
        track_powerstates = config.track_powerstates
        base_thr = config.prune_threshold
        n_cand_thr = config.candidate_threshold
        max_prune_mass = config.max_prune_mass
        thr_alpha = config.prune_threshold_alpha
        use_beam_cache = config.use_beam_cache
        cover_opt = config.cover_opt
        expand_threshold = config.expand_threshold
        max_candidates = config.max_candidates

        # Initialize beam search
        quotient, remainder = defaultdict(list), defaultdict(list)
        cfg_hash = _hash_cfg(config)
        if use_beam_cache:
            buckets = self.get_cached_beams(out_tokens, config)
            beam_cache = defaultdict(list)
        else:
            buckets = {((), 0.0): [(self.fst.start(),np.empty(0, dtype=np.int16))]} 
        
        while buckets:
            blocks: List[Block] = []
            cand_pack: Dict[Tuple[tuple,float], List[Packed]] = {}
            cand_eps_idx: Dict[Tuple[tuple,float], List[Tuple[int,np.ndarray]]] = {}
            pending_ys = set()

            # lazy collections
            for key, beam_list in buckets.items():
                cover_beams = [b for b in beam_list if len(b[BEAM_OUT]) >= n_out or cover_opt]
                if cover_beams:
                    # Closure of union of power‑states from covering beams
                    if track_powerstates:
                        union_ps: frozenset[int] = frozenset().union(s for b in cover_beams \
                                for s, _ in self._state_closure_output_syms.get(b[POWERSTATE], \
                                    self.input_epsilon_closure_track_trg(b[POWERSTATE], missing_out=None)))
                    else:
                        union_ps: frozenset[int] = frozenset().union(b[POWERSTATE] for b in cover_beams)
                    # Covering beam is exactly the target length.
                    if use_beam_cache:
                        for beam in cover_beams:
                            if len(beam[BEAM_OUT]) == n_out:
                                beam_cache[key].append(beam)
                    
                    # Decide quotient vs remainder based on the union power‑state
                    if self.is_powerset_universal(union_ps):
                        quotient[key] = beam_list
                        continue

                    elif self.has_final(union_ps):
                        remainder[key] = beam_list


                    # Early break if we iterate too far or are stuck in loops
                    if len(cover_beams[0][BEAM_OUT]) >= n_out+expand_threshold and len(quotient.keys()):
                        continue
                 
                # We now need to expand all beams in all buckets, we will group them later
                for b in beam_list: 
                    beamout_pref = b[BEAM_OUT]
                    len_beamout = len(beamout_pref)
                    covered_target = len_beamout >= n_out or cover_opt
                    
                    if track_powerstates:
                        missing_out = None if covered_target else tuple(out_tokens[len_beamout:])
                        states = ((t, o) for t, o in self.input_epsilon_closure_track_trg(b[POWERSTATE], missing_out=missing_out))
                        states = tuple(states)
                    else:
                        states = ((b[POWERSTATE], ()),)
                    
                    # Get vectorized cache
                    (
                        all_insym,
                        all_olab,
                        all_adv,
                        all_n_state,
                        in_sym_not_special_mask,
                        all_states_out,
                        state_out_len
                    ) = self.get_vectorized(states, track_powerstates)

                    eps_mask = all_insym != eps
                    # Early reject beams
                    acc_m = in_sym_not_special_mask.copy()
                    if covered_target:
                        acc_m &= eps_mask
                    else:
                        pos = len_beamout + state_out_len
                        adv_mask = acc_m & all_adv
                        in_range_m = adv_mask & (pos < n_out)
                        if in_range_m.any():
                            acc_m[in_range_m] = all_olab[in_range_m] == np.take(out_tokens_np, pos[in_range_m])

                    # efficiently collects all candidates
                    packed, eps_chunk, has_ne = self._pack_beam(
                        all_insym, eps_mask, acc_m,
                        blocks, all_n_state, all_adv, all_olab, all_states_out,
                        beamout_pref, len_beamout
                    )
                    # add new scoring item
                    if has_ne:
                        pending_ys.add(key[YS])
                    if packed is not None:
                        cand_pack.setdefault(key, []).append(packed)
                    if eps_chunk is not None:
                        cand_eps_idx.setdefault(key, []).append(eps_chunk)
                        
            # Pruning
            if cand_pack or cand_eps_idx:
                # Score elements
                ctx2logdist = await self.gather_ctx2logdist(pending_ys, config, cfg_hash)
                # Aggregate
                kind, sym, owner_idx, key_lp, owners = self._build_keys(
                    cand_pack, cand_eps_idx, ctx2logdist.__getitem__,
                )
                # Prune
                survivors = _prune_by_logweights(
                    key_lp, base_thr, n_cand_thr, 
                    thr_alpha, max_prune_mass, max_candidates
                )
                if survivors.size:
                    beams = self._materialize_survivors(
                        survivors, kind, sym, owner_idx, key_lp, owners, blocks
                    )
                    buckets = {b: self.materialize_out(outs, track_powerstates) for b, outs in beams.items()}
                else:
                    buckets = {}
            else:
                buckets = {}
        if use_beam_cache and beam_cache and cache_result:
            self._cover_beam_cache[out_tokens] = beam_cache
        return remainder, quotient
    
    @profile
    def materialize_out(self, candidate_outs, track_powerstates):
        outs = []
        for (tgt, beam_out_prefix,len_beam_out, adv, olabel, state_out) in candidate_outs:
            if track_powerstates:
                slen = state_out.shape[0]
                total = len_beam_out + slen + adv
                new = np.empty(total, dtype=np.int16)
                new[:len_beam_out] = beam_out_prefix
                new[len_beam_out:len_beam_out+slen] = state_out
                if adv:
                    new[-1] = olabel
            else:
                new = (*beam_out_prefix, olabel) if adv else beam_out_prefix
            outs.append((tgt, new))
        return outs


    async def eos_prob(
        self, 
        config: Config,
        context: Optional[List[str]] = [],
        cover_fct: Optional[Callable]= None,
    ):
        """
        Compute the eos probability
        """  
        cover_fct = self.optimal_decomposition if not cover_fct else cover_fct
        # We reconstruct the perfect matching beams from the cache
        cfg_hash = _hash_cfg(config)
        key = (" ".join(context), cfg_hash)
        if key not in self._cover_beam_cache:
            _log(config.verbose, f"{key} cache miss, computing cover")
            # Populate cache
            await cover_fct(
                context,
                config=config,
        )
        eos_ps = []
        perfect_beams = self._cover_beam_cache[key]
        for beam in perfect_beams:
            if beam.pos == len(context):
                log_probs = await self._get_cached_log_dist(beam[YS], cfg_hash)
                eos_logp = log_probs.get(self.eos_out, float("-inf"))
                eos_ps.append(beam.logp+eos_logp)
        return logsumexp(eos_ps)

    def get_or_build_ngram(
        self,
        config,
        path: str | pathlib.Path,
        n_data: int,
        max_order: int = 4,
        discount: float = 0.75,
    ) -> dict:
        path = pathlib.Path(path)
        
        if path.exists():
            _log(config.verbose, "Loading n_gram")
            return prepare_ngram_arrays(_load_ngram(path))
        
        _log(config.verbose, "Training n-gram")
        data, _, _ = load_wikitext_paragraphs_bytes(
            self, 
            "train", 
            n=n_data, 
            verbose=False, 
            join_paragraphs=True, 
            transducer_name=self.name) 
        _log(config.verbose, f"len data {len(data)}")
        ngram = train_ngram(
            corpus=data,
            sym_to_id=self._out_sym_to_id,
            vocab_size=self.fst.output_symbols().num_symbols(),
            max_order=max_order,
            discount=discount,
        )
        _save_ngram(ngram, path)
        return prepare_ngram_arrays(ngram)
    
    async def sequence_logp_next(self, config: Config, sequence: List[str]):
        stats = {
            "byte_level_log_prob": [],
            "byte_level_log_distribution": [],
            "times": []
        }
        self.backtracking_stats = defaultdict(float)
        use_no_symloop = config.use_no_symloop
        ngram_top_p = config.ngram_top_p
        if ngram_top_p < 1.0: # Use ngram proposal
            _log(config.verbose, f"Using ngram with top-p={ngram_top_p}")
            n_data=10000
            max_order=3
            discount=0.75
            self._ngram = self.get_or_build_ngram(
                config=config,
                path=f"{max_order}gram_{n_data}bytes_{self.name}",
                n_data=n_data, max_order=max_order, discount=discount,
            )      

        out_tokens = [self._out_sym_to_id[t] for t in sequence]
        running = ()
        
        logp_context = 0

        logp_func = self.logp_next_symloop
        if use_no_symloop:
            _log(config.verbose, "Using logp_next_no_symloop")
            logp_func = self.logp_next_no_symloop

        for i, (c, o) in enumerate(zip(sequence, out_tokens)):
            _start = time.time()
            logp = await logp_func(
                config, 
                context=running,
                cover_fct=self.optimal_decomposition,
                logp_context=logp_context,
                top_p=ngram_top_p,
                next_char = o,
                verbose=False
            )
            running = running + (o,)
            _end = time.time()
            stats["times"].append(_end-_start)
            logp = {self._out_id_to_sym[sid]:logp for sid, logp in logp.items()}
            logp_context += logp[str(c)]
            stats["byte_level_log_prob"].append(logp[str(c)])
            stats["byte_level_log_distribution"].append(logp)
            _log(config.verbose and i%100==0, f"Dist for before {c} {logp[str(c)]}: {stats['times'][-1]}")
            #_log(config.verbose, f"Dist for before {c} {logp[str(c)]}: {stats['times'][-1]}")
        return stats

    async def backtrack_if_nan(self, context, config):
        _log(config.verbose, "Backtracking")
        self.backtracking_stats["calls"] += 1.0
        _start = time.time()
        new_cfg = copy.copy(config)
        cache_key = context[:-1]
        running_key = cache_key
        while True:
            # We walk the beam cache to find a suitable starting beam here
            if running_key != ():
                running_key = running_key[:-1]
                if cache_key not in self._cover_beam_cache:
                    self._cover_beam_cache[cache_key] = {}
                if running_key in self._cover_beam_cache:
                    self._cover_beam_cache[cache_key] = self._cover_beam_cache[running_key]
            new_cfg.prune_threshold *= 0.8
            new_cfg.prune_threshold_alpha *= 0.8
            new_cfg.max_prune_mass *= 0.8                   
            logp_context = await self.logp(
                context=context,
                config=new_cfg,
            )
            if not np.isinf(logp_context):
                _log(config.verbose, "Finished Backtracking")
                self.backtracking_stats["time"] += time.time()-_start
                break
            
        return logp_context 

    #### HELPERS FOR LOGP_NEXT ####

    def _ensure_ctx_store(self):
        """Caching for context tuples"""
        if getattr(self, "_arr_ready", False):
            return
        self._ctx_cap  = 1 << 16
        self._ctx_size = 1
        self._parent_arr  = np.full(self._ctx_cap, -1, dtype=np.int64)
        self._last_tok_arr= np.full(self._ctx_cap, -1, dtype=np.int64)
        self._clen_arr    = np.zeros(self._ctx_cap, dtype=np.int64)
        self._parent_arr[0]   = -1
        self._last_tok_arr[0] = -1
        self._clen_arr[0]     = 0

        if not hasattr(self, "_id2ctx"):
            self._id2ctx = [()]
        elif not self._id2ctx:
            self._id2ctx.append(())

        self._edge_keys = np.empty(0, dtype=np.int64)
        self._edge_vals = np.empty(0, dtype=np.int64)
        self._buf_keys = np.empty(0, dtype=np.int64)
        self._buf_vals = np.empty(0, dtype=np.int64)

        self._MASK32 = np.int64((1 << 32) - 1)
        self._edge_flush_thresh = 200_000
        self._arr_ready = True

    def _ctx_reserve(self, add: int):
        need = self._ctx_size + int(add)
        if need <= self._ctx_cap:
            return
        cap = self._ctx_cap
        while cap < need:
            cap *= 2
        # reallocate & copy
        pa = np.empty(cap, dtype=np.int64); pa[:self._ctx_size] = self._parent_arr[:self._ctx_size]
        la = np.empty(cap, dtype=np.int64); la[:self._ctx_size] = self._last_tok_arr[:self._ctx_size]
        ca = np.empty(cap, dtype=np.int64); ca[:self._ctx_size] = self._clen_arr[:self._ctx_size]
        self._parent_arr, self._last_tok_arr, self._clen_arr = pa, la, ca
        self._ctx_cap = cap
    
    def _ctx_append_batch(self, parents: np.ndarray, toks: np.ndarray) -> np.ndarray:
        parents = parents.astype(np.int64, copy=False)
        toks = toks.astype(np.int64, copy=False)
        m = int(parents.size)
        self._ctx_reserve(m)
        start = self._ctx_size
        end   = start + m
        self._parent_arr[start:end]   = parents
        self._last_tok_arr[start:end] = toks
        self._clen_arr[start:end]     = self._clen_arr[parents] + 1
        self._ctx_size = end
        need = end - len(self._id2ctx)
        if need > 0:
            self._id2ctx.extend([None] * need)
        return np.arange(start, end, dtype=np.int64)
    
    def intern_many_ctx(self, ctxs: list[tuple[int, ...]]) -> np.ndarray:
        """Return ids for ctxs."""
        self._ensure_ctx_store()
        out = np.empty(len(ctxs), dtype=np.int64)
        ctx2id = getattr(self, "_ctx2id", None)
        if ctx2id is None:
            self._ctx2id = ctx2id = {(): 0}
        by_len = {}
        miss = []
        for i, t in enumerate(ctxs):
            cid = ctx2id.get(t)
            if cid is not None:
                out[i] = cid
            else:
                out[i] = -1
                miss.append(i)
                by_len.setdefault(len(t), []).append(i)
        if not miss:
            return out

        for L in sorted(by_len.keys()):
            idxs = by_len[L]
            if L == 0:
                out[idxs] = 0
                continue
            # parents and last tokens in batch
            parents = np.fromiter((ctx2id[t[:-1]] if t[:-1] in ctx2id else -1 for t in (ctxs[i] for i in idxs)), dtype=np.int64, count=len(idxs))
            # build any missing parents recursively (rare if lengths sorted)
            need_parent = np.flatnonzero(parents == -1)
            if need_parent.size:
                par_ctxs = [ctxs[idxs[j]][:-1] for j in need_parent]
                parents[need_parent] = self.intern_many_ctx(par_ctxs)

            toks = np.fromiter((int(ctxs[i][-1]) for i in idxs), dtype=np.int64, count=len(idxs))
            new_ids = self._ctx_append_batch(parents, toks)
            for i, nid in zip(idxs, new_ids):
                ctx2id[ctxs[i]] = int(nid)
            out[idxs] = new_ids
        return out

    def _intern_ctx_by_parent(self, parent_cid: int, tok: int) -> int:
        """Advance context without ever building a tuple."""
        ch = self._children.get(parent_cid)
        if ch is None:
            ch = self._children[parent_cid] = {}
        nxt = ch.get(tok)
        if nxt is not None:
            return nxt
        nid = len(self._ctx_parent)
        self._ctx_parent.append(parent_cid)
        self._ctx_last_tok.append(tok)
        self._ctx_len.append(self._ctx_len[parent_cid] + 1)
        self._children[nid] = {}
        ch[tok] = nid
        self._id2ctx.append(None)
        return nid

    def _intern_ctx(self, ctx: tuple[int, ...]) -> int:
        self._ensure_ctx_store()
        cid = 0
        for tok in ctx:
            cid = self._intern_ctx_by_parent(cid, tok)
        if self._id2ctx[cid] is None:
            self._id2ctx[cid] = tuple(ctx)
        self._ctx2id[tuple(ctx)] = cid
        return cid

    def _take_many_from_dist(self, dist, toks: np.ndarray) -> np.ndarray:
        if isinstance(dist, np.ndarray):
            return dist[toks].astype(np.float32, copy=False)
        if hasattr(dist, "_p"):  # LazyProb
            arr = np.asarray(dist._p, dtype=np.float32)
            return arr[toks]
        if hasattr(dist, "get"):
            dget = dist.get
            out = np.empty(toks.size, dtype=np.float32)
            for j, t in enumerate(toks):
                out[j] = float(dget(int(t), NEG_INF))
            return out
        return np.fromiter((float(dist.get(int(t), NEG_INF)) if hasattr(dist,"get") else float(dist[int(t)])), dtype=np.float32, count=toks.size)
    

    @profile
    def _advance_ctx_pairs_grouped(self, ctx_ids: np.ndarray, toks: np.ndarray) -> np.ndarray:
        self._ensure_ctx_store()
        ctx_ids = np.asarray(ctx_ids, dtype=np.int64, order="C")
        toks = np.asarray(toks,    dtype=np.int64, order="C")
        keys = (ctx_ids << 32) ^ (toks & self._MASK32)
        uniq_keys, inv = np.unique(keys, return_inverse=True)
        nxt = np.full(uniq_keys.size, -1, dtype=np.int64)
        base_keys = self._edge_keys
        base_vals = self._edge_vals
        if base_keys.size:
            posb = np.searchsorted(base_keys, uniq_keys)
            maskb = (posb < base_keys.size)
            if maskb.any():
                eqb = base_keys[posb[maskb]] == uniq_keys[maskb]
                if np.any(eqb):
                    m = np.zeros_like(maskb)
                    m[maskb] = eqb
                    nxt[m] = base_vals[posb[m]]

        buf_keys = self._buf_keys
        buf_vals = self._buf_vals
        if buf_keys.size:
            posu = np.searchsorted(buf_keys, uniq_keys)
            masku = (posu < buf_keys.size)
            if masku.any():
                equ = buf_keys[posu[masku]] == uniq_keys[masku]
                if np.any(equ):
                    m = np.zeros_like(masku)
                    m[masku] = equ
                    m &= (nxt == -1)
                    nxt[m] = buf_vals[posu[m]]
        miss = (nxt == -1)
        if miss.any():
            mk   = uniq_keys[miss]
            mpar = (mk >> 32).astype(np.int64, copy=False)
            mtok = (mk & self._MASK32).astype(np.int64, copy=False)
            new_ids = self._ctx_append_batch(mpar, mtok)
            nxt[miss] = new_ids
            if buf_keys.size == 0:
                self._buf_keys = mk
                self._buf_vals = new_ids
            else:
                bk = np.concatenate([buf_keys, mk], axis=0)
                bv = np.concatenate([buf_vals, new_ids], axis=0)
                order = np.argsort(bk, kind="mergesort")
                self._buf_keys = bk[order]
                self._buf_vals = bv[order]
            if self._buf_keys.size >= self._edge_flush_thresh:
                allk = np.concatenate([self._edge_keys, self._buf_keys], axis=0)
                allv = np.concatenate([self._edge_vals, self._buf_vals], axis=0)
                order = np.argsort(allk, kind="mergesort")
                self._edge_keys = allk[order]
                self._edge_vals = allv[order]
                self._buf_keys  = np.empty(0, dtype=np.int64)
                self._buf_vals  = np.empty(0, dtype=np.int64)

        return nxt[inv]
    
    @profile
    def _materialize_ctx_tuples_batch(self, ctx_ids: np.ndarray) -> None:
        self._ensure_ctx_store()
        ids = np.asarray(ctx_ids, dtype=np.int64, order="C")
        if ids.size == 0:
            return
        # Only those still missing
        miss = [cid for cid in ids if self._id2ctx[cid] is None]
        if not miss:
            return
        miss = np.asarray(miss, dtype=np.int64)
        lens = self._clen_arr[miss].astype(np.int32, copy=False)
        maxL = lens.max()
        pad = np.full((miss.size, maxL), -1, dtype=np.int32)
        cur = miss.copy()
        for t in range(maxL - 1, -1, -1):
            active = (lens > t)
            if not active.any():
                continue
            pad[active, t] = self._last_tok_arr[cur[active]].astype(np.int32, copy=False)
            cur[active] = self._parent_arr[cur[active]]
        ctx2id = getattr(self, "_ctx2id", None)
        if ctx2id is None:
            self._ctx2id = ctx2id = {(): 0}
        for row, cid in enumerate(miss):
            t = tuple(pad[row, :lens[row]].tolist())
            self._id2ctx[cid] = t
            ctx2id[t] = cid    
    
    def _dist_to_numpy(self, log_dist):
        """Return (tok_ids:int32, tok_lps:float32) with finite entries only."""
        if isinstance(log_dist, np.ndarray):
            arr = log_dist
            ids = np.flatnonzero(np.isfinite(arr)).astype(np.int32)
            return ids, arr[ids]
        if hasattr(log_dist, "_p"):  # LazyProb fast path
            arr = np.asarray(log_dist._p, dtype=np.float32)
            ids = np.flatnonzero(np.isfinite(arr)).astype(np.int32)
            return ids, arr[ids]
        # generic mapping
        ids_, lps_ = [], []
        for tok, lp in log_dist.items():
            if np.isfinite(lp):
                ids_.append(int(tok))
                lps_.append(float(lp))
        if not ids_:
            return np.empty(0, np.int32), np.empty(0, np.float32)
        return np.asarray(ids_, np.int32), np.asarray(lps_, np.float32)
    
    def _intern_ps(self, ps: frozenset[int]) -> int:
        i = self._ps2id.get(ps)
        if i is not None: 
            return i
        i = len(self._id2ps)
        self._ps2id[ps] = i
        self._id2ps.append(ps)
        self._next_by_tok.append(np.zeros(self.num_in_syms, np.int32))
        self._next_prefilled.append(False)
        self._univ_arr = np.append(self._univ_arr, -1)
        return i

    @functools.lru_cache(maxsize=200_000)
    def _prefill_next_for_ps(self, ps_id: int):
        if self._next_prefilled[ps_id]:
            return
        ps = self._id2ps[ps_id]
        next_ids = self._next_by_tok[ps_id]
        # Only fill when needed; default zeros already there.
        for t in range(self.num_in_syms):
            nxt = self._next_pset_after(ps, t)
            if nxt:
                next_ids[t] = self._intern_ps(nxt)
        self._next_prefilled[ps_id] = True

    def _ensure_univ_known(self, ps_id: int):
        if self._univ_arr[ps_id] != -1:
            return
        ps = self._id2ps[ps_id]
        self._univ_arr[ps_id] = 1 if self.is_powerset_universal(ps) else 0

    @profile
    def group_lse_by_id(self, ids: np.ndarray, offs: np.ndarray):
        ids = np.asarray(ids)
        x   = np.asarray(offs, dtype=np.float32)
        if ids.size == 0:
            return ids[:0], x[:0]
        b = np.empty(ids.size, dtype=bool)
        b[0] = True
        b[1:] = ids[1:] != ids[:-1]
        idx  = np.flatnonzero(b)
        uids = ids[idx]
        lse = np.logaddexp.reduceat(x, idx).astype(np.float32, copy=False)
        return uids, lse
    
    def _next_pset_after(self, ps: frozenset[int], tok_id: int) -> frozenset[int]: 
        nxt = set() 
        for s in ps: 
            _arcs, _ilabels, in_syms, _olabels, nextstates = self._arcs(s) 
            if in_syms.size: 
                take = (in_syms == tok_id) 
            if np.any(take): 
                ns = nextstates[take] 
                if ns.size: 
                    nxt.update(int(x) for x in ns) 
        return frozenset(nxt) 
        
    @profile
    async def logp_next_no_symloop(
        self,
        config: Config,
        *,
        context: Optional[List[str]] = None,
        quotient: Optional[List["Beam"]] = None,
        remainder: Optional[List["Beam"]] = None,
        logp_context: Optional[float] = None,
        cover_fct: Optional[Callable] = None,
        next_char: str | None = None,
        top_p=None, # Unused here
        verbose: bool = False
    ) -> Dict[str, float]:
        """
        Optimised logp_next evaluation.

        Handles
        1. Deterministic output produced on input-eps arcs, and
        2. LM-driven output produced after reading the next token.

        Requires:   
            1. Functional WFST
            2. push_labels + rmepsilon already applied
            3. symbol table contains every LM token id as input label
        """
        # Config
        track_powerstates = config.track_powerstates
        expand_threshold = config.expand_threshold
        prune_kwargs = dict(
            thld=config.prune_threshold,
            cand_thld=config.candidate_threshold,
            alpha=config.prune_threshold_alpha,
            max_prune_mass=config.max_prune_mass,
            max_cand=config.max_candidates,
        )
        log_eps_rel = np.float32(np.log(config.stop_epsilon_mass))


        # Get quotient / remainder beams
        context = () if context is None else context
        cover_fct = cover_fct or self.optimal_decomposition
        cfg_hash = _hash_cfg(config)
        if quotient is None or remainder is None:
            remainder, quotient = await cover_fct(context, config=config, cache_result=True)

        # Get context probability
        beam_logp_context = logsumexp([b[LOGP] for b in quotient] + [b[LOGP] for b in remainder]) if quotient or remainder else 0.0

        if verbose:
            _log(True, f"{'='*72}\nOPT logp_next ctx={context}  "
                    f"|Q|={len(quotient)} |R|={len(remainder)}  "
                    f"logP(ctx)={beam_logp_context:.4f}")

        # Accumulate probability mass
        mass = np.full(self.num_out_syms, NEG_INF, dtype=np.float32)

        if self._has_eps_out:
            q_beams, q_psets = [], []
            for beam, out_list in quotient.items():
                if beam[YS] and beam[YS][-1] == self.eos_out:
                    continue
                q_beams.append(beam)
                q_psets.append(frozenset(b[POWERSTATE] for b in out_list))

            if not q_beams:
                return await self.logp_next_symloop(
                    config, context=context, cover_fct=self.optimal_decomposition,
                    logp_context=logp_context, next_char=next_char, verbose=verbose
                )
            q_ctx_ids = self.intern_many_ctx([tuple(b[YS]) for b in q_beams])

            q_cond_lp = np.asarray([b[LOGP] for b in q_beams], dtype=np.float32) - float(beam_logp_context)
            per_beam_counts: list[Dict[tuple[tuple[int,...], int], int]] = []
            per_beam_pssizes: list[int] = []
            for ps in q_psets:
                item = self._local_cache.get(ps)
                if item is None:
                    item = self._pset_frontier_counts(ps)
                    self._local_cache[ps] = item
                counts, pssize = item
                per_beam_counts.append(counts)
                per_beam_pssizes.append(pssize)

            pairs_ctx_ids, pairs_seq, pairs_meta = [], [], []
            zero_len_items = []
            for i, counts in enumerate(per_beam_counts):
                offset = q_cond_lp[i]
                for (in_seq, out_lab), c in counts.items():
                    if len(in_seq) == 0:
                        zero_len_items.append((out_lab, offset))
                    else:
                        pairs_ctx_ids.append(q_ctx_ids[i])
                        pairs_seq.append(in_seq)
                        pairs_meta.append((out_lab, offset))

            for out_lab, off in zero_len_items:
                np.logaddexp.at(mass, out_lab, off)

            if pairs_seq:
                arr = np.asarray(pairs_meta, dtype=[('lab', np.int32), ('off', np.float32)])
                offsets = arr['off'].astype(np.float32, copy=False)

                # Pruning step before scoring
                survivors = _prune_by_logweights(offsets, **prune_kwargs)

                if survivors.size == 0:
                    return {self._in_id_to_sym[i]: float(mass[i]) for i in np.flatnonzero(np.isfinite(mass))}

                # Apply pruning consistently to everything
                arr = arr[survivors]
                pairs_ctx_ids = np.asarray(pairs_ctx_ids, dtype=np.int64)[survivors]
                pairs_seq = [pairs_seq[i] for i in survivors.tolist()]

                # Build flat views for vectorized advance
                cur_ctx_ids = pairs_ctx_ids.astype(np.int64, copy=False)
                N = len(pairs_seq)
                lens = np.fromiter((len(s) for s in pairs_seq), dtype=np.int64, count=N)
                if lens.any():
                    seq_lps = np.zeros(N, dtype=np.float32)
                    tot = lens.sum()
                    tok_flat = np.fromiter((tok for s in pairs_seq for tok in s), dtype=np.int32, count=tot)
                    row_flat = np.repeat(np.arange(N, dtype=np.int64), lens)

                    # depth per flat position
                    depth_flat = np.empty(tot, dtype=np.int64)
                    off = 0
                    for L, _ in zip(*np.unique(lens, return_counts=True)):
                        if L == 0: 
                            continue
                        rows_L = np.flatnonzero(lens == L)
                        blk = slice(off, off + L * rows_L.size)
                        depth_flat[blk] = np.tile(np.arange(L, dtype=np.int64), rows_L.size)
                        off = blk.stop

                    # group by depth once
                    order = np.argsort(depth_flat, kind="mergesort")
                    depth_sorted = depth_flat[order]
                    rows_sorted  = row_flat[order]
                    toks_sorted  = tok_flat[order]

                    chg = np.r_[True, depth_sorted[1:] != depth_sorted[:-1]]
                    starts = np.flatnonzero(chg)
                    ends   = np.r_[starts[1:], depth_sorted.size]

                    # contexts BEFORE each token (flat)
                    ctx_before_flat = np.empty_like(rows_sorted, dtype=np.int64)
                    p = 0
                    for s, e in zip(starts, ends):
                        rows_t = rows_sorted[s:e]
                        toks_t = toks_sorted[s:e]
                        ctx_t = cur_ctx_ids[rows_t]
                        ctx_before_flat[p:p+(e-s)] = ctx_t
                        cur_ctx_ids[rows_t] = self._advance_ctx_pairs_grouped(ctx_t, toks_t)
                        p += (e - s)

                    uniq_ctx_ids, inv = np.unique(ctx_before_flat, return_inverse=True)
                    self._materialize_ctx_tuples_batch(uniq_ctx_ids)
                    uniq_ctx = [self._id2ctx[int(cid)] for cid in uniq_ctx_ids]
                    ctx2dist = await self.gather_ctx2logdist(uniq_ctx, config, cfg_hash)

                    # score by context runs
                    ord2 = np.argsort(inv, kind="mergesort")
                    inv2 = inv[ord2]
                    rows2 = rows_sorted[ord2]
                    toks2 = toks_sorted[ord2]
                    chg2 = np.r_[True, inv2[1:] != inv2[:-1]]
                    s2 = np.flatnonzero(chg2)
                    e2 = np.r_[s2[1:], inv2.size]

                    dists_aligned = tuple(ctx2dist[ctx] for ctx in uniq_ctx)
                    vals_all = np.empty(rows2.shape[0], dtype=np.float32)
                    take = self._take_many_from_dist
                    for (s, e), g in zip(zip(s2, e2), inv2[s2]):
                        dist = dists_aligned[int(g)]
                        vals_all[s:e] = take(dist, toks2[s:e])
                    seq_lps += np.bincount(rows2, weights=vals_all, minlength=N).astype(np.float32)

                    labels = arr['lab']
                    vals   = (arr['off'] + seq_lps).astype(np.float32, copy=False)
                    ord3 = np.argsort(labels, kind='stable')
                    labels, vals = labels[ord3], vals[ord3]
                    uniq, idx = np.unique(labels, return_index=True)
                    ends = np.r_[idx[1:], labels.size]
                else:
                    # no sequences (all zero length)—only offsets contributed
                    labels = arr['lab']
                    vals = arr['off'].astype(np.float32, copy=False)
                    ord3 = np.argsort(labels, kind='stable')
                    labels, vals = labels[ord3], vals[ord3]
                    uniq, idx = np.unique(labels, return_index=True)
                    ends = np.r_[idx[1:], labels.size]

                # grouped stable log-sumexp into mass
                for u, s, e in zip(uniq, idx, ends):
                    m = float(np.max(vals[s:e]))
                    if np.isfinite(m):
                        mass[u] = np.logaddexp(mass[u], m + float(np.log(np.exp(vals[s:e] - m).sum())))
        else:
            # flatten beams, outlists, powerstate sets
            q_beams, q_outlists, q_psets = [], [], []
            for beam, out_list in quotient.items():
                q_beams.append(beam)
                q_outlists.append(out_list)
                q_psets.append(frozenset().union(b[POWERSTATE] for b in out_list))

            if q_beams:
                # Fast path for realpha
                if not track_powerstates:
                    # normalize beam logps once
                    q_cond_lp = np.asarray([b[LOGP] for b in q_beams], dtype=np.float32) - beam_logp_context

                    # map unique powerstate -> index
                    ps2idx, uniq_ps = {}, []
                    q_ps_idx = np.empty(len(q_beams), dtype=np.int32)
                    for i, ps in enumerate(q_psets):
                        j = ps2idx.get(ps)
                        if j is None:
                            j = len(uniq_ps)
                            ps2idx[ps] = j
                            uniq_ps.append(ps)
                        q_ps_idx[i] = j

                    y_eps_unique = np.empty(len(uniq_ps), dtype=np.int32)
                    for j, ps in enumerate(uniq_ps):
                        try:
                            y_eps_unique[j] = self.first_symbol_epsin_set(ps)
                        except ValueError:
                            y_eps_unique[j] = -1
                    y_eps = y_eps_unique[q_ps_idx]

                    # accumulate all ε-in at once
                    eps_mask = (y_eps >= 0)
                    if np.any(eps_mask):
                        np.logaddexp.at(mass, y_eps[eps_mask], q_cond_lp[eps_mask])
                        if verbose:
                            i0 = int(np.flatnonzero(eps_mask)[0])
                            sym = self._out_id_to_sym.get(int(y_eps[i0]), int(y_eps[i0]))
                            _log(True, f"  ε-in beam → '{sym}'  lp={q_cond_lp[i0]:.4f}")

                    # remaining beams need LM distributions
                    need_idx = np.flatnonzero(~eps_mask)
                    if need_idx.size:
                        # batch LM calls
                        dists = await asyncio.gather(
                            *(self._get_cached_log_dist(q_beams[i][YS], backend=config.backend, cfg_hash=cfg_hash) for i in need_idx)
                        )
                        # collect per-beam arrays
                        tok_ids_list, tok_lps_list, ps_idx_list, condlp_list = [], [], [], []
                        for i, dist in zip(need_idx, dists):
                            ids, lps = self._dist_to_numpy(dist)
                            if ids.size == 0:
                                continue
                            tok_ids_list.append(ids)
                            tok_lps_list.append(lps)
                            ps_idx_list.append(q_ps_idx[i])
                            condlp_list.append(q_cond_lp[i])

                        if tok_ids_list:
                            ps_idx_arr  = np.asarray(ps_idx_list, dtype=np.int32)
                            condlp_arr  = np.asarray(condlp_list, dtype=np.float32)

                            # group by powerstate id so we call first_symbol_vectorized once per group
                            order = np.argsort(ps_idx_arr, kind="stable")
                            ps_idx_arr = ps_idx_arr[order]
                            condlp_arr = condlp_arr[order]
                            tok_ids_list = [tok_ids_list[i] for i in order]
                            tok_lps_list = [tok_lps_list[i] for i in order]

                            split = np.flatnonzero(np.diff(ps_idx_arr)) + 1
                            starts = np.r_[0, split]
                            ends = np.r_[split, ps_idx_arr.size]
                            for s, e in zip(starts, ends):
                                g_idx = int(ps_idx_arr[s])
                                g_ps  = uniq_ps[g_idx]

                                ids_concat = np.concatenate(tok_ids_list[s:e])
                                lps_concat = np.concatenate(tok_lps_list[s:e])

                                # repeat beam cond_lps per-token
                                counts = np.fromiter((a.size for a in tok_ids_list[s:e]), dtype=np.int32)
                                if counts.size == 0 or ids_concat.size == 0:
                                    continue
                                cond_rep = np.repeat(condlp_arr[s:e], counts)
                                y_vec = self.first_symbol_vectorized(g_ps, ids_concat)  # vectorized per powerstate
                                good = (y_vec >= 0)
                                if not np.any(good):
                                    continue

                                np.logaddexp.at(mass, y_vec[good], cond_rep[good] + lps_concat[good])
                            if verbose:
                                kept0 = int(np.sum(self.first_symbol_vectorized(uniq_ps[int(ps_idx_arr[0])], tok_ids_list[0]) >= 0))
                                tot0  = int(tok_ids_list[0].size)
                                _log(True, f"  routed {kept0}/{tot0} tokens ({kept0/max(1,tot0):.1%}) for first LM beam")
                
                else:
                    # normalize beam logps once
                    q_cond_lp = np.asarray([b[LOGP] for b in q_beams], dtype=np.float32) - beam_logp_context
                    ctx_len = len(context) if context is not None else 0

                    exceed_mask = np.zeros(len(q_beams), dtype=bool)
                    ex_syms, ex_lps = [], []

                    # We now check for exceeded-context outputs, where we can directly aggregate the probability
                    for i, out_list in enumerate(q_outlists):
                        if not out_list:
                            continue
                        uids = set()
                        any_exceeded = False
                        for b in out_list:
                            bo = b[BEAM_OUT]
                            if len(bo) > ctx_len:
                                any_exceeded = True
                                y = bo[ctx_len]
                                if isinstance(y, (int, np.integer)):
                                    y_id = int(y)
                                else:
                                    y_id = int(self._out_sym_to_id.get(y, -1))
                                if y_id >= 0:
                                    uids.add(y_id)

                        if any_exceeded and uids:
                            exceed_mask[i] = True
                            add_lp = q_cond_lp[i]# - np.log(len(uids))
                            for y_id in uids:
                                ex_syms.append(y_id)
                                ex_lps.append(add_lp)

                    # Accumulate exceeded-context contributions
                    if ex_syms:
                        np.logaddexp.at(
                            mass,
                            np.asarray(ex_syms, dtype=np.int32),
                            np.asarray(ex_lps, dtype=np.float32),
                        )
                        if verbose:
                            sym0 = int(ex_syms[0])
                            sym_repr = self._out_id_to_sym.get(sym0, sym0)
                            _log(True, f"  exceeded→ '{sym_repr}'  beams={int(exceed_mask.sum())} (excluded from ε/LM)")

                    # Keep only beams that did NOT exceed context for ε-in / LM routing
                    keep_idx = np.flatnonzero(~exceed_mask)
                    if keep_idx.size == 0:
                        # Nothing else to do this step
                        pass
                    else:
                        # Map through the "kept" arrays
                        q_cond_lp_keep = q_cond_lp[keep_idx]
                        q_psets_keep = [q_psets[i] for i in keep_idx]
                        q_beams_keep = [q_beams[i] for i in keep_idx]
                        ps2idx, uniq_ps = {}, []
                        q_ps_idx = np.empty(len(q_beams_keep), dtype=np.int32)
                        for i, ps in enumerate(q_psets_keep):
                            j = ps2idx.get(ps)
                            if j is None:
                                j = len(uniq_ps)
                                ps2idx[ps] = j
                                uniq_ps.append(ps)
                            q_ps_idx[i] = j

                        y_eps_unique = np.empty(len(uniq_ps), dtype=np.int32)
                        for j, ps in enumerate(uniq_ps):
                            try:
                                y_eps_unique[j] = self.first_symbol_epsin_set(ps)
                            except ValueError:
                                y_eps_unique[j] = -1
                        y_eps = y_eps_unique[q_ps_idx]

                        # accumulate all ε-in at once
                        eps_mask = (y_eps >= 0)
                        if np.any(eps_mask):
                            np.logaddexp.at(mass, y_eps[eps_mask], q_cond_lp_keep[eps_mask])
                            if verbose:
                                i0 = int(np.flatnonzero(eps_mask)[0])
                                sym = self._out_id_to_sym.get(int(y_eps[i0]), int(y_eps[i0]))
                                _log(True, f"  ε-in beam → '{sym}'  lp={q_cond_lp_keep[i0]:.4f}")

                        # remaining kept beams need LM distributions
                        need_idx = np.flatnonzero(~eps_mask)
                        if need_idx.size:
                            dists = await self.gather_ctx2logdist([q_beams_keep[i][YS] for i in need_idx], config=config, cfg_hash=cfg_hash)
                            
                            # Frontier items need to be extended until the powerstate is universal.
                            frontier: list[tuple[frozenset[int], tuple[int, ...], float, int]] = []

                            # Seed frontier from the FIRST token
                            for i in need_idx:
                                base_lp = q_cond_lp_keep[i]
                                ps0 = uniq_ps[q_ps_idx[i]]
                                ctx0 = q_beams_keep[i][YS]
                                ids = np.flatnonzero(np.isfinite(dists[ctx0]))
                                if ids.size == 0:
                                    continue
                                lps =  dists[ctx0][ids]

                                # Only tokens that actually produce a FIRST output from the initial powerstate 
                                y_vec = self.first_symbol_vectorized(ps0, ids)
                                good  = (y_vec >= 0)
                                if not np.any(good):
                                    continue

                                ids, lps, y_vec = ids[good], lps[good], y_vec[good]
                                states_arr = self._states_arr(ps0)
                                uniq_tok, inv = np.unique(ids, return_inverse=True)

                                # Get unique ys
                                y_uniq = self.first_symbol_vectorized(ps0, uniq_tok)
                                M = self._first[np.ix_(states_arr, uniq_tok)]
                                valid  = M >= 0
                                match  = (M == y_uniq[np.newaxis, :])
                                frac_uniq = (match.sum(0) / np.maximum(valid.sum(0), 1)).astype(np.float32)
                                frac = frac_uniq[inv]

                                # Apply support fraction as a log penalty
                                offs = (base_lp + lps + np.log(frac)).astype(np.float32, copy=False)
                                uniq_tok = np.unique(ids)
                                # Precompute once per token for this powersate:
                                y_for_tok = self.first_symbol_vectorized(ps0, uniq_tok)
                                nxt_map = {t: self._next_pset_after_with_y(ps0, t, y) for t, y in zip(uniq_tok, y_for_tok)}

                                for t, lp, y in zip(ids, offs, y_vec):
                                    ps1 = nxt_map[t]
                                    if not ps1:
                                        continue
                                    if self.is_powerset_universal(ps1):
                                        np.logaddexp.at(mass, y, lp)
                                    else:
                                        frontier.append((ps1, ctx0 + (t,), lp, y))

                            if frontier:
                                best = {}
                                dedup = []
                                for ps, ctx, acc_lp, y0 in frontier:
                                    ps_id = self._intern_ps(ps)
                                    key = (ps_id, ctx)
                                    prev = best.get(key)
                                    if (prev is None) or (acc_lp > prev[0]):
                                        best[key] = (acc_lp, y0, ps)
                                dedup = [(ps, ctx, lp, y0) for (ps_id, ctx), (lp, y0, ps) in best.items()]
                                lps0 = np.fromiter((lp for _, _, lp, _ in dedup), dtype=np.float32)
                                keep_idx = _prune_by_logweights(lps0, **prune_kwargs)
                                frontier = [dedup[i] for i in keep_idx]

                            def lse_1d(x):
                                if not x:
                                    return np.float32(NEG_INF)
                                return np.logaddexp.reduce(np.asarray(x, dtype=np.float32))

                            steps = 0
                            M_total = np.float32(NEG_INF) # running logsum of collected universal mass

                            while frontier and steps < expand_threshold:
                                steps += 1

                                # Early stop bofore expanding based on mass
                                frontier_lps = [it[2] for it in frontier]
                                R_bound = lse_1d(frontier_lps) # upper bound on leftover mass
                                if R_bound - np.logaddexp(R_bound, M_total) <= log_eps_rel:
                                    break

                                by_ctx = defaultdict(list)
                                for ps, ctx, acc_lp, y0 in frontier:
                                    by_ctx[ctx].append((ps, acc_lp, y0))
                                uniq_ctx = list(by_ctx.keys())

                                ctx2dist = await self.gather_ctx2logdist(uniq_ctx, config, cfg_hash)

                                new_frontier = []
                                new_frontier_lps = []
                                added_univ_lps = []  # track universal mass increments to update M_total

                                for ctx, items in by_ctx.items():
                                    ids = np.flatnonzero(np.isfinite(ctx2dist[ctx]))
                                    if ids.size == 0:
                                        continue
                                    lps = ctx2dist[ctx][ids]

                                    active_ctx, scores_keep_base = self.group_lse_by_id(ids, lps)
                                    if active_ctx.size == 0:
                                        continue

                                    # keep_rel_idx_ctx = _prune_by_logweights(scores_ctx_base, **prune_kwargs)
                                    # if keep_rel_idx_ctx.size == 0:
                                    #     continue
                                    # active_ctx = active_all_ctx[keep_rel_idx_ctx]
                                    # scores_keep_base = scores_ctx_base[keep_rel_idx_ctx]

                                    for ps, acc_lp, y0 in items:
                                        scores_keep = scores_keep_base + np.float32(acc_lp)
                                        ps_id = self._intern_ps(ps)
                                        if self._next_prefilled[ps_id]:
                                            next_ids = self._next_by_tok[ps_id][active_ctx]
                                        else:
                                            u_toks, inv = np.unique(active_ctx, return_inverse=True)
                                            nids_u = np.empty(u_toks.size, dtype=np.int32)
                                            for i, t in enumerate(u_toks):
                                                nxt = self._next_pset_after(ps, t)
                                                nids_u[i] = 0 if not nxt else self._intern_ps(nxt)
                                            self._next_by_tok[ps_id][u_toks] = nids_u
                                            self._next_prefilled[ps_id] = True
                                            next_ids = nids_u[inv]  # map back to active_ctx order

                                        valid = (next_ids > 0)
                                        if not valid.any():
                                            continue

                                        # safer vector indexing
                                        labels = np.take(self._univ_arr, next_ids)
                                        unknown_mask = valid & (labels == -1)
                                        if unknown_mask.any():
                                            for nid in np.unique(next_ids[unknown_mask]).tolist():
                                                self._ensure_univ_known(int(nid))
                                            labels = np.take(self._univ_arr, next_ids)

                                        univ = valid & (labels == 1)
                                        nonu = valid & ~univ

                                        if univ.any():
                                            total = np.logaddexp.reduce(scores_keep[univ])
                                            mass[y0] = np.logaddexp(mass[y0], total)
                                            added_univ_lps.append(total) # collect for global M update

                                        if nonu.any():
                                            toks_keep = active_ctx[nonu]
                                            nids_keep = next_ids[nonu]
                                            lps_keep  = scores_keep[nonu]
                                            new_frontier.extend(
                                                (self._id2ps[nid], ctx + (t,), lp_t, y0)
                                                for t, nid, lp_t in zip(toks_keep, nids_keep, lps_keep)
                                            )
                                            new_frontier_lps.extend(lps_keep.tolist())

                                if not new_frontier:
                                    break

                                # Update global collected mass
                                if added_univ_lps:
                                    M_total = np.logaddexp(M_total, lse_1d(added_univ_lps))

                                R_tight = lse_1d(new_frontier_lps)
                                if (R_tight - np.logaddexp(R_tight, M_total) <= log_eps_rel):
                                    break

                                frontier = new_frontier

        # Unused and not tested
        if remainder and not config.ignore_remainder:
            eos_id = self._out_sym_to_id.get(self.eos_out)
            if eos_id is not None:
                rem_cond = np.asarray([b[LOGP] for b in remainder], dtype=np.float32) - beam_logp_context
                rem_dists = await asyncio.gather(
                    *(self._get_cached_log_dist(b[YS], backend=config.backend, cfg_hash=cfg_hash) for b in remainder)
                )

                # try to read EOS from all dist types
                def _eos_logp(d):
                    if isinstance(d, torch.Tensor):
                        # if LM is over input ids, and eos maps to that space, map to its id:
                        # fall back to None if unknown
                        try:
                            return float(d.detach().cpu().numpy()[eos_id])
                        except Exception:
                            return None
                    if hasattr(d, "_p"):
                        try:
                            return float(np.asarray(d._p, dtype=np.float32)[eos_id])
                        except Exception:
                            return None
                    return float(d[self.eos_out]) if (hasattr(d, "get") and self.eos_out in d) else None

                lp_eos = [c + v for c, d in zip(rem_cond, rem_dists) if (v := _eos_logp(d)) is not None]
                if lp_eos:
                    np.logaddexp.at(mass, eos_id, logsumexp(lp_eos))
        
        valid = mass
        finite = np.isfinite(valid)

        # Fallback to original logp_next if no valid mass
        if not np.any(finite):
            _log(True, "No valid mass, falling back to original logp_next")
            return await self.logp_next_symloop(
                    config, context=context, logp_context=logp_context, next_char=next_char,
                )
        # Renormalize distrbution
        m  = float(valid[finite].max())
        logZ = m + np.log(np.exp(valid[finite] - m).sum())
        return {
            i: (
                float(mass[i] - logZ) if np.isfinite(mass[i]) else float("-inf")
            )
            for i in range(1, self.num_out_syms)
        }

    async def logp_next_symloop(
        self,
        config: "Config",
        context: Optional[List[str]] = None,
        quotient: Optional[List["Beam"]] = None,
        remainder: Optional[List["Beam"]] = None,
        logp_context: Optional[float] = None,
        cover_fct: Optional[Callable] = None,
        next_char: str = None,
        top_p: Optional[int] = 1.0, # 0.99 or 0.999
        ngram_alpha: float = 1e-5,
        verbose: bool = False
    ) -> Dict[str, float]:
        """
        Return {symbol: logp(symbol | context)} for every non-ε output symbol.
        ORIGINAL VERSION - computes decomposition for each symbol.
        """
        context = () if context is None else context
        cover_fct = cover_fct or self.optimal_decomposition
        
        _log(verbose, "="*80)
        _log(verbose, f"ORIGINAL logp_next: context={context}, next_char={next_char}")
        
        if logp_context is None:
            _log(verbose, "Computing logp_context...")
            logp_context = await self.logp(
                context=context,
                quotient=quotient,
                remainder=remainder,
                config=config,
                cover_fct=cover_fct
            )
            _log(verbose, f"logp_context = {logp_context:.6f}")
        elif np.isinf(logp_context):
            _log(verbose, "logp_context is inf, backtracking...")
            logp_context = await self.backtrack_if_nan(config=config, context=context)
        all_ids = self._all_valid_ids
        num_out_syms =self.num_out_syms

        if top_p >= 1.0:
            cand_ids = all_ids
        else:
            q_full = ngram_probs(context, self._ngram, alpha=ngram_alpha)
            sorted_ids = np.argsort(q_full)[::-1]
            cumsum = np.cumsum(q_full[sorted_ids])
            cutoff = bisect.bisect_left(cumsum, top_p) + 1
            cand_ids = sorted_ids[:cutoff].tolist()

        id2sym = self._out_id_to_sym
        scores: Dict[str, float] = {}
        if config.batched:
            for sid in cand_ids:
                lp = await self.logp(config=config, context=context+(sid,), cache_result=(sid==next_char))
                scores[sid] = lp - logp_context
        else:
            async def _worker(sid: str):
                lp = await self.logp(config=config, context=context+(sid,), cache_result=(sid==next_char))
                return sid, lp - logp_context
            pairs = await asyncio.gather(*(_worker(s) for s in cand_ids))
            scores = dict(pairs)

        if len(cand_ids) == num_out_syms or top_p >= 1.0:
            vals = np.fromiter(scores.values(), dtype=np.float32)
            m = float(vals.max())
            logZ = m + np.log(np.exp(vals - m).sum())
            return {sid: lp -logZ for sid, lp in scores.items()}

        log_u = np.fromiter(scores.values(), dtype=np.float32)
        logS_head = logsumexp(log_u)

        head_mass = float(q_full[cand_ids].sum())
        log_tail_mass = np.log1p(-head_mass)
        log_total = np.logaddexp(logS_head, log_tail_mass)
        out: Dict[str, float] = {}
        for sid, log_u_i in zip(scores.keys(), log_u):
            out[id2sym[sid]] = float(log_u_i - log_total)
        cand_ids_a = np.asarray(cand_ids, dtype=np.int32)
        cand_scores = np.fromiter(
            (scores[sid] for sid in cand_ids_a),
            dtype=np.float32,
            count=cand_ids_a.size
        )
        out_vals = np.log(q_full.astype(np.float32) + 1e-300) - log_total
        out_vals[cand_ids_a] = cand_scores - log_total
        return {sid: float(out_vals[sid]) for sid in all_ids}
    

    async def logp(
            self, 
            config: Config,
            context: Optional[Tuple[int]] = (),
            quotient: Optional[List[Beam]] = None, 
            remainder: Optional[List[Beam]] = None, 
            cover_fct: Optional[Callable]= None,
            cache_result: bool = True
        ):
        """
        Compute the logp of a decomposition 
        """
        key = (context, _hash_cfg(config))
        if key in self._cover_logp_cache:
            return self._cover_logp_cache[key]
        if quotient is None or remainder is None:
            cover_fct = self.optimal_decomposition if not cover_fct else cover_fct
            remainder, quotient  = await cover_fct(
                context,
                config=config,
                cache_result = cache_result
        )
        quotient_logp = logsumexp([beam[LOGP] for beam in quotient])
        if not remainder or config.ignore_remainder:
            if cache_result:
                self._cover_logp_cache[key] = quotient_logp
            return quotient_logp
        _log(config.verbose, "Not ignoring remainder")
        eos_logp = await self.eos_prob(
            context=context,
            config=config,
            cover_fct=cover_fct,
        )
        _log(config.verbose, f"EOS logp {eos_logp}")
        remainder_logp = logsumexp([beam[LOGP] for beam in remainder])
        reweighted_remainder = remainder_logp + eos_logp
        final_logp = logsumexp([quotient_logp, reweighted_remainder])
        if cache_result:
            self._cover_logp_cache[key] = final_logp
        return final_logp