import pynini
from transformers import AutoTokenizer


from genlm.backend import load_model_by_name
from genlm.backend.tokenization.bytes import get_byte_vocab
from tokenizer_conversion.machines.transducedLM import TransducedLM


def build_hf_fst_bytes(hf_tokenizer: str | AutoTokenizer, llm=None, llm_name: str | None = None) -> TransducedLM:    
    if llm is None:
        if llm_name is None:
            raise ValueError("Either 'llm' or 'llm_name' must be provided.")
        llm = load_model_by_name(llm_name)
    
    print(llm_name)


    tokenizer = llm.tokenizer
    input_symtab  = pynini.SymbolTable(name="input_symtab")
    output_symtab = pynini.SymbolTable(name="output_symtab")

    
    byte_vocab = get_byte_vocab(tokenizer)
    tokens = sorted(tokenizer.vocab.items(), key=lambda x: x[1])
    
    max_token_id = len(tokens)-1
    # Special tokens
    EPS = str(max_token_id+1)
    input_symtab.add_symbol(EPS, 0)
    output_symtab.add_symbol(EPS, 0)

    EOS_IN  = str(llm.tokenizer.eos_token_id)
    EOS_OUT = str(max_token_id+2)
    output_symtab.add_symbol(EOS_OUT)

    # Reserve all 256 byte values on the output side
    for b in range(256):
        output_symtab.add_symbol(str(b), b + 1)

    T = pynini.Fst()
    start_state = T.add_state()
    T.set_start(start_state)
    T.set_final(start_state)


    def add_arc(src: int, ilabel: str, olabel: str, dst: int) -> None:
        """Helper that handles ε and symbol-table lookup/creation."""
        i_id = 0 if ilabel in (EPS, "") else (
            input_symtab.find(ilabel)
            if input_symtab.find(ilabel) >= 0
            else input_symtab.add_symbol(ilabel)
        )
        o_id = 0 if olabel in (EPS, "") else output_symtab.find(olabel)
        arc = pynini.Arc(i_id, o_id, pynini.Weight.one("tropical"), dst)
        T.add_arc(src, arc)
    
    add_arc(start_state, EOS_IN, EOS_OUT, start_state)

    for bytes_vals, (token_str, token) in zip(byte_vocab, tokens):
        if token_str in tokenizer.all_special_tokens:
            continue

        current_state = start_state
        current_input = token # Token id will be the input
        for idx, byte_val in enumerate(bytes_vals):
            next_state = start_state if idx == len(bytes_vals)-1 else T.add_state()
            add_arc(current_state,
                    str(current_input) if idx == 0 else EPS, # Needs to be a string unfortunately
                    str(byte_val),
                    next_state)
            current_state = next_state
            current_input = EPS
    

    T.set_input_symbols(input_symtab)
    T.set_output_symbols(output_symtab)
    T.optimize()

    fst = TransducedLM(T, llm=llm)
    fst.eos_out = int(EOS_OUT)
    fst.universal_states = {s: True for s in T.states()}
    fst.name = "hf_realpha"
    return fst