import collections
import re
import logging

from nltk import tree
from lark.load_grammar import  _TERMINAL_NAMES

from prompt_compiler.data_structs.simple_rule import SimpleRule
from prompt_compiler.earley_parser.tree import Tree

logger = logging.getLogger("global_logger")

inline_terminal_names = {
        # for SMC dataset
        "WORD", "NUMBER", "ESCAPED_STRING", "L",
        # for regex dataset
        "STRING", "INT", "CHARACTER_CLASS", "CONST",
        # for overnight
        # "PROPERTY", "SINGLETON_VALUE", "ENTITY_VALUE", "NUMBER_VALUE",
        # for molecule
        "N", "C", "O", "F", "c",
        # for fol
        "PNAME", "CNAME", "LCASE_LETTER"
}
for k, v in _TERMINAL_NAMES.items():
    inline_terminal_names.add(v)

skipped_nonterminal_names = (
    # for smc and regex
    "string", "number", "literal", "delimiter",
    "object" # for DSL
    # "VALUE"  # for mtop
    # "property", "value",  # for overnight
)

def split_rule(rule):
    try:
        split_idx = rule.index(":")
        lhs, rhs = rule[:split_idx].strip(), rule[split_idx+1:].strip()
        return lhs, rhs
    except Exception as e:
        return None, None

def collect_rules_from_larkfile(lark_file):
    """
    Parse bnf file (.lark) to extract rules
    """
    rule_stat = collections.OrderedDict() # used as ordered set
    aux_rules = []

    with open(lark_file, "r") as f:
        cur_nonterminal = None
        for line in f:
            line = line.strip()
            if line.startswith("%"):
                aux_rules.append(line)
            elif line == "" or line.startswith("//"):
                continue
            elif line.startswith("|"):
                rhs = line[1:].strip()
                for rhs_part in rhs.split("|"):
                    rhs_part = rhs_part.strip()
                    if rhs_part == "":
                        continue
                    assert cur_nonterminal is not None
                    rule = SimpleRule(cur_nonterminal, tuple(rhs_part.split()))
                    rule_stat[rule] = 1
            elif ":" in line and "\":" not in line: # for rules like :duration
                lhs, rhs = split_rule(line)
                cur_nonterminal = lhs
                for rhs_part in rhs.split("|"):
                    rhs_part = rhs_part.strip()
                    if rhs_part == "":
                        continue
                    rule = SimpleRule(cur_nonterminal, tuple(rhs_part.split()))
                    rule_stat[rule] = 1
            else:
                raise ValueError(f"Unknown line: {line}")
    rule_set = list(rule_stat.keys())
    return rule_set, aux_rules

def collect_rules_from_larkstr(lark_str):
    rule_stat = collections.OrderedDict() # used as ordered set
    aux_rules = []
    lines = lark_str.split("\n")
    cur_nonterminal = None
    for line in lines:
        line = line.strip()
        if line.startswith("%"):
            aux_rules.append(line)
        elif line == "" or line.startswith("//"):
            continue
        elif line.startswith("|"):
            rhs = line[1:].strip()
            for rhs_part in rhs.split("|"):
                rhs_part = rhs_part.strip()
                if rhs_part == "":
                    continue
                assert cur_nonterminal is not None
                rule = SimpleRule(cur_nonterminal, tuple(rhs_part.split()))
                rule_stat[rule] = 1
        elif ":" in line and "\":" not in line: # for rules like :duration
            lhs, rhs = split_rule(line)
            cur_nonterminal = lhs
            for rhs_part in rhs.split("|"):
                rhs_part = rhs_part.strip()
                if rhs_part == "":
                    continue
                rule = SimpleRule(cur_nonterminal, tuple(rhs_part.split()))
                rule_stat[rule] = 1
        else:
            raise ValueError(f"Unknown line: {line}")
    rule_set = list(rule_stat.keys())
    return rule_set, aux_rules

def rulelist2larkstr(rule_stat):
    lhs2rhs = collections.OrderedDict()
    for rule in rule_stat:
        lhs, rhs = rule.origin, rule.expansion
        if lhs not in lhs2rhs:
            lhs2rhs[lhs] = []
        lhs2rhs[lhs].append(rhs)

    grammar = ""
    for lhs in sorted(lhs2rhs):
        grammar += f"{lhs} :"
        for rhs in sorted(lhs2rhs[lhs]):
            rhs_str = " ".join(rhs)
            grammar += f" {rhs_str} |"
        grammar = grammar[:-2]
        grammar += "\n"

    return grammar.strip()

def lark2bnf(grammar):
    """
    Make it easier for GPT to generate
    """
    #grammar = grammar.replace(" : ", " -> ")
    grammar = grammar.replace(":", "::=")
    return grammar

def linearize_tree(tree):
    def recur_add(node):
        if getattr(node, "children", None) is None:
            return "{" + f"{node.value}" + "}"
        else:
            ret_str = f"[{node.data.value} "
            for child in node.children:
                ret_str += recur_add(child)
                ret_str += " "
            ret_str += "]"
            return ret_str
    return recur_add(tree)

def linearized_tree_to_program(linearized_tree, delimiter=""):
    tokens = re.findall(r'{(.*?)}', linearized_tree)
    return delimiter.join(tokens)

def to_lisp_like_string(node):
    if isinstance(node, tree.Tree):
        return f"({node.label()} {' '.join([to_lisp_like_string(child) for child in node])})"
    else:
        return node

def remove_lf_space(raw_lf: str):
    """
    See run_sempar_icl.py for usage
    """
    try:
        lf_tree = tree.Tree.fromstring(raw_lf)
        return to_lisp_like_string(lf_tree)
    except:
        return raw_lf

def counter2pred(counter):
    if len(counter) == 0:
        return None
    else:
        return counter.most_common(1)[0][0]


def _wrap_string(s):
    if s.startswith("\"") and s.endswith("\""):
        # a bit complex to preserve the quotation marks
        s = f"\"\\{s[:-1]}\\\"{s[-1]}"
    else:
        s = f"\"{s}\""

        # escape unicode characters
    if "\\u" in s:
        s = s.replace("\\u", "\\\\u")

    return s

def treenode2rule(treenode):
    if treenode is None:
        return None

    if isinstance(treenode, Tree):
        origin = f"{treenode.data.value}"
        expansion = []

        for child in treenode.children:
            if child is None:
                continue

            if isinstance(child, Tree):
                expansion.append(child.data.value)
            else:
                if child.type.startswith("__") or child.type in inline_terminal_names:
                    expansion.append(_wrap_string(child.value))
                else:
                    expansion.append(child.type)
    else: # terminal
        if treenode.type.startswith("__") or treenode.type in inline_terminal_names:
            return None
        else:
            origin = treenode.type
            expansion = [_wrap_string(treenode.value)]
    return SimpleRule(origin, tuple(expansion))

def extract_rule_stat(tree, rule_stat):
    """
    Count the occurrence of each rule
    """
    cur_rule = treenode2rule(tree)
    if cur_rule is None:
        return
    if cur_rule not in rule_stat:
        rule_stat[cur_rule] = 1
    else:
        rule_stat[cur_rule] += 1

    if getattr(tree, "children", None):
        for child in tree.children:
            extract_rule_stat(child, rule_stat)

def extract_min_grammar_from_trees(trees, return_rules=False):
    """
    Extract minimal grammar to reconstruct the tree
    """
    rule_stat = collections.OrderedDict()
    for tree in trees:
        extract_rule_stat(tree, rule_stat)
    grammar = rulelist2larkstr(rule_stat)

    if return_rules:
        return grammar, list(rule_stat.keys())
    else:
        return grammar

def gen_min_lark(program, parser):
    """
    Obtain the minimal grammar from a program
    """
    parse_trees = []
    if "\n" in program:
        program = program.split("\n")
        for line in program:
            parse_tree = parser.parse(line)
            parse_trees.append(parse_tree)
    else:
        parse_tree = parser.parse(program)
        parse_trees.append(parse_tree)
    grammar = extract_min_grammar_from_trees(parse_trees)
    return grammar

def bnf2lark(grammar):
    """
    Opposite of lark2bnf
    """
    # grammar = grammar.replace(" -> ", " : ")
    grammar = grammar.replace("::=", ":")
    return grammar

def larkstr2rulelist(lark_str, rhs_sep=None):
    """
    Convert lark grammar string to list of rules.
    TODO: use load_grammar function from lark
    """
    for raw_rule in lark_str.split("\n"):
        lhs, rhs = split_rule(raw_rule)
        if lhs == None and rhs == None:
            continue
        rhs_l = rhs.split("|")
        for rhs in rhs_l:
            rhs = rhs.strip()
            if rhs_sep is not None:
                rhs = rhs.split(rhs_sep)
                rule = SimpleRule(lhs, rhs)
            else:
                # treat rhs as a single token, which is enough
                # for checking grammar validity bc. the the resulting string is the same
                rule = SimpleRule(lhs, (rhs,) )
            yield rule

def check_grammar_validity(valid_rules, pred_lark_str):
    """
    Check if the grammar (i.e., bnf_str produced by model) is valid
    """
    for rule in larkstr2rulelist(pred_lark_str):
        if rule.origin not in skipped_nonterminal_names and rule not in valid_rules:
            logger.info(f"Found invalid rule {rule}")
            return False
    return True

def rulelist2bnfstr(rule_list):
    """
    Convert list of rules to lark grammar string
    """
    larkstr = rulelist2larkstr(rule_list)
    bnf_str = lark2bnf(larkstr)
    return bnf_str

def decorate_grammar(grammar):
    """
    Add auxiliary rules to the grammar
    """
    grammar += "\n%import common.DIGIT"
    grammar += "\n%import common.LCASE_LETTER"
    grammar += "\n%import common.UCASE_LETTER"
    grammar += "\n%import common.WS"
    grammar += "\n%ignore WS"
    return grammar