import pynini
from pynini import SymbolTable

def attach_symbol_table(fst: pynini.Fst) -> pynini.Fst:
    symtab_in = SymbolTable()
    symtab_out = SymbolTable()
    symtab_in.add_symbol("ε", 0)
    symtab_out.add_symbol("ε", 0)
    for state in fst.states():
        for arc in fst.arcs(state):
            # Add input symbols to the input symbol table
            if arc.ilabel != 0:
                symtab_in.add_symbol(chr(arc.ilabel), arc.ilabel)
                try:
                    char = chr(arc.ilabel)
                    symtab_out.add_symbol(char, arc.ilabel)
                except ValueError:
                    print(f"Invalid character for input label {arc.ilabel}")
            # Add output symbols to the output symbol table
            if arc.olabel != 0:
                try:
                    char = chr(arc.olabel)
                    symtab_out.add_symbol(char, arc.olabel)
                except ValueError:
                    print(f"Invalid character for output label {arc.olabel}")
    fst.set_input_symbols(symtab_in)
    fst.set_output_symbols(symtab_out)
    return fst


def build_label_acceptor(label: int) -> pynini.Fst:
    label_fst = pynini.Fst()
    start_state = label_fst.add_state()
    label_fst.set_start(start_state)
    end_state = label_fst.add_state()
    label_fst.add_arc(
        start_state,
        pynini.Arc(label, label, pynini.Weight.one(label_fst.weight_type()), end_state)
    )
    label_fst.set_final(end_state)
    return label_fst

def build_label_acceptor_bytes(label: int) -> pynini.Fst:
    label_fst = pynini.Fst()
    s0 = label_fst.add_state()
    label_fst.set_start(s0)
    s1 = label_fst.add_state()
    label_fst.add_arc(
        s0,
        pynini.Arc(label, label, pynini.Weight.one("tropical"), s1)
    )
    label_fst.set_final(s1)
    return label_fst