import itertools
import torch
from .logic_ops import unify, subs_list, subs
from tqdm import tqdm


class TensorEncoder():
    """
    tensor encoder

    Parameters
    ----------
    facts : List[.logic.Atom]
        enumerated ground atoms
    clauses : List[.logic.Clause]
        generated clauses
    """

    def __init__(self, lang, facts, clauses, max_subs_num=9, device=None):
        self.lang = lang  # preds, funcs, consts, mode_declarations
        self.facts = facts
        self.clauses = clauses
        self.max_body_len = max([len(clause.body)
                                 for clause in clauses] + [1])
        self.head_unifier_dic = self.build_head_unifier_dic()
        self.fact_index_dic = self.build_fact_index_dic()

        self.max_subs_num = max_subs_num
        self.device = device

    def encode(self):
        """
        compute index tensors for the differentiable inference

        Returns
        -------
        X : torch.tensor((|clauses|,|facts|, max_subs_num, max_body_len, ))
            index tensor
        """
        X = torch.zeros(
            (len(self.clauses), len(self.facts), self.max_subs_num, self.max_body_len), dtype=torch.long).to(self.device)
        print("Encoding to Tensors:", X.shape)

        for ci, clause in enumerate(tqdm(self.clauses)):
            X_c = self.build_X_c(clause)
            X[ci, :, :, :] = X_c
        return X

    def print_tensor(self, X):
        for clause in self.clauses:
            print(clause)
            for fact in self.facts:
                print(fact)
        pass

    def __encode_batch(self, batch_size=512):
        """
        compute index tensors for the differentiable inference

        Returns
        -------
        X : torch.tensor((|clauses|, batch, |facts|, max_subs_num, max_body_len, ))
            index tensor
        """
        X = torch.zeros(
            (len(self.clauses), batch_size, len(self.facts), self.max_subs_num, self.max_body_len), dtype=torch.long).to(self.device)
        for ci, clause in enumerate(self.clauses):
            # X_c : torch.tensor((|facts|, max_body_len))
            X_c = self.build_X_c(clause)
            # X_c_batch : torch.tensor((batch_size, |facts|, max_body_len))
            # .view(batch_size, len(self.facts), self,max_subs_num, self.max_body_len)
            X_c_batch = X_c.repeat(batch_size, 1, 1, 1)
            X[ci] = X_c
        return X

    def build_fact_index_dic(self):
        """
        build dictionary [FACT -> INDEX]

        Returns
        -------
        dic : {.logic.Atom -> int}
            dictionary of ground atoms to indexes
        """
        dic = {}
        for i, fact in enumerate(self.facts):
            dic[fact] = i
        return dic

    def build_head_unifier_dic(self):
        """
        build dictionary [(HEAD, FACT) -> THETA].

        Returns
        -------
        dic : {(.logic.Atom, .logic.Atom) -> List[(.logic.Var, .logic.Const)]}
            if the pair of atoms are unifiable, the dictional returns the unifier for them
        """
        dic = {}
        heads = set([c.head for c in self.clauses])
        for head in heads:
            for fi, fact in enumerate(self.facts):
                unify_flag, theta_list = unify([head, fact])
                if unify_flag:
                    dic[(head, fact)] = theta_list
        return dic

    def build_X_c(self, clause):
        """
        build index tensor for a given clause
        for ci in C, X_ci == X[i]
        X_c[i] is a list of indexes of facts that are needed to entail facts[i] using given clause.

        Inputs
        ------
        clasue : .logic.Clause
            input clause

        Returns
        -------
        X_c : torch.tensor((|facts|, |max_subs_num|, max_body_len))
            index tensor for clause c
        """
        X_c = torch.zeros(
            (len(self.facts), self.max_subs_num, self.max_body_len), dtype=torch.long).to(self.device)

        for fi, fact in enumerate(tqdm(self.facts)):
            if (clause.head, fact) in self.head_unifier_dic:
                theta = self.head_unifier_dic[(clause.head, fact)]
                clause_ = subs_list(clause, theta)
                # body with existentially quantified variable
                X_c_b = self.body_to_tensor(clause_.body)
                #print(clause, fact, clause_.body)
                X_c[fi] = X_c_b
        return X_c

    # taking constant modes to reduce the number of substituions
    def body_to_tensor(self, body):
        # X_c_b: max_subs_num * max_body_len
        X_c_b = torch.zeros(
            (self.max_subs_num, self.max_body_len), dtype=torch.long).to(self.device)
        var_list = []
        const_list = []
        for atom in body:
            var_list += atom.all_vars()
            const_list += atom.all_consts()

        var_list = list(set(var_list))
        const_list = list(set(const_list))

        # print(var_list)
        assert len(
            var_list) <= 10, 'Too many existentially quantified variables in an atom: ' + str(atom)

        # var_list += atom.all_vars()
        #var_list = list(set(var_list))

        #print('var_list: ', var_list)
        if len(var_list) == 0:
            # only ground atoms
            x_b = self.ground_atoms_to_index(body)  # TODO: fill by TRUE
            X_c_b[0] = self.pad_by_true(x_b)

            for i in range(1, self.max_subs_num):
                X_c_b[i] = torch.zeros(self.max_body_len, dtype=torch.long).to(
                    self.device)  # fill by FALSE
        else:
            # the body has existentially quantified variable!!

            # body: [p(a,Z),q(Z,a)]
            # subs_list: [(Z,a), (Z,b)]

            # for each substitute, compute the indices of ground body atom
            ##subs_list = []
            # ASSUMTION that all consts have a same mode e.g. img1
            cmode = const_list[0].mode
            theta_list = self.generate_subs(body, cmode)
            # for atom in body:
            # if len(atom.all_vars()) > 0:
            ##       subs = self.get_subs_by_dtype(atom, consts)
            # subs_list.append(subs)
            # print(theta_list)
            n_substs = len(theta_list)
            # print(theta_list)
            assert n_substs <= self.max_subs_num, 'Exceeded the maximum number of substitution patterns to existential variables: n_substs is: ' + \
                str(n_substs) + ' but max num is: ' + str(self.max_subs_num)

            #theta_list:  [[(Z, red)], [(Z, yellow)], [(Z, blue)]]
            for i, theta in enumerate(theta_list):
                # print(theta)
                #ground_body = [subs(bi, theta[0], theta[1]) for bi in body]
                ground_body = [subs_list(bi, theta) for bi in body]
                X_c_b[i] = self.pad_by_true(
                    self.ground_atoms_to_index(ground_body))

            for i in range(n_substs, self.max_subs_num):
                X_c_b[i] = torch.zeros(self.max_body_len, dtype=torch.long).to(
                    self.device)  # fill by FALSE

        #print('body: ', body)
        #print('X_c_b: ', X_c_b)
        #print(body, X_c_b)
        return X_c_b

    def ground_atoms_to_index(self, atoms):
        return torch.tensor([self.get_fact_index(nf) for nf in atoms], dtype=torch.long).to(self.device)

    def pad_by_true(self, x):
        assert x.size(
            0) <= self.max_body_len, 'x.size(0) exceeds max_body_len: ' + str(self.max_body_len)
        if x.size(0) == self.max_body_len:
            return x
        else:
            diff = self.max_body_len - x.size(0)
            x_pad = torch.ones(diff, dtype=torch.long).to(self.device)
            return torch.cat([x, x_pad])

    # taking constant modes to reduce the number of substitutions
    def generate_subs(self, body, cmode):
        # assumption that there is no contradiction for variable dtypes e.g. x/object and x/color
        #assert len(var_list) < 5, 'Invalid number of variables in the body'

        var_dtype_list = []
        for atom in body:
            terms = atom.terms
            for i, term in enumerate(terms):
                if term.is_var():
                    v = term
                    dtype = atom.pred.dtypes[i]
                    var_dtype_list.append((v, dtype))

            # print(atom)
            #var_dtype_list = var_dtype_list + self.lang.get_var_and_dtype(atom)

        var_dtype_list = list(set(var_dtype_list))
        ordered_var_list = [x[0] for x in var_dtype_list]

        var_consts_pairs = []  # List[(x, [a,b,c]), ...]
        for v, dtype in var_dtype_list:
            ##print(cmode, body)
            # mode = [x[1] for x in var_mode_list if x[0] == v] # assumption that there is no contradiction of var modes: a var has one mode in one clause
            ####consts_by_dtype = [c for c in self.lang.consts if c.dtype == dtype and (c.mode == None or c.mode == cmode)]
            #consts_by_dtype = [c for c in self.lang.consts if c.dtype == dtype]
            consts_by_dtype = self.lang.get_by_dtype(dtype)

            var_consts_pairs.append((v, consts_by_dtype))

        consts_list = [x[1]
                       for x in var_consts_pairs]  # [[a,b,c], [a,b,c], ..]

        # taking outer product List[(a,a), (a,b)]
        #subs_consts_list = itertools.product(*consts_list)
        subs_consts_list = itertools.permutations(
            consts_list[0], len(consts_list))

        # make sure c_i != c_j (i != j)
        # print(consts_by_dtype)
        subs_consts_list = [
            ls for ls in subs_consts_list if len(list(set(ls))) == len(ls)]
        # if len(list(set(consts_by_dtype))) == len(consts_by_dtype):

        theta_list = []
        # generate substitutions by combining variables to the head of subs_consts_list
        for subs_consts in tqdm(subs_consts_list):
            # (a,a)
            theta = []
            for i, const in enumerate(subs_consts):
                s = (ordered_var_list[i], const)
                theta.append(s)
                # print(theta)
            theta_list.append(theta)
        #print('theta_list: ', theta_list)
        # theta_list:  [[(Z, red)], [(Z, yellow)], [(Z, blue)]]
        return theta_list  # List[[(x,a), (y,a)], [(x,a), (y,b)], ...]

    # general exitentially quantified variables => not scaling
    def __generate_subs(self, body, var_list):
        # TODO: FIX BUG for MULTI EXISTENALLY QUANTIFIED VARIABLES
        # assumption that there is no contradiction for variable dtypes e.g. x/object and x/color
        #assert len(var_list) < 5, 'Invalid number of variables in the body'

        var_dtype_list = []
        for atom in body:
            var_dtype_list = var_dtype_list + self.lang.get_var_and_dtype(atom)

        var_dtype_list = list(set(var_dtype_list))
        ordered_var_list = [x[0] for x in var_dtype_list]

        var_consts_pairs = []  # List[(x, [a,b,c]), ...]
        for v, dtype in var_dtype_list:
            consts_by_dtype = [c for c in self.lang.consts if c.dtype == dtype]
            var_consts_pairs.append((v, consts_by_dtype))

        consts_list = [x[1]
                       for x in var_consts_pairs]  # [[a,b,c], [a,b,c], ..]

        # taking outer product List[(a,a), (a,b)]
        subs_consts_list = itertools.product(*consts_list)

        theta_list = []
        # generate substitutions by combining variables to the head of subs_consts_list
        for subs_consts in subs_consts_list:
            # (a,a)
            theta = []
            for i, const in enumerate(subs_consts):
                s = (ordered_var_list[i], const)
                theta.append(s)
                # print(theta)
            theta_list.append(theta)
        #print('theta_list: ', theta_list)
        # theta_list:  [[(Z, red)], [(Z, yellow)], [(Z, blue)]]
        return theta_list  # List[[(x,a), (y,a)], [(x,a), (y,b)], ...]

    def ____generate_subs(self, var_list):
        assert len(var_list) == 1, 'Too many existentaiily quantified variable'
        theta_list = []
        consts = self.lang.consts
        for v in var_list:
            consts_filtered = [
                c for c in self.lang.consts if c.dtype == v.dtype]
            assert len(
                consts_filtered) > 0, 'No substitutions to existentially quantified variables.'
            for c in consts_filtered:
                theta_list.append((v, c))
        return list(set(theta_list))

    def __generate_subs(self, var_list):
        assert len(var_list) == 1, 'Too many existentaiily quantified variable'
        theta_list = []
        consts = self.lang.consts
        for v in var_list:
            consts_filtered = [
                c for c in self.lang.consts if c.dtype == v.dtype]
            assert len(
                consts_filtered) > 0, 'No substitutions to existentially quantified variables.'
            for c in consts_filtered:
                theta_list.append((v, c))
        return list(set(theta_list))

    def get_subs_by_dtype(self, atom, consts):
        # generate substitution depending on the variables in the input atom
        pass

    def index_list_to_tensor(self, index_list):
        """
        convert list of indexes to torch tensor 
        filling the gap by \top (See Eq.4 in the paper)

        Inputs
        ------
        index_list : List[int]
            list of indexes of ground atoms in the body

        Returns
        -------
        body_tensor : torch.tensor((max_body_len, ))
            index tensor for a clause
            indexes of ground atoms in the body after the unification of its head
        """
        diff = self.max_body_len - len(index_list)
        if diff > 0:
            return torch.tensor(index_list + [1 for i in range(diff)], dtype=torch.int32).to(self.device)
        else:
            return torch.tensor(index_list, dtype=torch.int32).to(self.device)

    def get_fact_index(self, fact):
        """
        convert fact to index in the ordered set of all facts

        Inputs
        ------
        fact : .logic.Atom
            ground atom

        Returns
        -------
        index : int
            index of the input ground atom
        """
        try:
            index = self.fact_index_dic[fact]
        except KeyError:
            #print('key error: ', fact)
            index = 0
        return index
