import pynini
from genlm.backend import load_model_by_name
from tokenizer_conversion.machines.transducedLM import TransducedLM

genetic_code = {
    'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L',
    'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L',
    'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M',
    'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V',
    'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
    'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
    'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
    'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
    'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*',
    'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q',
    'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K',
    'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E',
    'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
    'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
    'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
    'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G',
}
aa2dna = {v: k for k, v in genetic_code.items()}

def build_dna_to_aa_fst_bytes(llm=None, llm_name=None):
    if llm is None:
        llm = load_model_by_name(llm_name)
    EPS = 0
    EOS_STR = "0"

    in_sym = pynini.SymbolTable("bytes_in")
    all_ins = list("ACGT")
    max_in = len(all_ins)
    in_sym.add_symbol(str(max_in+1), EPS)
    codon_map = {}
    in_counter = 1
    for ch in "ACGT":
        b = in_counter
        in_sym.add_symbol(str(b), b)
        codon_map[ch] = str(in_counter)
        in_counter += 1
    in_sym.add_symbol(EOS_STR, max_in+1)

    out_sym = pynini.SymbolTable("aa_out")
    all_outs = list("ACDEFGHIKLMNPQRSTVWY") + ["*"]
    max_out = len(all_outs)

    
    out_sym.add_symbol(str(max_out+1), EPS)
    out_counter = 1
    aa_map = {}
    for aa in list("ACDEFGHIKLMNPQRSTVWY") + ["*"]:
        b = out_counter
        out_sym.add_symbol(str(b), b)
        aa_map[aa] = str(out_counter)
        out_counter += 1
    out_sym.add_symbol(EOS_STR, max_out+1)  # EOS as a single byte

    # Precompute codon -> AA byte
    CODON2AA_BYTE = {codon: aa_map[aa] for codon, aa in genetic_code.items()}

    fst = pynini.Fst()
    fst.set_input_symbols(in_sym)
    fst.set_output_symbols(out_sym)
    one = pynini.Weight.one(fst.weight_type())

    start = fst.add_state()
    fst.set_start(start)
    fst.set_final(start, one)

    # Byte labels for inputs
    A_id = in_sym.find(str(codon_map["A"]))
    C_id = in_sym.find(str(codon_map["C"]))
    G_id = in_sym.find(str(codon_map["G"]))
    T_id = in_sym.find(str(codon_map["T"]))
    EPS_OUT = EPS

    # first base
    first = {}
    for ch, lbl in (("A", A_id), ("C", C_id), ("G", G_id), ("T", T_id)):
        s1 = fst.add_state()
        fst.set_final(s1, one)
        fst.add_arc(start, pynini.Arc(lbl, EPS_OUT, one, s1))
        first[ch] = s1

    # second base
    second = {}
    for b1, s1 in first.items():
        for ch, lbl in (("A", A_id), ("C", C_id), ("G", G_id), ("T", T_id)):
            s2 = fst.add_state()
            fst.set_final(s2, one)
            fst.add_arc(s1, pynini.Arc(lbl, EPS_OUT, one, s2))
            second[b1 + ch] = s2

    # emit AA byte and return to start
    for two, s2 in second.items():
        for ch, lbl in (("A", A_id), ("C", C_id), ("G", G_id), ("T", T_id)):
            codon = two + ch
            aa_byte = CODON2AA_BYTE[codon]
            fst.add_arc(s2, pynini.Arc(lbl, int(aa_byte), one, start))


    # SPECIAL CASE: IGNORE EOS
    # SPECIAL_IN_ID = in_sym.find(EOS_STR)
    # STAR_OUT_ID = out_sym.find(EOS_STR)
    #for st in fst.states(): 
    #    fst.add_arc(st, pynini.Arc(SPECIAL_IN_ID, STAR_OUT_ID, one, st))

    fst.rmepsilon()
    fst.optimize()

    T = TransducedLM(fst, llm=llm)
    T.compute_universal_states()
    T.universal_states = {s: True for s in T.fst.states()}
    T.eos_out = int(EOS_STR)
    T.name = "hf_dna2aa_bytes"
    T.in_symtab = in_sym
    T.out_symtab = out_sym
    T.aa_map = aa_map
    T.codon_map = codon_map
    return T