import pynini
import operator
from functools import reduce
from tokenizer_conversion.machines.transducedLM import TransducedLM

def b(ch: str, symbols) -> pynini.Fst:
    """ 
        Converts a character to bytes
    """
    pieces = (pynini.accep(str(bt), token_type=symbols)
            for bt in ch.encode("utf-8"))
    return reduce(operator.add, pieces)


def bs(text: str, symbols) -> pynini.Fst:
    """ 
        Converts characterS to byteS
    """
    return reduce(operator.add, (b(c, symbols) for c in text), pynini.accep(""))


def construct_bad_ungood(model_name):

    # Build the fst on the right in figure 2
    symtab = pynini.SymbolTable()
    symtab.add_symbol("257", 0) 

    for b in range(256):
        symtab.add_symbol(str(b), b + 1)

    sigma = pynini.union(
        *[pynini.accep(str(b), token_type=symtab) for b in range(256)]
    ).optimize()
    sigma_star = sigma.closure()

    bad = bs("bad", symtab)     # sequence 62 61 64
    ungood = bs("ungood", symtab) 
    sub = pynini.cross(bad, ungood)
    replace_bad = pynini.cdrewrite(sub, "", "", sigma_star).optimize()
    replace_bad.set_input_symbols(symtab)
    replace_bad.set_output_symbols(symtab)
    # | Machine constructed

    # Build the Transduced LM
    wrapped = TransducedLM(replace_bad, llm_name=model_name, use_genlm=True)
    wrapped.compute_universal_states(verbose=False)
    wrapped.name = "bad_ungood"
    return wrapped