from infer import InferModule, ClauseInferModule
from tensor_encoder import TensorEncoder, TensorEncoder_mi
from fol.logic import *
from fol.data_utils import DataUtils
from fol.language import DataType
import numpy as np

p_ = Predicate('.', 1, [DataType('spec')])
false = Atom(p_, [Const('__F__', dtype=DataType('spec'))])
true = Atom(p_, [Const('__T__', dtype=DataType('spec'))])


def get_lang(lark_path, lang_base_path, dataset_type, dataset):
    """Load the language of first-order logic from files.

    Read the language, clauses, background knowledge from files.
    Atoms are generated from the language.
    """
    du = DataUtils(lark_path=lark_path, lang_base_path=lang_base_path,
                   dataset_type=dataset_type, dataset=dataset)
    lang = du.load_language()
    clauses = du.load_clauses(du.base_path + 'clauses.txt', lang)
    bk_clauses = du.load_clauses(du.base_path + 'bk_clauses.txt', lang)
    bk = du.load_atoms(du.base_path + 'bk.txt', lang)
    bk2 = du.load_atoms(du.base_path + 'bk2.txt', lang)
    atoms = generate_atoms(lang)
    return lang, clauses, bk_clauses, bk, bk2, atoms

def get_lang_mi(lark_path, lang_base_path, dataset_type, dataset,atoms,V_T, meta_arg):
    """Load the language of meta first-order logic from files.

    Read the language, clauses, background knowledge from files.
    Atoms are generated from the language.
    """
    du = DataUtils(lark_path=lark_path, lang_base_path=lang_base_path,
                   dataset_type=dataset_type, dataset=dataset)
    du.generate_const_for_mi(atoms, V_T)
    lang_mi = du.load_language_mi(meta_arg)
    clauses_mi = du.get_clauses_mi(lang_mi, meta_arg)
    a = clauses_mi
    #   bk_mi = du.get_bk(lang_mi)
    atoms_mi = generate_atoms_mi(lang_mi)

    return lang_mi, clauses_mi, atoms_mi

def get_lang_mi_train(lark_path, lang_base_path, dataset_type, dataset, atoms, V_T, bk, meta_arg = 0):
    """Load the language of meta first-order logic from files.

    Read the language, clauses, background knowledge from files.
    Atoms are generated from the language.
    """
    du = DataUtils(lark_path=lark_path, lang_base_path=lang_base_path,
                   dataset_type=dataset_type, dataset=dataset)
    du.generate_const_for_mi_train(atoms, V_T, meta_arg)
    lang_mi = du.load_language_mi_train(meta_arg)
    clauses_mi = du.get_clauses_mi_train(lang_mi)
    atoms_mi, terms = generate_atoms_mi(lang_mi, bk,max_term_depth=2, meta_arg= meta_arg)

    return lang_mi, clauses_mi, atoms_mi, terms


def get_args_list(terms_list):
    args_list = []
    for item0 in terms_list[0]:
        for item1 in terms_list[1]:
            if str(item0) not in str(item1):
                args_list.append((item0, item1))
    return args_list


def get_args_list_faster(terms_list):
    # Generate the Cartesian product of terms_list[0] and terms_list[1]
    args_list = [(item0, item1) for item0, item1 in itertools.product(terms_list[0], terms_list[1])
                 if str(item0) not in str(item1)]
    return args_list


def generate_terms(lang, max_term_depth):
    consts = lang.consts
    funcs = lang.funcs
    terms = consts
    for f in funcs:
        if f.name == 'h':
            max = 2
        elif f.name == 'g':
            max = 4
        elif f.name == 'r':
            max = 4
        else:
            max = max_term_depth
        new_terms = []
        for i in range(max):
            terms_list = []
            for in_dtype in f.in_dtypes:
                terms_dt = [term for term in terms if term.dtype == in_dtype]
                if terms_dt:
                    terms_list.append(terms_dt)
            if f.name =='k':
                terms_list[0].pop(0)
            if f.name in [ 'f' , 'p' , 'c']:
                if len(terms_list) == 1:
                    terms_list.append(terms_list[0])
                args_list = list((itertools.product(*terms_list)))
            elif f.name == 'g':
                args_list= get_args_list_faster(terms_list)
            else:
                args_list = list(set(itertools.product(*terms_list)))

            for args in args_list:
                new_terms.append(FuncTerm(f, args))
            new_terms = list(set(new_terms))
            terms.extend(new_terms)
    return list(terms)

def build_graph_from_pairs(pairs_list):
    graph = {}  # 创建一个空字典来存储图结构
    for start, end in pairs_list:
        if start not in graph:
            graph[start] = []  # 如果节点不在字典中，初始化为一个空列表
        if end not in graph[start]:
            graph[start].append(end)  # 将当前节点的邻接节点添加到列表中（防止重复）
    return graph

def build_new_paths(graph, path):
    start_node = path[0]  # 获取路径的起始节点 'a'
    new_paths = []  # 用于存储新的路径

    if start_node in graph:  # 确保起始节点有邻接节点
        for next_node in graph[start_node]:
            if next_node not in path:  # 检查新的节点是否不在旧路径中
                new_paths.extend([next_node] + path)  # 在路径的开头插入新的节点
    if len(new_paths) == 0:
        return path

    return new_paths


def is_valid_path(path, graph):
    for i in range(len(path) - 1):
        current_node = path[i]
        next_node = path[i + 1]
        # Check if there's an edge between current_node and next_node
        if next_node not in graph.get(current_node, []):
            return False
    return True

def generate_atoms_mi(lang, bk, max_term_depth=2, meta_arg = 0):
    spec_atoms = [false, true]
    atoms = []
    terms = generate_terms(lang, max_term_depth=max_term_depth)
    lst_for_findall = []
    for i in bk:
        i.all_consts()
        lst_for_findall.append(i.all_consts())
    pairs_list = lst_for_findall
    # 构建图
    graph = build_graph_from_pairs(pairs_list)
    if meta_arg == 0:
        bk_path = [['a', '*'], ['c', 'a', '*'], ['b', 'a', '*'], ['e', 'c', 'a', '*'],
                    ['d', 'b', 'a', '*']]
        bk_stack = [['a', '*'], ['b', 'a', '*'],
                    ['d', 'b', 'a', '*'], ['f','d', 'b', 'a', '*']]
        bk_paths = [['*'], ['a', '*', '*'], ['c','a', '*', '*'], ['e','c','a', '*', '*'], ['b','a', '*', '*'], ['d','b','a', '*', '*'], ['b', 'a', '*', 'c', 'a', '*', '*'], ['c', 'a', '*', 'b', 'a', '*', '*']
            , ['c', 'a', '*', 'd', 'b', 'a', '*',  '*'], [ 'b', 'a', '*', 'e','c', 'a', '*', '*']]
    else:
        bk_path = [['a', '*'], ['c', 'a', '*'], ['b', 'a', '*'], ['d', 'a', '*'], ['e', 'b', 'a', '*'],
                    ['f', 'c', 'a', '*'], ['h', 'e', 'b', 'a', '*']]
        bk_stack = [['a', '*'], ['c', 'a', '*'], ['b', 'a', '*'], ['d', 'a', '*'], ['e', 'b', 'a', '*'],
                    ['f', 'c', 'a', '*'], ['h', 'e', 'b', 'a', '*']]
        bk_paths = [['*'], ['a', '*', '*'], ['c','a', '*', '*'], ['e','c','a', '*', '*'], ['b','a', '*', '*'], ['d','b','a', '*', '*'], ['b', 'a', '*', 'c', 'a', '*', '*'], ['c', 'a', '*', 'b', 'a', '*', '*']
            , ['c', 'a', '*', 'd', 'b', 'a', '*',  '*'], [ 'b', 'a', '*', 'e','c', 'a', '*', '*']]

    path_list = [term for term in terms if term.dtype == 'path']
    nonpath_list = [term for term in terms if term.dtype != 'path']
    string_paths_list = [[str(element) for element in inner_list.all_consts()] for inner_list in path_list]
    lst_bk_path = [string_paths_list.index(inner_list) for inner_list in bk_path]
    selected_path = [path_list[i] for i in lst_bk_path]
    terms = selected_path + nonpath_list


    paths_list = [term for term in terms if term.dtype == 'paths']
    nonpaths_list = [term for term in terms if term.dtype != 'paths']
    string_paths_list = [[str(element) for element in inner_list.all_consts()] for inner_list in paths_list]
    lst_bk_paths = [string_paths_list.index(inner_list) for inner_list in bk_paths]
    selected_paths = [paths_list[i] for i in lst_bk_paths]
    terms = selected_paths + nonpaths_list

    stack_list = [term for term in terms if term.dtype == 'stack']
    nonstack_list = [term for term in terms if term.dtype != 'stack']
    string_stack_list = [[str(element) for element in inner_list.all_consts()] for inner_list in stack_list]
    lst_bk_stack = [string_stack_list.index(inner_list) for inner_list in bk_stack]
    selected_stack = [stack_list[i] for i in lst_bk_stack]
    terms = selected_stack + nonstack_list


    for pred in lang.preds:
        #print(pred)
        dtypes = pred.dtypes
        terms_list = [[term for term in terms if term.dtype == dtype] for dtype in dtypes]

        if len(dtypes)==1:
            if str(dtypes[0]) == 'atoms':
                args_list = list(set(itertools.product(*terms_list)))
            else:
              #  args_list = list(set(itertools.zip_longest(*terms_list)))
                args_list = list(set(itertools.product(*terms_list)))
        else:
            if str(dtypes[1]) == 'atoms' :
                args_list = list(set(itertools.product(*terms_list)))
            elif str(dtypes[1]) == 'proofs_head':
                proof_head = np.array([str(item.get_ith_term(1)) for item in terms_list[1] ])
                args_list = []
                for head in terms_list[0]:
                    correspond_index = np.where(proof_head == str(head))
                    proof_list = np.array(terms_list[1])[correspond_index]
                    args_list += list(set(itertools.product([head], proof_list) ))
            elif pred.name == 'append':
                args_list = list(itertools.product(*terms_list))
                index_list=[]
                for index, arg in enumerate(args_list):
                    if str(arg[0]) == '*':
                        if str(arg[1]) == str(arg[2]):
                            print(index)
                            index_list.append(index)
                    elif str(arg[1]) == '*':
                        if str(arg[0]) == str(arg[2]):
                            index_list.append(index)
                            print(index)
                    else:
                        if arg[0].to_list()[:-1] + arg[1].to_list() == arg[2].to_list():
                            index_list.append(index)
                            print(index)
                index_list = set(index_list)
                args_list = [args_list[i] for i in range(len(args_list)) if i in index_list]

            elif pred.name == 'paths':
                args_list = list(itertools.product(*terms_list))
                index_list=[]
                for index, arg in enumerate(args_list):
                    if str(arg[0]) == '*':
                        if str(arg[1]) == str(arg[2]):
                            print(index)
                            index_list.append(index)
                    elif str(arg[1]) == '*':
                        if str(arg[0]) == str(arg[2]):
                            index_list.append(index)
                            print(index)
                    else:
                        if arg[0].all_consts() + arg[1].all_consts() == arg[2].all_consts():
                            index_list.append(index)
                            print(index)
                index_list = set(index_list)
                args_list = [args_list[i] for i in range(len(args_list)) if i in index_list]
            elif pred.name == 'findall':
                index_list = []
                args_list = list(set(itertools.product(*terms_list)))
                for index, arg in enumerate(args_list):
                    if arg[0].all_consts()[0].name != '*':
                        old_path_list= arg[1].all_consts()
                        old_path_list.pop()
                        pair_list = arg[0].all_consts()
                        newpaths = build_new_paths(graph, pair_list)
                        if sorted(newpaths) == sorted(old_path_list):
                            print(index)
                            index_list.append(index)
                args_list = [args_list[i] for i in range(len(args_list)) if i in index_list]
            #elif pred.name == 'dfs':
            #    args_list = list(itertools.product(*terms_list))
            else:
                #args_list = list(set(itertools.zip_longest(*terms_list)))
                args_list = list(set(itertools.product(*terms_list)))
        # args_list = lang.get_args_by_pred(pred)
        args_str_list = []
        # args_mem = []
        a = 0
        for args in args_list:
            if pred.name == 'append' or pred.name == 'paths' or pred.name == 'dfs' or pred.name == 'dfsf'or pred.name == 'equal':
                atoms.append(Atom(pred, args))
            elif len(args) == 1 or len(set(args)) == len(args):
                # if len(args) == 1 or (args[0] != args[1] and args[0].mode == args[1].mode):
                # if len(set(args)) == len(args):
                # if not (str(sorted([str(arg) for arg in args])) in args_str_list):
                atoms.append(Atom(pred, args))

                # args_str_list.append(
                #    str(sorted([str(arg) for arg in args])))
                # print('add atom: ', Atom(pred, args))
    return spec_atoms + sorted(atoms), terms


def get_searched_clauses(lark_path, lang_base_path, dataset_type, dataset):
    """Load the language of first-order logic from files.

    Read the language, clauses, background knowledge from files.
    Atoms are generated from the language.
    """
    du = DataUtils(lark_path=lark_path, lang_base_path=lang_base_path,
                   dataset_type=dataset_type, dataset=dataset)
    lang = du.load_language()
    clauses = du.load_clauses(du.base_path  +  dataset +  '/beam_searched.txt', lang)
    return clauses


def get_clauses_mi(self, lang):
    return self.load_clauses(self.base_path + 'mi_clauses.txt', lang)
def _get_lang(lark_path, lang_base_path, dataset_type, dataset):
    """Load the language of first-order logic from files.

    Read the language, clauses, background knowledge from files.
    Atoms are generated from the language.
    """
    du = DataUtils(lark_path=lark_path, lang_base_path=lang_base_path,
                   dataset_type=dataset_type, dataset=dataset)
    lang = du.load_language()
    clauses = du.get_clauses(lang)
    bk = du.get_bk(lang)
    atoms = generate_atoms(lang)
    return lang, clauses, bk, atoms

def build_infer_module(clauses, bk_clauses, atoms, lang, device, m, infer_step=3, train=False):
    te = TensorEncoder(lang, atoms, clauses, device=device)
    I = te.encode()
    if len(bk_clauses) > 0:
        te_bk = TensorEncoder(lang, atoms, bk_clauses, device=device)
        I_bk = te_bk.encode()
    else:
        te_bk = None
        I_bk = None
    ##I_bk = None
    im = InferModule(I, m=m, infer_step=infer_step, device=device, train=train, I_bk=I_bk)
    return im
def build_infer_module_mi(clauses, bk_clauses, atoms, lang, terms, device, m=3, infer_step=3, train=False):
    te = TensorEncoder_mi(lang, atoms, clauses, terms, device=device)
    I = te.encode()
    if len(bk_clauses) > 0:
        te_bk = TensorEncoder(lang, atoms, bk_clauses, device=device)
        I_bk = te_bk.encode()
    else:
        te_bk = None
        I_bk = None
    ##I_bk = None
    im = InferModule(I, m=m, infer_step=infer_step, device=device, train=train, I_bk=I_bk)
    return im

def build_clause_infer_module(clauses, bk_clauses, atoms, lang, device, m=3, infer_step=3, train=False):
    te = TensorEncoder(lang, atoms, clauses, device=device)
    I = te.encode()
    if len(bk_clauses) > 0:
        te_bk = TensorEncoder(lang, atoms, bk_clauses, device=device)
        I_bk = te_bk.encode()
    else:
        te_bk = None
        I_bk = None

    im = ClauseInferModule(I, m=m, infer_step=infer_step, device=device, train=train, I_bk=I_bk)
    return im

def generate_atoms(lang):
    spec_atoms = [false, true]
    atoms = []
    for pred in lang.preds:
        dtypes = pred.dtypes
        consts_list = [lang.get_by_dtype(dtype) for dtype in dtypes]
        args_list = list(set(itertools.product(*consts_list)))
        # args_list = lang.get_args_by_pred(pred)
        args_str_list = []
        # args_mem = []
        for args in args_list:
            if len(args) == 1 or len(set(args)) == len(args):
                # if len(args) == 1 or (args[0] != args[1] and args[0].mode == args[1].mode):
                # if len(set(args)) == len(args):
                # if not (str(sorted([str(arg) for arg in args])) in args_str_list):
                atoms.append(Atom(pred, args))
                # args_str_list.append(
                #    str(sorted([str(arg) for arg in args])))
                # print('add atom: ', Atom(pred, args))
    return spec_atoms + sorted(atoms)


def generate_bk(lang):
    atoms = []
    for pred in lang.preds:
        if pred.name in ['diff_color', 'diff_shape']:
            dtypes = pred.dtypes
            consts_list = [lang.get_by_dtype(dtype) for dtype in dtypes]
            args_list = itertools.product(*consts_list)
            for args in args_list:
                if len(args) == 1 or (args[0] != args[1] and args[0].mode == args[1].mode):
                    atoms.append(Atom(pred, args))
    return atoms


def get_index_by_predname(pred_str, atoms):
    for i, atom in enumerate(atoms):
        if atom.pred.name == pred_str:
            return i
    assert 1, pred_str + ' not found.'


def parse_clauses(lang, clause_strs):
    du = DataUtils(lang)
    return [du.parse_clause(c) for c in clause_strs]
