import torch.nn.functional as F
import torch
import random
import numpy as np
random.seed(a=1)


def convert(bk, facts):
    """
    convert background knowledge to the valuation vector

    Inputs
    ------
    bk : List[.logic.Atom]
        background knowledge
    facts : List[.logic.Atom]
        enumerated ground atoms

    Returns
    -------
    v_0 : torch.Tensor((|facts|, ))
        initial valuation vector
    """
    return torch.tensor([1.0 if (fact in bk or str(fact) == '.(__T__)') else 0.0 for fact in facts], dtype=torch.float32).to(self.device)


def bmul(vec, mat, axis=0):
    # [1,2,3] * [[1,1],[1,1],[1,1]] -> [[1.1],[2,2],[3,3]]
    mat = mat.transpose(axis, -1)
    return (mat * vec.expand_as(mat)).transpose(axis, -1)


def weight_sum(W, H):
    # W: |C|
    # H: |C| * Batch * |G|
    W_ex = W.unsqueeze(dim=-1).unsqueeze(dim=-1).expand_as(H)
    WH = W_ex * H

    # Batch  * |G|
    WH_sum = torch.sum(WH, dim=0)
    # print('H: ', H.shape,  H)
    # print('W: ', W.shape, W)
    # print('WH: ', WH.shape,  WH)
    # print('WH_sum: ', WH_sum.shape,  WH_sum)
    return WH_sum


class InferModule():
    """
    differentiable inference module

    Parameters
    ----------
    X : torch.Tensor((|clauses|, |facts|, max_body_len, ))
        index tensor
    m : int
        size of the logic program
    v_0 : torch.Tensor((|facts|, ))
        initial valuation vector
    infer_step : int
        number of steps in forward-chaining inference
    device : torch.device
        gpu id or cpu
    """

    def __init__(self, X, infer_step, m=3, device=None):
        self.X = X
        # self.batch_X = self.to_batch_tensor(X)
        self.m = m  # X.size(0)
        # print("X: ", X.shape)
        self.max_body_len = X.size(-1)
        self.infer_step = infer_step
        self.device = device
        assert self.device != 'None', 'No device in Infer module'
        self.Ws = self.init_weights()

    def init_weights(self):
        return torch.rand((self.m, self.X.size(0)), requires_grad=True, device=self.device)

    def __init_weights(self):
        """
        initialize weights randomly

        Inputs
        ------
        m : int
            size of the logic program
        """
        weights = []
        for i in range(self.m):
            w = torch.zeros((self.X.size(0),)).to(self.device)
            w[i] = 1.0
            weights.append(w)
        weights = torch.stack(weights, dim=0).to(self.device)
        # print('weights: ',weights)
        return weights
        # return [torch.tensor([1.0 for i in range(len(self.X))]).to(
        #    device).detach().requires_grad_(False)
        #    for j in range(m)]

    def get_params(self):
        return [self.Ws]

    def get_weights(self):
        return self.softmax(self.Ws, beta=0.1, dim=1)
        # return self.softmax(
        #   self.softmax(self.Ws, beta=0.5, dim=0),
        #   beta=0.1, dim=1)

    def infer(self, V_0):
        """ TODO: BATCH
        f_infer function
        compute V_0, V_1, ..., v_T and return v_T

        Returns
        -------
        v_T : torch.tensor((|facts|, ))
            valuation vector of the result of the forward-chaining inference
        """
        valuation = V_0
        n = len(self.Ws)
        Ws_softmaxed = self.get_weights()

        ones = torch.ones(self.Ws.size(0)).to(self.device)

        for t in range(self.infer_step):
            # C * B * G
            F = torch.stack([self.F_c(ci, valuation)
                             for ci in range(self.X.size(0))], 0)
            # H_t: C * B * G
            H_t = [weight_sum(W, F) for W in Ws_softmaxed]

            # print('H_t_i: ', H_t[0].shape)
            # B * G
            h_t = self.softor(H_t)
            valuation, indices = torch.max(
                torch.stack([valuation, h_t], dim=0), dim=0)

        assert valuation.max() < 1.1, 'Wrong valuation'
        if (valuation > 1.0).any():
            valuation = valuation / torch.max(valuation)

        return valuation

    def F_c(self, ci, V_t):
        """
        c_i function
        forward-chaining inference using a clause

        Inputs
        ------
        ci : .logic.Clause
            i-th clause in the set of enumerated clauses
        valuation : torch.tensor((|facts|, ))
            current valuation vector v_t

        Returns
        -------
        v_{t+1} : torch.tensor((|facts|, ))
            valuation vector after 1-step forward-chaining inference
        """
        batch_size = V_t.size(0)
        # X_c = self.X[ci, :, :, :]
        X_c = self.X[ci, :, :, :]
        # batch * G * subs * bodylen
        # .view(batch_size, X_c.size(0), X_c.size(1), X_c.size(2))
        X_c_batch = X_c.repeat(batch_size, 1, 1, 1)

        # each possible assignments for existentially quantified variables
        # list with respect to the different patterns of substitutions in the body
        body_prod_list = []
        for subs_i in range(self.X.size(2)):
            X_c_b_batch = X_c_batch[:, :, subs_i, :]
            gathered_tensor = torch.stack([torch.gather(V_t, 1, X_c_b_batch[:, :, i])
                                           for i in range(self.max_body_len)])
            body_prod = torch.prod(gathered_tensor, 0)
            body_prod_list.append(body_prod)
        # print('Clause: ', ci, 'body_prod_list, ', body_prod_list)
        return self.softor(body_prod_list, dim=1)

    def softmax(self, x, beta=1.0, dim=0):
        """
        softmax fuction for torch vectors
        """
        return F.softmax(x / beta, dim=dim)

    def amalgamate(self, x, y):
        """
        amalgamate function for valuation vectors
        """
        # x + y - x*y
        return self.softor([x, y])

    def gather(self, x, y):
        """
        gather function for torch tensors
        """
        tensors = [torch.gather(x, 0, y[:, i]).unsqueeze(-1)
                   for i in range(self.max_body_len)]
        return torch.cat(tensors, -1).to(self.device)

    def softor(self, xs, dim=0, gamma=0.01):
        """
        softor function for valuation vectors
        """
        xs_tensor = torch.stack(xs, dim) * (1/gamma)
        return gamma*torch.logsumexp(xs_tensor, dim=dim)

    def softor_tensor(self, X, dim=0, gamma=1e-5):
        return gamma*torch.logsumexp(X * (1/gamma), dim=dim)

    def prod_body(self, gathered):
        """
        taking product along the body atoms

        Inputs
        ------
        gathered : torch.tensor(())

        Returns
        -------
        result : torch.tensor(())
        """
        result = torch.ones(self.valuation_memory[0].shape).to(self.device)
        result[0] = 0.0  # False = 0.0
        for i in range(self.max_body_len):
            result = result * gathered[:, i]
        return result


def softmax(x, beta=1.0, dim=0):
    """
    softmax fuction for torch vectors
    """
    return F.softmax(x / beta, dim=dim)


def softor(xs, dim=0, gamma=0.01):
    """
    softor function for valuation vectors
    """
    # xs is List[Tensor] or Tensor
    if not torch.is_tensor(xs):
        xs = torch.stack(xs, dim)
    return gamma*torch.logsumexp(xs * (1/gamma), dim=dim)


def amalgamate(x, y):
    """
    amalgamate function for valuation vectors
    """
    # x + y - x*y
    return self.softor([x, y])
