from __future__ import annotations

from tokenizer_conversion.machines.transducedLM import TransducedLM
from tokenizer_conversion.machines.utils.symbols import build_label_acceptor_bytes, build_label_acceptor
from tokenizer_conversion.machines.utils.utils import _log

import pynini


from functools import reduce
from typing import List, Sequence

def _build_input_acceptor(
    symbols: Sequence[int],
    symtable: pynini.SymbolTable,
    use_bytes: bool = True,
    verbose: bool = False,
) -> pynini.Fst:
    """Return an acceptor FST over symbols using ``build_label_acceptor_bytes``.
    """

    acceptors: List[pynini.Fst] = []
    for s in symbols:
        label = symtable.find(s)
        if label == -1:
            raise ValueError(f"Character {s!r} not found in symbol table.")
        if use_bytes:
            acceptors.append(build_label_acceptor_bytes(label))
        else:
            acceptors.append(build_label_acceptor(label))

    if not acceptors:
        # Empty sequence -> empty automaton
        return pynini.Fst()

    _log(verbose, "Concatenating", len(acceptors), "acceptor fragments …")
    return reduce(lambda a, b: a + b, acceptors)

def _build_sigma_star(symtable: pynini.SymbolTable) -> pynini.Fst:
    """Return sigma* as an FST accepting any string over ``symtable``."""
    fst = pynini.Fst()
    fst.add_state()
    fst.set_start(0)
    fst.set_final(0, pynini.Weight.one(fst.weight_type()))

    for symbol in symtable:
        label = symbol[0]
        if label:
            fst.add_arc(0, pynini.Arc(label, label, pynini.Weight.one(fst.weight_type()), 0))

    fst.set_input_symbols(symtable)
    fst.set_output_symbols(symtable)
    return fst

def precover(
    fst: "TransducedLM",
    sequence: str,
    llm,
    sigma: Sequence[int] | None = None,
    *,
    determinize: bool = True,
    comp_universal: bool = True,
    use_bytes: bool = True,
    rmrps: bool=True,
    use_llm: bool = False,
    verbose: bool = False,
    visualize: bool = False,

) -> "TransducedLM":
    """Construct the preprecover FST ``fst @ (sigma + sigma*)``.

    Params
    ----------
    fst
        `PyniniFST` with transducer.
    sequence
        Input sequence used to derive Σ when sigma is None.
    llm
        Same language model handle forwarded to the returned wrapper.
    sigma
        Optional set of "seen" output symbols. When None, it's computed via
        ``fst.apply``.
    determinize
        Optional determinize precover fst
    comp_universal
        Optional recompute universal states
    verbose, visualize
        Debugging.
    """
    if use_bytes:
        from tokenizer_conversion.machines.transducedLM import TransducedLM
    else: 
        raise NotImplementedError
    if sigma is None:
        sigma = fst.apply(sequence, input_tokens=None, use_llm=use_llm, bytes_input=use_bytes, join_char=False)

    if visualize:
        import os
        os.makedirs("precover", exist_ok=True)
        fst.visualize(f"precover/{sequence}_original.svg")

    
    _log(verbose, "Constructing input acceptor …")
    input_acceptor = _build_input_acceptor(sigma, fst.fst.output_symbols(), verbose=verbose, use_bytes=use_bytes)
    input_acceptor.set_input_symbols(fst.fst.output_symbols())
    input_acceptor.set_output_symbols(fst.fst.output_symbols())

    if visualize and input_acceptor.num_states():
        TransducedLM(input_acceptor, llm=llm).visualize(f"precover/{sequence}_input.svg")

    _log(verbose, "Constructing sigma* …")
    sigma_star = _build_sigma_star(fst.fst.output_symbols())
    if visualize:
        TransducedLM(sigma_star, llm=llm).visualize(f"precover/{sequence}_sigma_star.svg")

    _log(verbose, "Composing input acceptor with sigma* …")
    composed = (input_acceptor + sigma_star) if input_acceptor.num_states() else sigma_star
    composed.set_input_symbols(fst.fst.output_symbols())
    composed.set_output_symbols(fst.fst.output_symbols())

    if visualize:
        TransducedLM(composed, llm=llm).visualize(f"precover/{sequence}_composed.svg")

    _log(verbose, "Composing with original FST …")
    pre_p = fst.fst @ composed
    pre_p.set_input_symbols(fst.fst.input_symbols())
    pre_p.set_output_symbols(fst.fst.output_symbols())

    projected = pre_p.copy()
    projected.project("input")

    if determinize:
        p = pynini.determinize(pynini.rmepsilon(projected))
    elif rmrps:
        p = pynini.rmepsilon(projected)
    else: 
        p = projected

    wrapped_p = TransducedLM(p, llm=llm)
    if comp_universal:
        wrapped_p.compute_universal_states()
    else: 
        # Debugging realpha
        wrapped_p.universal_states = {s: True for s in p.states()}
    wrapped_p.eos_out = fst.eos_out
    if hasattr(fst, "special_tokens"):
        wrapped_p.special_tokens = fst.special_tokens

    if visualize:
        wrapped_p.visualize(f"precover/{sequence}_P.svg")

    return wrapped_p
