import pynini
from tokenizer_conversion.machines.transducedLM import TransducedLM
from tokenizer_conversion.machines.hf_realpha import build_hf_fst_bytes

def construct_lowercase_bytes(model_name, EOS_OUT="257", EPS="258"):
    symtab = pynini.SymbolTable()
    symtab.add_symbol(EPS)
    for b in range(256):
        symtab.add_symbol(str(b), b + 1)
    if EOS_OUT:
        symtab.add_symbol(EOS_OUT)
        EOS_OUT_ID = symtab.find(EOS_OUT)
    sigma = pynini.union(
        *[pynini.accep(str(b), token_type=symtab) for b in range(256)]
    ).optimize()
    sigma_star = sigma.closure()
    rules = []
    for b in range(256):
        if 65 <= b <= 90:
            rules.append(pynini.cross(pynini.accep(str(b), token_type=symtab), pynini.accep(str(b+32), token_type=symtab)))
        else:
            rules.append(pynini.cross(pynini.accep(str(b), token_type=symtab), pynini.accep(str(b), token_type=symtab)))

    byte_to_lower = pynini.union(*rules).optimize()
    lower_fst = pynini.cdrewrite(byte_to_lower, "", "", sigma_star).optimize()
    if EOS_OUT:
        arc = pynini.Arc(EOS_OUT_ID, EOS_OUT_ID, pynini.Weight.one("tropical"), lower_fst.start())
        lower_fst.add_arc(lower_fst.start(), arc)
    lower_fst.set_input_symbols(symtab)
    lower_fst.set_output_symbols(symtab)

    # Build the Transduced LM
    wrapped = TransducedLM(lower_fst, llm_name=model_name, use_genlm=True)
    wrapped.compute_universal_states(verbose=False)
    wrapped.name = "lowercase"
    return wrapped


def construct_lowercase_bytes_composed(model_name, EOS_OUT="⭑"):
    hf = build_hf_fst_bytes(hf_tokenizer=model_name, llm_name=model_name)
    eps = hf.eps
    EOS_OUT = hf.eos_out
    lc = construct_lowercase_bytes(model_name, EOS_OUT=EOS_OUT, EPS=eps)
    composed = hf.fst @ lc.fst
    composed.set_input_symbols(hf.fst.input_symbols())
    composed.set_output_symbols(lc.fst.output_symbols())
    composed.optimize()
    fst = TransducedLM(composed, llm=hf.llm)
    del hf
    fst.eos_out = EOS_OUT
    fst.universal_states = {s: True for s in fst.fst.states()}
    #test = "Hi How Are you"
    #transduced_sequence = fst.apply(test, use_llm=True, join_char=False)
    #print(transduced_sequence)
    return fst
