import itertools
import torch
from tqdm import tqdm
from fol.logic_ops import unify, subs_list, subs
import numpy as np

class TensorEncoder(object):
    """The tensor encoder for differentiable inference.

    A class for tensor encoding in differentiable forward-chaining approach.

    Args:
        lang (language): The language of first-order logic.
        facts (list(atom)): The set of ground atoms (facts).
        clauses (list(clause)): The set of clauses (rules).
        device (torch.device): The device to be used.

    Attrs:
        lang (language): The language of first-order logic.
        facts (list(atom)): The set of ground atoms (facts).
        clauses (list(clause)): The set of clauses (rules).
        G (int): The number of ground atoms.
        C (int): The number of clauses.
        L (int): The maximum length of the clauses.
        S (int): The maximum number of substitutions for body atoms.
        head_unifier_dic ({(atom, atom) -> List[(var, const)]}): The dictionary to save the computed unification results.
        fact_index_dic ({atom -> int}): The dictionary that maps an atom to its index.
    """

    def __init__(self, lang, facts, clauses, device):
        self.lang = lang
        self.facts = facts
        self.clauses = clauses
        self.device = device
        self.G = len(facts)
        self.C = len(clauses)
        # call before computing S and L
        #self.head_unifier_dic = self.build_head_unifier_dic()
        self.fact_index_dic = self.build_fact_index_dic()
        self.S = self.get_max_subs_num(clauses)
        self.L = max([len(clause.body)
                      for clause in clauses] + [1])

    def get_max_subs_num(self, clauses):
        """Compute S (the maximum numebr of substitutions for body atoms) from clauses.

        Args:
            clauses (list(clause)): A set of clauses.

        Returns:
            S (int): The maximum number of substitutions for existentially quantified variables in the body atoms.
        """
        S_list = []
        for clause in clauses:
            #print("clause: ", clause)
            for fi, fact in enumerate(self.facts):
                #if (clause.head, fact) in self.head_unifier_dic:
                #theta = self.head_unifier_dic[(clause.head, fact)]
                unify_flag, theta = unify([clause.head, fact])
                if unify_flag:
                    clause_ = subs_list(clause, theta)
                    body = clause_.body
                    theta_list = self.generate_subs(body)
                    S_list.append(len(theta_list))
        return max(S_list)

    def encode(self):
        """Compute the index tensor for the differentiable inference.

        Returns
        I (tensor): The index tensor (G, C, S, L).
        """
        I = torch.zeros((self.C, self.G, self.S, self.L),
                        dtype=torch.long).to(self.device)
        for ci, clause in enumerate(self.clauses):
            #print("CLAUSE: ", clause)
            I_c = self.build_I_c(clause)
            I[ci, :, :, :] = I_c
        return I

    def build_I_c(self, clause):
        """Build index tensor for a given clause.

        Args:
            clause (clause): A clause.

        Returns:
            I_c (tensor): The index tensor for the given clause (G, S, L).
        """
        # G * S * L
        I_c = torch.zeros((self.G, self.S, self.L),
                          dtype=torch.long).to(self.device)
        #print("CLAUSE: ", clause)
        for fi, fact in enumerate(self.facts):
            #if (clause.head, fact) in self.head_unifier_dic:
            unify_flag, theta = unify([clause.head, fact])
            if unify_flag:
                #theta = self.head_unifier_dic[(clause.head, fact)]
                clause_ = subs_list(clause, theta)
                # convert body atoms into indices
                I_c_b = self.body_to_tensor(clause_.body)
                I_c[fi] = I_c_b
        return I_c

    def build_fact_index_dic(self):
        """Build dictionary {fact -> index}

        Returns:
            dic ({atom -> int}): A dictionary to map the atoms to indices.
        """
        dic = {}
        for i, fact in enumerate(self.facts):
            dic[fact] = i
        return dic

    def build_head_unifier_dic(self):
        """Build dictionary {(head, fact) -> unifier}.

        Returns:
            dic ({(atom,atom) -> subtitution}): A dictionary to map the pair of ground atoms to their unifier.
        """
        dic = {}
        heads = set([c.head for c in self.clauses])
        for head in heads:
            for fi, fact in enumerate(self.facts):
                unify_flag, theta_list = unify([head, fact])
                if unify_flag:
                    dic[(head, fact)] = theta_list
        return dic

    # taking constant modes to reduce the number of substituions

    def body_to_tensor(self, body):
        """Convert the body atoms into a tensor.

        Args:
            body (list(atom)): The body atoms.

        Returns:
            I_c_b (tensor;(S * L)): The tensor representation of the body atoms.
        """
        # S * L
        I_c_b = torch.zeros(
            (self.S, self.L), dtype=torch.long).to(self.device)

        # extract all vars in the body atoms
        var_list = []
        for atom in body:
            var_list += atom.all_vars()
        var_list = list(set(var_list))

        assert len(
            var_list) <= 10, 'Too many existentially quantified variables in an atom: ' + str(atom)

        if len(var_list) == 0:
            # the case of the body atoms are already grounded
            x_b = self.facts_to_index(body)
            I_c_b[0] = self.pad_by_true(x_b)

            for i in range(1, self.S):
                I_c_b[i] = torch.zeros(self.L, dtype=torch.long).to(
                    self.device)  # fill by FALSE
        else:
            # the body has existentially quantified variable!!
            # e.g. body atoms: [in(img,O1),shape(O1,square)]
            # theta_list: [(O1,obj1), (O1,obj2)]
            theta_list = self.generate_subs(body)
            n_substs = len(theta_list)
            assert n_substs <= self.S, 'Exceeded the maximum number of substitution patterns to existential variables: n_substs is: ' + \
                str(n_substs) + ' but max num is: ' + str(self.S)

            # compute the grounded clause for each possible substitution, convert to the index tensor, and store it.
            for i, theta in enumerate(theta_list):
                ground_body = [subs_list(bi, theta) for bi in body]
                I_c_b[i] = self.pad_by_true(
                    self.facts_to_index(ground_body))
            # if the number of substitutions is less than the maximum number of substitions (S),
            # the rest of the tensor is filled 0, which is the index of FALSE
            for i in range(n_substs, self.S):
                I_c_b[i] = torch.zeros(
                    self.L, dtype=torch.long).to(self.device)
        return I_c_b

    def pad_by_true(self, x):
        """Fill the tensor by ones for the clause which has less body atoms than the longest clause.

        Args:
            x (tensor): The tensor.

        Return:
            x_padded (tensor): The tensor that is padded to the shape of (S, L).
        """
        assert x.size(
            0) <= self.L, 'x.size(0) exceeds max_body_len: ' + str(self.L)
        if x.size(0) == self.L:
            return x
        else:
            diff = self.L - x.size(0)
            x_pad = torch.ones(diff, dtype=torch.long).to(self.device)
            return torch.cat([x, x_pad])

    # taking constant modes to reduce the number of substitutions
    def generate_subs(self, body):
        """Generate substitutions from given body atoms.

        Generate the possible substitutions from given list of atoms. If the body contains any variables,
        then generate the substitutions by enumerating constants that matches the data type.
        !!! ASSUMPTION: The body has variables that have the same data type
            e.g. variables O1(object) and Y(color) cannot appear in one clause !!!

        Args:
            body (list(atom)): The body atoms which may contain existentially quantified variables.

        Returns:
            theta_list (list(substitution)): The list of substitutions of the given body atoms.
        """
        # extract all variables and corresponding data types from given body atoms
        var_dtype_list = []
        dtypes = []
        vars = []
        for atom in body:
            terms = atom.terms
            for i, term in enumerate(terms):
                if term.is_var():
                    v = term
                    dtype = atom.pred.dtypes[i]
                    var_dtype_list.append((v, dtype))
                    dtypes.append(dtype)
                    vars.append(v)
        # in case there is no variables in the body
        if len(list(set(dtypes))) == 0:
            return []
        # check the data type consistency
        assert len(list(set(dtypes))) == 1, "Invalid existentially quantified variables. " + \
            str(len(list(set(dtypes)))) + " data types in the body: " + str(body) + " dypes: " + str(dtypes)

        vars = list(set(vars))
        n_vars = len(vars)
        consts = self.lang.get_by_dtype(dtypes[0])

        # e.g. if the data type is shape, then subs_consts_list = [(red,), (yellow,), (blue,)]
        subs_consts_list = itertools.permutations(consts, n_vars)

        theta_list = []
        # generate substitutions by combining variables to the head of subs_consts_list
        for subs_consts in subs_consts_list:
            theta = []
            for i, const in enumerate(subs_consts):
                s = (vars[i], const)
                theta.append(s)
            theta_list.append(theta)
        # e.g. theta_list: [[(Z, red)], [(Z, yellow)], [(Z, blue)]]
        #print("theta_list: ", theta_list)
        return theta_list

    def facts_to_index(self, atoms):
        """Convert given ground atoms into the indices.
        """
        return torch.tensor([self.get_fact_index(nf) for nf in atoms], dtype=torch.long).to(self.device)

    def get_fact_index(self, fact):
        """Convert a fact to the index in the ordered set of all facts.
        """
        try:
            index = self.fact_index_dic[fact]
        except KeyError:
            index = 0
        return index



class TensorEncoder_mi(object):
    """The tensor encoder for differentiable inference.

    A class for tensor encoding in differentiable forward-chaining approach, adjusted for the meta-interpreter

    Args:
        lang (language): The language of first-order logic.
        facts (list(atom)): The set of ground atoms (facts).
        clauses (list(clause)): The set of clauses (rules).
        device (torch.device): The device to be used.

    Attrs:
        lang (language): The language of first-order logic.
        facts (list(atom)): The set of ground atoms (facts).
        clauses (list(clause)): The set of clauses (rules).
        G (int): The number of ground atoms.
        C (int): The number of clauses.
        L (int): The maximum length of the clauses.
        head_unifier_dic ({(atom, atom) -> List[(var, const)]}): The dictionary to save the computed unification results.
        fact_index_dic ({atom -> int}): The dictionary that maps an atom to its index.
    """

    def __init__(self, lang, facts, clauses, terms, device):
        self.lang = lang
        self.facts = facts
        self.clauses = clauses
        self.device = device
        self.G = len(facts)
        self.C = len(clauses)
        # call before computing S and L
        self.head_unifier_dic = self.build_head_unifier_dic()
        self.fact_index_dic = self.build_fact_index_dic()
        self.terms_all = terms
        self.S = self.get_max_subs_num(clauses)
        self.L = max([len(clause.body)
                      for clause in clauses] + [1])


    def get_max_subs_num(self, clauses):
        """Compute S (the maximum numebr of substitutions for body atoms) from clauses.

        Args:
            clauses (list(clause)): A set of clauses.

        Returns:
            S (int): The maximum number of substitutions for existentially quantified variables in the body atoms.
        """
        S_list = []
        for clause in clauses:
           # print('clause kaishi')
            for fi, fact in enumerate(self.facts):
                if (clause.head, fact) in self.head_unifier_dic:
                    theta = self.head_unifier_dic[(clause.head, fact)]
                    clause_ = subs_list(clause, theta)
                    body = clause_.body
                    theta_list = self.generate_subs(body)
                    S_list.append(len(theta_list))
                  #  print('clause fenzhi')
        return max(max(S_list),1)

    def generate_subs(self, body):
        """Generate substitutions from given body atoms.

        Generate the possible substitutions from given list of atoms. If the body contains any variables,
        then generate the substitutions by enumerating constants that matches the data type.
        the function is used for the rules solve(A):-clause(A,B),solve(B)
        since we can only get the clause(A,B) from clause(A), we need to get solve(B) from the first body predicate

        !!! ASSUMPTION: The body has variables that have the same data type
            e.g. variables O1(object) and Y(color) cannot appear in one clause !!!

        Args:
            body (list(atom)): The body atoms which may contain existentially quantified variables.

        Returns:
            theta_list (list(substitution)): The list of substitutions of the given body atoms.
        """
        # extract all variables and corresponding data types from given body atoms
        var_dtype_list = []
        dtypes = []
        vars = []
        for atom in body:
            terms = atom.terms
            for i, term in enumerate(terms):
                if term.is_var():
                    '''handles the case when the body contains a function term and the function term has variables'''
                    if str(type(term)) == "<class 'fol.logic.FuncTerm'>":
                        for j ,arg in enumerate(term.args):
                            if str(type(arg)) == "<class 'fol.logic.FuncTerm'>":
                                v_list = arg.all_vars()
                                dtype_list = arg.func_symbol.in_dtypes
                                for v, dtype in zip(v_list, dtype_list):
                                    if v not in vars:
                                        vars.append(v)
                                        dtypes.append(dtype)
                            else:
                                v = arg
                                dtype = term.func_symbol.in_dtypes[j]
                                if v not in vars:
                                    vars.append(v)
                                    dtypes.append(dtype)

                    else:
                        v = term
                        dtype = atom.pred.dtypes[i]
                        var_dtype_list.append((v, dtype))
                    if v not in vars:
                        dtypes.append(dtype)
                        vars.append(v)

        #   var_dtype_list.append((v, term))

        # in case there is no variables in the body
        if len(list(set(vars))) == 0:
            return []
        # check the data type consistency
        #   assert len(list(set(dtypes))) == 1, "Invalid existentially quantified variables. " + \
        #       str(len(list(set(dtypes)))) + " data types in the body."



        theta_lists = []
        for i in range(len(dtypes)):
            theta_list = []
            consts = [term for term in self.terms_all if term.dtype == dtypes[i]]
            #consts = self.lang.get_by_dtype(dtypes[i])

        # e.g. if the data type is shape, then subs_consts_list = [(red,), (yellow,), (blue,)]
            subs_consts_list = itertools.permutations(consts, 1)


        # generate substitutions by combining variables to the head of subs_consts_list
            for subs_consts in subs_consts_list:
                theta = []
                for j, const in enumerate(subs_consts):
                    s = (vars[i], const)
                    theta.append(s)
                theta_list.append(theta)
            theta_lists.append(theta_list)
        theta_all = []
        if len(theta_lists) == 1:
            theta_all = theta_lists[0]
        else:
            lst = list(itertools.product(*theta_lists))
          #  args_list = list(set(itertools.product(*consts_list)))
            for item in lst:
                theta_all.append([ele for ele in item])
          #  for item in theta_lists[0]:
          #      for item2 in theta_lists[1]:
          #          theta_all.append([item[0], item2[0]])
        # e.g. theta_list: [[(Z, red)], [(Z, yellow)], [(Z, blue)]]
        return theta_all

    def encode(self):
        """Compute the index tensor for the differentiable inference.

        Returns
        I (tensor): The index tensor (C, G, L).
        """

        I = torch.zeros((self.C, self.G, self.S, self.L),
                        dtype=torch.long).to(self.device)
        for ci, clause in enumerate(self.clauses):
            I_c = self.build_I_c(clause)
            print("build_encoder_complete_this_rules")
            I[ci, :, :, :] = I_c
        return I

    def build_I_c(self, clause):
        """Build index tensor for a given clause.

        Args:
            clause (clause): A clause.

        Returns:
            I_c (tensor): The index tensor for the given clause (G, S, L).
        """
        # G * L

        I_c = torch.zeros((self.G, self.S, self.L),
                          dtype=torch.long).to(self.device)

        for fi, fact in enumerate(self.facts):
            if (clause.head, fact) in self.head_unifier_dic:
                theta = self.head_unifier_dic[(clause.head, fact)]
                clause_ = subs_list(clause, theta)
                # convert body atoms into indices
                I_c_b = self.body_to_tensor(clause_.body)
                I_c[fi] = I_c_b

        return I_c

    def build_fact_index_dic(self):
        """Build dictionary {fact -> index}

        Returns:
            dic ({atom -> int}): A dictionary to map the atoms to indices.
        """
        dic = {}
        for i, fact in enumerate(self.facts):
            dic[str(fact)] = i
        return dic

    def build_head_unifier_dic(self):
        """Build dictionary {(head, fact) -> unifier}.

        Returns:
            dic ({(atom,atom) -> subtitution}): A dictionary to map the pair of ground atoms to their unifier.
        """
        dic = {}
        heads = set([c.head for c in self.clauses])
        for head in heads:
            for fi, fact in enumerate(self.facts):
                unify_flag, theta_list = unify([head, fact])
                if unify_flag:
                    dic[(head, fact)] = theta_list
        return dic

    # taking constant modes to reduce the number of substituions

    def body_to_tensor(self, body):
        """Convert the body atoms into a tensor.

        Args:
            body (list(atom)): The body atoms.

        Returns:
            I_c_b (tensor;( L)): The tensor representation of the body atoms.
        """
        # S * L
        I_c_b  = torch.zeros(
            (self.S, self.L), dtype=torch.long).to(self.device)
        groundbody = []
        var_list = []
        for atom in body:
            var_list += atom.all_vars()
        var_list = list(set(var_list)) # 去除重复的variable

        assert len(
            var_list) <= 10, 'Too many existentially quantified variables in an atom: ' + str(atom)

        if len(var_list) == 0:
            #detect the grounded move and step predicate
            groundbody = []
            for atom in body:
                groundbody.append(atom)
            I_c_b[0] = self.pad_by_true(
                self.facts_to_index(groundbody))
            for i in range(1, self.S):
                I_c_b[i] = torch.zeros(self.L, dtype=torch.long).to(
                    self.device)


        else:
            # the body has existentially quantified variable!!
            # e.g. body atoms: [in(img,O1),shape(O1,square)]
            # theta_list: [(O1,obj1), (O1,obj2)]

            theta_list = self.generate_subs(body)
            n_substs = len(theta_list)
        #    assert n_substs <= self.S, 'Exceeded the maximum number of substitution patterns to existential variables: n_substs is: ' + \
        #        str(n_substs) + ' but max num is: ' + str(self.S)

            # compute the grounded clause for each possible substitution, convert to the index tensor, and store it.
            for i, theta in enumerate(theta_list):
                if len(theta) > 1:
                    theta = [item[0] for item in theta]
                else:
                    theta = [item for item in theta]

                ground_body = [subs_list(bi, theta) for bi in body]

                I_c_b[i] = self.pad_by_true(self.facts_to_index(ground_body))
            # if the number of substitutions is less than the maximum number of substitions (S),
            # the rest of the tensor is filled 0, which is the index of FALSE
            for i in range(n_substs, self.S):
                I_c_b[i] = torch.zeros(
                    self.L, dtype=torch.long).to(self.device)



        return I_c_b

    def pad_by_true(self, x):
        """Fill the tensor by ones for the clause which has less body atoms than the longest clause.

        Args:
            x (tensor): The tensor.

        Return:
            x_padded (tensor): The tensor that is padded to the shape of (S, L).
        """
        assert x.size(
            0) <= self.L, 'x.size(0) exceeds max_body_len: ' + str(self.L)
        if x.size(0) == self.L:
            return x
        else:
            diff = self.L - x.size(0)
            x_pad = torch.ones(diff, dtype=torch.long).to(self.device)
            return torch.cat([x, x_pad])

    def facts_to_index(self, atoms):
        """Convert given ground atoms into the indices.
        """
        return torch.tensor([self.get_fact_index(nf) for nf in atoms], dtype=torch.long).to(self.device)

    def get_fact_index(self, fact):
        """Convert a fact to the index in the ordered set of all facts.
        """
       # index = self.fact_index_dic[str(fact)]
       # a = index
        try:
            index = self.fact_index_dic[str(fact)]
            print('something is ok')

        except KeyError:
            print('something is wrong')
            print (fact)
            index = 0

        return index