import operator
import os
from typing import List, Dict, Any
from functools import reduce

import pynini
from pynini import cross, cdrewrite, union

from tokenizer_conversion.machines.transducedLM import TransducedLM
from tokenizer_conversion.machines.utils.utils import _log

def aggregate_words(
    chars: List[str],
    log_probs: List[float],
    bow_token: str = "258",
    eow_token: str = "259",
) -> List[Dict[str, Any]]:
    if len(chars) != len(log_probs):
        raise ValueError("chars and log_probs must be the same length")
    words: List[Dict[str, Any]] = []
    cur_chars: List[str] = []
    cur_lps:   List[float] = []
    in_word = False
    has_content = False

    def flush() -> None:
        nonlocal cur_chars, cur_lps, in_word, has_content
        if not cur_chars:
            return
        stripped = "".join(
            c for c in cur_chars
            if c not in (bow_token, eow_token) and not c.isspace()
        )
        words.append(
            dict(
                word=stripped,
                logprob=sum(cur_lps),
                chars=cur_chars.copy(),
                char_logprobs=cur_lps.copy(),
            )
        )
        cur_chars.clear()
        cur_lps.clear()
        in_word = False
        has_content = False

    for ch, lp in zip(chars, log_probs):
        if ch == bow_token:
            if not in_word:
                in_word = True # first prefix of a word
            elif has_content:#open NEXT word
                flush()
                in_word = True
            # else: still prefix of current word
            cur_chars.append(ch)
            cur_lps.append(lp)
        elif ch == eow_token:
            if not in_word:
                raise ValueError("EOW encountered without an open word")
            cur_chars.append(ch)
            cur_lps.append(lp)
            flush()# explicit end
        else:# ordinary char or space
            if not in_word:
                in_word = True
            cur_chars.append(ch)
            cur_lps.append(lp)
            if not ch.isspace():
                has_content = True
    flush()
    total_char_lp = sum(log_probs)
    total_word_lp = sum(w["logprob"] for w in words)
    if abs(total_char_lp - total_word_lp) > 1e-10:
        raise ValueError(
            f"Probability leakage: {total_char_lp} vs. {total_word_lp}"
        )
    return words

def detokenise_ptb(
    stream: list[str],
    sep_begin: str = "258",
    sep_end: str   = "259",
    eos: str       = "256",
    keep_space: bool = True,
) -> str:
    out_parts, buf = [], []
    def _flush():
        if buf:
            out_parts.append(bytes(buf).decode("utf-8"))
            buf.clear()
    for tok in stream:
        if tok == eos:
            _flush()
            break
        if tok in (sep_begin, sep_end):
            _flush()
            out_parts.append(tok)
            continue
        buf.append(int(tok))
    _flush()
    if keep_space:
        return "".join(out_parts)
    else:
        return (sep_begin + " ").join(s.lstrip(" ") for s in "".join(out_parts).split(sep_begin + " "))


def load_ptb(
    directory, 
    llm=None,
    model_name=None,
    verbose=True,
    EOS_SYM="256", 
    EPS="257", 
    BEGIN_SEP_CHAR="258",
    END_SEP_CHAR="259"
) -> TransducedLM:
    if os.path.isdir(directory):
        _log(verbose, "Loading PTB")
        ptb_wrapped = TransducedLM.load(directory, use_genlm=True, llm_name=model_name)
        ptb_wrapped.compute_universal_states()
        #ptb_wrapped.get_all_fst_arcs() # recompute with special symbols
        ptb = TransducedLM(ptb_wrapped.fst, llm=llm, use_genlm=True, llm_name=model_name)
        ptb._universal_set_cache = ptb_wrapped._universal_set_cache
        ptb.compute_universal_states()
        ptb.eos_out = EOS_SYM
        ptb.special_tokens = [
            ptb.fst.input_symbols().find(EOS_SYM),
            ptb.fst.input_symbols().find(BEGIN_SEP_CHAR),
            ptb.fst.input_symbols().find(END_SEP_CHAR)
        ]

    else:
        ptb_fst = build_ptb_fst_bytes(
            EOS_SYM=EOS_SYM, 
            EPS=EPS, 
            BEGIN_SEP_CHAR=BEGIN_SEP_CHAR,
            END_SEP_CHAR=END_SEP_CHAR
        )
        ptb = TransducedLM(ptb_fst, llm_name=model_name, use_genlm=True)
        ptb.compute_universal_states()
        #ptb_wrapped.precompute_universal_set(verbose=True)
        ptb.eos_out = EOS_SYM
        ptb.save(directory)
        ptb.special_tokens = [
            ptb.fst.input_symbols().find(EOS_SYM),
            ptb.fst.input_symbols().find(BEGIN_SEP_CHAR),
            ptb.fst.input_symbols().find(END_SEP_CHAR)
        ]
        ptb = load_ptb(directory, llm=llm, model_name=model_name)
    ptb.precompute_next_pset()
    ptb.name = "ptb"
    return ptb


def build_ptb_fst_bytes(
        EOS_SYM: str = "256",
        EPS: str = "257",
        BEGIN_SEP_CHAR: str = "258",
        END_SEP_CHAR: str = "259",
        convert_parentheses: bool = False,
        trailing_space: bool = False,
    ) -> pynini.Fst:
    """
    Builds an FST that mimics the Penn Treebank tokenizer: 
    https://www.nltk.org/_modules/nltk/tokenize/treebank.html#TreebankWordTokenizer

    Convert_parentheses has some issues so default is False
    """

    symbols = pynini.SymbolTable()
    symbols.add_symbol(EPS) 

    for b in range(256):
        symbols.add_symbol(str(b), b + 1)

    def b(ch: str) -> pynini.Fst:
        pieces = (pynini.accep(str(bt), token_type=symbols) for bt in ch.encode("utf-8"))
        return reduce(operator.add, pieces)

    def bs(text: str) -> pynini.Fst:
        return reduce(operator.add, (b(c) for c in text), pynini.accep(""))
    
    
    if EOS_SYM not in [symbols.find(i) for i in range(symbols.num_symbols())]:
        print(f"Adding eos '{EOS_SYM}' to symbol table")
        symbols.add_symbol(EOS_SYM)
    # Add separator to symbol table
    if BEGIN_SEP_CHAR not in [symbols.find(i) for i in range(symbols.num_symbols())]:
        print(f"Adding separator '{BEGIN_SEP_CHAR}' to symbol table")
        symbols.add_symbol(BEGIN_SEP_CHAR)
    if END_SEP_CHAR not in [symbols.find(i) for i in range(symbols.num_symbols())]:
        print(f"Adding separator '{END_SEP_CHAR}' to symbol table")
        symbols.add_symbol(END_SEP_CHAR)

    # Add special bracket symbols to symbol table
    if convert_parentheses:
        symbols.add_symbol("-LRB-")
        symbols.add_symbol("-RRB-")
        symbols.add_symbol("-LSB-")
        symbols.add_symbol("-RSB-")
        symbols.add_symbol("-LCB-")
        symbols.add_symbol("-RCB-")

    SPACE = b(" ")

    sigma_chars = []
    for idx in range(1, symbols.num_symbols()):
        sym = symbols.find(idx)
        try:    
            # Should be bytes already, but TODO check
            #if sym in (BEGIN_SEP_CHAR, END_SEP_CHAR):
            #    continue
            sigma_chars.append(pynini.accep(sym, token_type=symbols))        
        except Exception as e:
            print(f"Error accepting symbol {sym}: {e}")
            if sym == " ":
                sigma_chars.append(SPACE)

    sigma = pynini.union(*sigma_chars)
    sigma_star = pynini.closure(sigma)
    
    # DEFINE SPECIAL CHARACTERS
    EOS = pynini.union(pynini.accep(EOS_SYM, token_type=symbols))
    BOS = "[BOS]"
    SEP_BEGIN = pynini.accep(BEGIN_SEP_CHAR, token_type=symbols)
    SEP_END = pynini.accep(END_SEP_CHAR, token_type=symbols)
    
    ALL_SEP = pynini.union(SEP_BEGIN, SEP_END, SPACE)
    SEP_OR_BOS = pynini.union(ALL_SEP, BOS)
    SEP_OR_EOS = pynini.union(ALL_SEP, EOS)

    APOS = b("'")
    DIGIT = pynini.union(*[b(str(i)) for i in range(10)])
    DOT = b(".")

    NON_DOT = pynini.difference(sigma, DOT).optimize()
    NON_DIGIT = pynini.difference(sigma, DIGIT).optimize()
    NON_APOS = pynini.difference(sigma, APOS).optimize()
    NON_APOS_OR_SEP =  pynini.difference(sigma, pynini.union(APOS, SEP_BEGIN, SEP_END, SPACE)).optimize()

    # ^\" -> `` checked
    QUOTE = b('"')
    BACKTICK = b("`")
    DOUBLE_BACKTICK = BACKTICK + BACKTICK

    start_quotes_1 = cdrewrite(
        cross(QUOTE, DOUBLE_BACKTICK),
        BOS,
        "",
        sigma_star,
    )

    # space 
    if trailing_space:
        boundary_space = pynini.cdrewrite(
            pynini.cross(SPACE, SPACE + SEP_END),
            "", "", sigma_star
        )
    else:
        boundary_space = pynini.cdrewrite(
            pynini.cross(SPACE, SEP_BEGIN + SPACE),
            "", "", sigma_star
        )

    # (``) -> " `` " -> Checked
    start_quotes_2 = cdrewrite(cross(DOUBLE_BACKTICK, SEP_BEGIN+DOUBLE_BACKTICK+SEP_END), "", "", sigma_star)

    # ([ \(\[{<])(\"|\'{2}) -> \1 `` -> checked
    R_BRACKET_L = b("(")
    BRACKET_L = b("[")
    BRACE_L = b("{")
    ANGLE_L = b("<")
    
    start_quotes_3 = cdrewrite(
        cross(QUOTE, SEP_BEGIN+DOUBLE_BACKTICK+SEP_END),
        union(R_BRACKET_L, BRACKET_L, BRACE_L, ANGLE_L, SEP_BEGIN, SEP_END, SPACE).plus, # TODO
        "",
        sigma_star,
    )

    starting_quotes_fst = start_quotes_1 @ start_quotes_2 @ start_quotes_3

    # ([:,])([^\d]) -> " \1 \2" -> checked
    COMMA = b(",")
    COLON = b(":")
    
    punct_1 = cdrewrite(cross(COMMA, SEP_BEGIN+COMMA+SEP_END) | cross(COLON, SEP_BEGIN+COLON+SEP_END), 
                        "", NON_DIGIT, sigma_star
                        )

    # ([:,])$ -> r" \1 " checked
    punct_2 = cdrewrite(cross(COMMA, SEP_BEGIN+COMMA+SEP_END) |
                        cross(COLON, SEP_BEGIN+COLON+SEP_END), "", EOS, sigma_star # TODO check
                        )

    # \.\.\. -> " ... " -> checked
    DOT = b(".")

    ellipsis_rule = cdrewrite(cross(DOT + DOT + DOT, SEP_BEGIN+DOT+DOT+DOT+SEP_END), "", "", sigma_star)

    # [;@#$%&] -> " \g<0> " -> checked
    SEMICOLON = b(";")
    AT = b("@")
    PERCENT = b("%")
    AMPERSAND = b("&")
    DOLLAR = b("$")
    special_punct = [
        SEMICOLON,
        AT,
        PERCENT,
        AMPERSAND,
        DOLLAR,
    ]
    spaced_punct = pynini.union(
        *(pynini.cross(sym, SEP_BEGIN + sym + SEP_END) for sym in special_punct)
    )
    punct_4 = cdrewrite(spaced_punct, "", "", sigma_star)

    # ([^\.])(\.)([\]\)}>"']*)\s*$ -> \1 \2\3 TODO check
    R_BRACKET_R = b(")")
    punct_5 = cdrewrite(
        cross(DOT, SEP_BEGIN+DOT),
        NON_DOT,
        union(EOS, APOS+APOS+EOS, QUOTE+EOS, APOS+EOS, R_BRACKET_R+EOS), #TODO
        sigma_star,
    )
    
    # [?!] -> " ? " or " ! "
    QUESTION = b("?")
    EXCLAMATION = b("!")
    punct_6 = cdrewrite(
        cross(QUESTION, SEP_BEGIN + QUESTION + SEP_END) | cross(EXCLAMATION, SEP_BEGIN + EXCLAMATION + SEP_END),
        "",
        "",
        sigma_star,
    )
    
    # ([^'])'  -> \1 '
    punct_7 = cdrewrite(
        cross(APOS, SEP_BEGIN + APOS),
        NON_APOS,
        pynini.union(SEP_BEGIN, EOS),
        #EOS,

        sigma_star,
    )

    punct_fst = (
        punct_1
        @ punct_2
        @ ellipsis_rule
        @ punct_4
        @ punct_5
        @ punct_6
        @ punct_7
    )

    # r"[\]\[\(\)\{\}\<\>]"), r" \g<0> " -> checked
    BRACKET_R = b("]")
    BRACE_R = b("}")
    ANGLE_R = b(">")
    parens_chars = [
        BRACKET_R,
        BRACKET_L,
        R_BRACKET_R,
        R_BRACKET_L,
        BRACE_R,
        BRACE_L,
        ANGLE_R,
        ANGLE_L,
    ]   

    spaced_parens = pynini.union(
        *(pynini.cross(sym, SEP_BEGIN + sym + SEP_END) for sym in parens_chars)
    )
    parens_brackets_fst = cdrewrite(
        spaced_parens,
        "",
        "",
        sigma_star,
    )

    #  (( -> -LRB-,) -> -RRB-, etc.) -> checked
    # Accept bracket symbols
    if convert_parentheses:
        LRB = pynini.accep("-LRB-", token_type=symbols)
        RRB = pynini.accep("-RRB-", token_type=symbols)
        LSB = pynini.accep("-LSB-", token_type=symbols)
        RSB = pynini.accep("-RSB-", token_type=symbols)
        LCB = pynini.accep("-LCB-", token_type=symbols)
        RCB = pynini.accep("-RCB-", token_type=symbols)
        """
        convert_map = {
            R_BRACKET_L: LRB,
            R_BRACKET_R: RRB,
            BRACKET_L: LSB,
            BRACKET_R: RSB,
            BRACE_L: LCB,
            BRACE_R: RCB,
        }
        """
        
        convert_paren_fst = pynini.union(
            *[
                pynini.cross(R_BRACKET_L, SEP_BEGIN+LRB+SEP_END),
                pynini.cross(R_BRACKET_R, SEP_BEGIN+RRB+SEP_END),
                pynini.cross(BRACKET_L, SEP_BEGIN+LSB+SEP_END),
                pynini.cross(BRACKET_R, SEP_BEGIN+RSB+SEP_END),
                pynini.cross(BRACE_L, SEP_BEGIN+LCB+SEP_END),
                pynini.cross(BRACE_R, SEP_BEGIN+RCB+SEP_END),
            ]
        ).optimize()

    # (r"--"), r" -- " -> checked
    DASH = b("-")
    DASH_DASH = DASH + DASH
    double_dashes_fst = cdrewrite(
        cross(DASH_DASH, SEP_BEGIN + DASH_DASH + SEP_END),
        "",
        "",
        sigma_star,
    )

    # (r"''"), " '' " -> checked
    endq_1 = cdrewrite(cross(APOS+APOS, SEP_BEGIN+APOS+APOS+SEP_END), "", "", sigma_star)

    # (r'"'), " '' " -> checked
    endq_2 = cdrewrite(cross(QUOTE, SEP_BEGIN+APOS+APOS+SEP_END), "", "", sigma_star)

    # accept clitics 
    
    SMALL_S = b("s")
    SMALL_M = b("m")
    SMALL_D = b("d")
    CAPS_S = b("S")
    CAPS_M = b("M")
    CAPS_D = b("D")
    
    clitics_1 = [
        APOS+SMALL_S,
        APOS+SMALL_M,
        APOS+SMALL_D,
        APOS+CAPS_S,
        APOS+CAPS_M,
        APOS+CAPS_D,
    ]

    # approx (r"([^' ])('[sS]|'[mM]|'[dD]|') "), r"\1 \2 " checked
    endq_3 = cdrewrite(
        pynini.union(*(pynini.cross(clit, SEP_BEGIN+clit) for clit in clitics_1)),
        NON_APOS_OR_SEP, # TODO check
        SEP_OR_EOS,
        sigma_star,
    )
    apos = cdrewrite(
        pynini.cross(APOS, SEP_BEGIN+APOS),
        NON_APOS_OR_SEP, # TODO check
        SEP_OR_EOS,
        sigma_star,
    )
    # Check capitalized
    # approx (r"([^' ])('ll|'LL|'re|'RE|'ve|'VE|n't|N'T) "), r"\1 \2 " -> checked
    SMALL_L = b("l")
    CAPS_L = b("L")
    SMALL_R = b("r")
    CAPS_R = b("R")
    SMALL_E = b("e")
    CAPS_E = b("E")
    SMALL_V = b("v")
    CAPS_V = b("V")
    SMALL_N = b("n")
    CAPS_N = b("N")
    SMALL_T = b("t")
    CAPS_T = b("T")
    clitics_2 = [
        APOS+ SMALL_L+SMALL_L,
        APOS+CAPS_L+CAPS_L,
        APOS + SMALL_R+SMALL_E,
        APOS +CAPS_R+CAPS_E,
        APOS + SMALL_V + SMALL_E,
        APOS + CAPS_V + CAPS_E,
        SMALL_N + APOS + SMALL_T,
        CAPS_N + APOS + CAPS_T
    ]

    #clitics_2 = ["'ll", "'LL", "'re", "'RE", "'ve", "'VE", "n't", "N'T"]
    endq_4 = cdrewrite(
        pynini.union(*(pynini.cross(clit, SEP_BEGIN + clit) for clit in clitics_2)),
        NON_APOS_OR_SEP,
        SEP_OR_EOS,
        sigma_star,
    )

    ending_quotes_fst = endq_1 @ endq_2 @ endq_3 @ apos @ endq_4

    # Use MacIntyreContractions, CONTRACTIONS2 and CONTRACTIONS3
    contractions_raw_patterns = [
        ("cannot", ("can", "not")),
        ("d'ye", ("d", "'ye")),
        ("more'n", ("more", "'n")),
        ("'tis", ("'t", "is")),
        ("'twas", ("'t", "was")),
        ("gonna", ("gon", "na")),
        ("gotta", ("got", "ta")),
        ("lemme", ("lem", "me")),
        ("wanna", ("wan", "na")),
        ("gimme", ("gim", "me")),
        # Upper case
        ("Cannot", ("Can", "not")),
        ("D'ye", ("D", "'ye")),
        ("More'n", ("More", "'n")),
        ("Gonna", ("Gon", "na")),
        ("Gotta", ("Got", "ta")),
        ("Lemme", ("Lem", "me")),
        ("Wanna", ("Wan", "na")),
        ("Gimme", ("Gim", "me")),
    ]
    contractions_patterns = []
    for pattern in contractions_raw_patterns:
        input_str = pattern[0]
        pieces = pattern[1]
        accepted_input   = bs(input_str)
        accepted_pieces0 = bs(pieces[0])
        accepted_pieces1 = bs(pieces[1])
        
        contractions_patterns.append(
            (accepted_input, (accepted_pieces0, accepted_pieces1))
        )

    contractions2_fsts = [
        cdrewrite(cross(orig, SEP_BEGIN+pieces[0]+SEP_END+SEP_BEGIN+pieces[1]+SEP_END), SEP_OR_BOS, SEP_OR_EOS, sigma_star)
        for orig, pieces in contractions_patterns
    ]
    if contractions2_fsts:
        contractions_fst = contractions2_fsts[0]
        for c2 in contractions2_fsts[1:]:
            contractions_fst @= c2
    else:
        contractions_fst = pynini.accep("")
    
    fst = boundary_space
    fst @= starting_quotes_fst.optimize()
    fst @= punct_fst.optimize()
    fst @= parens_brackets_fst.optimize()
    if convert_parentheses:
        fst @= convert_paren_fst.optimize()
    fst @= double_dashes_fst.optimize()
    fst @= ending_quotes_fst.optimize()
    fst @= contractions_fst.optimize()

    fst.set_input_symbols(symbols)
    fst.set_output_symbols(symbols)

    print("INPUT")
    for sym in fst.input_symbols():
        print(sym)

    print("OUTPUT")
    for sym in fst.output_symbols():
        print(sym)
    # Debugging
    sep_begin_id = symbols.find(BEGIN_SEP_CHAR)
    sep_end_id = symbols.find(END_SEP_CHAR)
    eos_id = symbols.find(EOS_SYM)

    num_sep_in_arcs = 0
    num_sep_out_arcs = 0
    num_eos_arcs = 0

    for state in fst.states():
        for arc in fst.arcs(state):
            if arc.ilabel == sep_begin_id:
                num_sep_in_arcs += 1
            elif arc.ilabel == sep_end_id:
                num_sep_out_arcs += 1
            elif arc.ilabel == eos_id:
                num_eos_arcs += 1
    print(f"PTB FST: {num_sep_in_arcs} sep-in arcs, "
      f"{num_sep_out_arcs} sep-out arcs, {num_eos_arcs} eos arcs")
    return fst.optimize()
