import torch
import torch.nn as nn

from fol.logic import NeuralPredicate, Predicate
from tqdm import tqdm
import numpy as np

class Check_Existence_module(nn.Module):
    """
    FactsConverter converts the output from the perception module to the valuation vector.
    """

    def __init__(self, bk, lang, perception_module, valuation_module, clause, meta_arg= None, device=None):
        super(Check_Existence_module, self).__init__()
        self.e = perception_module.e
        self.d = perception_module.d
        self.lang = lang
        self.vm = valuation_module  # valuation functions
        self.device = device
        self.object_clause = clause
        self.bk = bk
        self.meta_arg = meta_arg

    def __str__(self):
        return "FactsConverter(entities={}, dimension={})".format(self.e, self.d)

    def __repr__(self):
        return "FactsConverter(entities={}, dimension={})".format(self.e, self.d)

    def forward(self, Z,G,V):
        return self.convert(Z,G,V)

    def get_params(self):
        return self.vm.get_params()

    def init_valuation(self, n, batch_size):
        v = torch.zeros((batch_size, n)).to(self.device)
        v[:, 1] = 1.0
        return v

    def filter_by_datatype():
        pass

    def to_vec(self, term, zs):
        pass

    def __convert(self, Z, G):
        # Z: batched output
        vs = []
        for zs in tqdm(Z):
            vs.append(self.convert_i(zs, G))
        return torch.stack(vs)

#     def convert(self, object_atoms, G, V_object):
# #        batch_size = object_atoms.size(0)
#         batch_size = 1
#
#         # V = self.init_valuation(len(G), Z.size(0))
#         V = torch.zeros((batch_size, len(G))).to(
#             torch.float32).to(self.device)
#
#         index_true = np.where(V_object > 0.7)[0]
#         object_atoms_true = np.array(object_atoms)[index_true]
#         object_atoms_true_str = [str(item) for item in object_atoms_true]
#         head_lst = np.array([str(m.head.pred.name) for m in self.object_clause])
#
#         for i, atom in enumerate(G):
#             if type(atom.pred) == NeuralPredicate:
#                 if str(atom.terms[0]) in object_atoms_true_str:
#                     index = object_atoms_true_str.index(str(atom.terms[0]))
#                     if type(object_atoms_true[index].pred) == Predicate:
#                         index_clause = np.where(head_lst == object_atoms_true[index].pred.name )
#                         if not index_clause[0]:
#                             V[:, i ] = 0
#                         else:
#                             body = self.object_clause[index_clause[0][0]].body
#                             body_name_object_class = [item.pred.name for item in body]
#                             body_name_object_object =[str(item) for item in object_atoms_true[index].terms]
#                             body_name_mi = [atom.terms[1].get_ith_term(i).name for i in range(1, atom.terms[1].size())]
#                             eva = []
#                             for j in  range(len(body_name_object_object)):
#                                 eva.append(body_name_object_class[j] in body_name_mi[j])
#                                 eva_obj = []
#                                 for k in range(len(body_name_object_object)):
#                                     eva_obj.append(body_name_object_object[j] in  body_name_mi[k])
#                                 eva.append(bool(sum(eva_obj)))
#                             eva_sum = 1 if sum(eva) == len(eva) else 0
#                             V[:, i] = V_object[index_true[index] ]*eva_sum
#                     else:
#                         V[:, i] = V_object[index_true[index]]
#         return V


    def convert(self, object_atoms, G, V_object):
        #        batch_size = object_atoms.size(0)
        batch_size = 1
        string_atoms = [str(item) for item in self.bk]
        # V = self.init_valuation(len(G), Z.size(0))
        V = torch.zeros((batch_size, len(G))).to(
            torch.float32).to(self.device)

        str_bk = [0]  # [str(item) for item in bk]

        index_true = np.where(V_object > 0.7)
        object_atoms_str = np.array([str(item) for item in object_atoms])
        object_atoms_true = np.array(object_atoms)[index_true]
        object_atoms_true_str = [str(item) for item in object_atoms_true]
        head_lst = np.array([str(m.head.pred.name) for m in self.object_clause])

        # for j, atom in enumerate(G):
        #     if str(atom.pred.name) == 'intervene':
        #         intervene_node = str(atom.terms[0]).split(',')[1][0]
        for i, atom in enumerate(G):
            if atom.pred.name == 'change_state':
                V[:, i] = 1
            if atom.pred.name == 'edge':
                if str(atom) in string_atoms:
                    V[:, i] = 1
            if atom.pred.name == 'starth':
                if str(atom.terms[0]) in object_atoms_true_str:
                    V[:, i] = 1
            if atom.pred.name == 'starth' or atom.pred.name == 'startv':
                V[:, i] = 1
            if atom.pred.name == 'condition_met':
                if atom.terms[0] in atom.terms[1].all_consts():
                    V[:, i] = 1
            if atom.pred.name == 'equal':
                if atom.terms[1] in atom.terms[0].all_consts():
                    V[:, i] = 1
            if atom.pred.name == 'move':
                if str(atom.terms[1])[:5]==str(atom.terms[2])[:5] and str(atom.terms[1])[:5] == 'pos_h':
                    if str(atom.terms[0]) == 'moveleft':
                        if str(atom.terms[2]).count('f') == str(atom.terms[1]).count('f')-1:
                            V[:, i] = 1
                    if str(atom.terms[0]) == 'moveright':
                        if str(atom.terms[2]).count('f') == str(atom.terms[1]).count('f') + 1:
                            V[:, i] = 1
                if str(atom.terms[1])[:5]==str(atom.terms[2])[:5] and str(atom.terms[1])[:5] == 'pos_v':
                    if str(atom.terms[0]) == 'movedown':
                        if str(atom.terms[2]).count('f') == str(atom.terms[1]).count('f')-1:
                            V[:, i] = 1

                    if str(atom.terms[0]) == 'moveup':

                        if str(atom.terms[2]).count('f') == str(atom.terms[1]).count('f') + 1:
                            V[:, i] = 1
            if atom.pred.name == 'plan':
                if str(atom) == 'plan(pos_h(obj0,f(0)),g(pos_h(obj0,f(0))),pos_h(obj0,f(f(f(f(f(0)))))),*)' :
                    V[:, i] = 1
                #'plan(pos_v(obj0,f(0)),g(pos_v(obj0,f(0))),pos_v(obj0,f(f(f(f(f(0)))))),*)'
                if str(atom) == 'plan(pos_v(obj0,f(0)),g(pos_v(obj0,f(0))),pos_v(obj0,f(f(f(f(f(0)))))),*)' :
                    V[:, i] = 1
            if(atom.pred.name == 'do' and str(atom)=='do(medicinea(a,s))' ):
                V[:,i] = 0.8
            if (atom.pred.name == 'ndo'):
                V[:, i] = 1.0
            if (atom.pred.name == 'ins'):
                if(str(atom.terms[0]) != str(atom.terms[1])):
                    V[:, i] = 1.0
            if (atom.pred.name == 'ins2'):
                if str(atom.terms[0].dtype) != 'atom_head':
                    V[:, i] = 1.0
            if(atom.pred.name == 'probs' and len(atom.all_consts())==1):
                V[:,i] = 1.0
            if (atom.pred.name == ('assert' or 'prob')):
                V[:, i] = V_object[np.where(str(atom.terms[0]) == object_atoms_str)[0][0]]
            if (atom.pred.name == 'solve') and type(atom.pred) == NeuralPredicate:
                V[:, i] = V_object[np.where(str(atom.terms[0]) == object_atoms_str)[0][0]]
            if (atom.pred.name == 'mi') and type(atom.pred) == NeuralPredicate:
                V[:, i] = V_object[np.where(str(atom.terms[0]) == object_atoms_str)[0][0]]
            if (atom.pred.name == 'asserts' or atom.pred.name == 'clause' or atom.pred.name == 'clause2'):
                head_var = np.array(self.object_clause[0].head.all_vars())
                body_var = np.array([body.all_vars() for body in self.object_clause[0].body])
                head_name_lst = self.object_clause[0].head.pred.name
                body_name_lst = [body.pred.name for body in self.object_clause[0].body]
                atom_head_name = str(atom.terms[0]).split('(')[0]
                'delete the * in the list of body variable'
                atom_body_name = [str(item).split('(')[0] for item in atom.terms[1].all_consts()]
                if '*' not in str(atom.terms[0].all_consts()):
                    if '*' in str(atom.terms[1].all_consts()):
                        if str(atom.terms[1].all_consts()[-1]) == '*':
                            atom_body_name.pop(-1)
                            #                else:
                            # atom_body_name = [str(item).split('(')[0] for item in atom.terms[1].all_consts()]
                            atom_head_var = str(atom.terms[0]).split('(')[1].split(')')[0].split(',')

                            lst_without_star = atom.terms[1].all_consts()
                            lst_without_star.pop(-1)
                            atom_body_var = [str(item).split('(')[1].split(')')[0].split(',') for item in lst_without_star]
                            'to delete the * in the const list of a clause body (represented as functor)'

                    else:

                        atom_body_var = [str(item).split('(')[1].split(')')[0].split(',') for item in
                                         atom.terms[1].all_consts()]
                        atom_head_var = str(atom.terms[0]).split('(')[1].split(')')[0].split(',')


                    atom_var = [atom_head_var] +  atom_body_var #+[atom_head_var]
                    name_identity = 0
                    'if body_name_lst[i] != atom_body_name[i] error type '
                    'list index out of range'

                    if head_name_lst == atom_head_name:
                        name_identity = 1
                        if len(atom_body_name) < len(body_var):
                            name_identity = 0
                        for j in range(0, min(len(atom_body_name),len(body_name_lst))):
                            if body_name_lst[j] != atom_body_name[j]:
                                name_identity = 0
                                break
                    if name_identity == 1:
                        all_var = np.array(self.object_clause[0].all_vars())
                        corresd_idx = []
                        '[[index where the first variable occurs in the head and body terms],[...the second... ],[... the third ...]]'
                        for item in all_var:
                            idx = []
                            has_item = False
                            for m in range(0, len(head_var)):
                                if item == head_var[m]:
                                    idx.append(m)
                                    has_item = True
                            if not has_item:
                                idx.append(None)

                            for body_var_itemlst in body_var:
                                has_item = False
                                for n in range(0, len(body_var_itemlst)):
                                    if item == body_var_itemlst[n]:
                                        idx.append(n)
                                        has_item = True
                                if not has_item:
                                    idx.append(None)
                            corresd_idx.append(idx)

                        #to evaluate the identity
                        identity = False
                        if len(corresd_idx) == len(body_var) + 1:
                            for lst in corresd_idx:
                                if (lst[0] != None):
                                    base = atom_var[0][lst[0]]
                                    start = 1
                                else:
                                    base = atom_var[1][lst[1]]
                                    start = 2
                                identity = True
                                for p in range(start, len(lst)):
                                    if lst[p] != None or p == len(lst) - 1:
                                        if lst[p] != None:
                                            if base != atom_var[p][lst[p]]:
                                                identity = False
                                                break
                                    else:
                                        if p <= len(lst) - 1:
                                            if base != atom_var[p + 1][lst[p + 1]]:
                                                identity = False
                                                break
                                if identity == False:
                                    break
                        if identity:
                            V[:, i] = 1.0
            if atom.pred.name == 'findall':
                V[:, i] = 1
            if atom.pred.name == 'equalbfs':
                if atom.terms[0].all_consts()[0] == atom.terms[2] or atom.terms[1].all_consts()[0]== atom.terms[2]:
                    V[:, i] = 1
            if atom.pred.name == 'append':
                V[:, i] = 1
            if self.meta_arg == 0:
                if atom.pred.name == 'bfs':
                    if str(atom) == 'bfs(k(g(a,*),*),a,e)':
                        V[:, i] = 1
                if atom.pred.name == 'dfs':
                    if str(atom) == 'dfs(a,a,e,r(a,*))':
                        V[:, i] = 1
            elif self.meta_arg == 1:
                if atom.pred.name == 'bfs':
                    if str(atom) == 'bfs(k(g(a,*),*),a,h)':
                        V[:, i] = 1
                if atom.pred.name == 'dfs':
                    if str(atom) == 'dfs(a,a,h,r(a,*))':
                        V[:, i] = 1
        V[:, 1] = 1
        return V


    def check_value(self,Z, Z_pred, atom ):
        index = np.where(Z_pred == str(atom.pred))

        return Z[index]




    def convert_i(self, zs, G):
        v = self.init_valuation(len(G))
       # for i, atom in enumerate(G):
          #  if type(atom.pred) == PlannerPredicate and i > 1:
          #      v[i] = self.vm.eval(atom, zs)
        return v

    def call(self, pred):
        return pred