
import torch.nn.functional as F
import torch
import torch.nn as nn
import random
import numpy as np
# random.seed(a=1)


def logsumexp(inputs, dim=None, keepdim=False):
    """Numerically stable logsumexp.
    from https://github.com/pytorch/pytorch/issues/2591#issuecomment-364474328
    Args:
        inputs: A Variable with any shape.
        dim: An integer.
        keepdim: A boolean.

    Returns:
        Equivalent of log(sum(exp(inputs), dim=dim, keepdim=keepdim)).
    """
    # For a 1-D array x (any array along a single dimension),
    # log sum exp(x) = s + log sum exp(x - s)
    # with s = max(x) being a common choice.
    if dim is None:
        inputs = inputs.view(-1)
        dim = 0
    s, _ = torch.max(inputs, dim=dim, keepdim=True)
    outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
    if not keepdim:
        outputs = outputs.squeeze(dim)
    return outputs


def _softor(xs, dim=0, gamma=0.01):
    """
    softor function for valuation vectors
    """
    # xs is List[Tensor] or Tensor
    if not torch.is_tensor(xs):
        xs = torch.stack(xs, dim)
    log_sum_exp = gamma*logsumexp(xs * (1/gamma), dim=dim)
    if log_sum_exp.max() > 1.0:
        return log_sum_exp / log_sum_exp.max()
    else:
        return log_sum_exp


def __softor(xs, dim=0, gamma=0.01):
    """
    softor function for valuation vectors
    """
    # xs is List[Tensor] or Tensor
    if not torch.is_tensor(xs):
        xs = torch.stack(xs, dim)
    return gamma*logsumexp(xs * (1/gamma), dim=dim)
    # return gamma*torch.logsumexp(xs * (1/gamma), dim=dim)


def weight_sum(W_l, H):
    # W: C
    # H: C * B * G
    W_ex = W_l.unsqueeze(dim=-1).unsqueeze(dim=-1).expand_as(H)
    # C * B * G
    WH = W_ex * H
    # B * G
    WH_sum = torch.sum(WH, dim=0)
    return WH_sum


class FCReasoner(nn.Module):
    """
    A class of differentiable foward-chaining inference.
    """

    def __init__(self, X, m, infer_step, gamma=0.01, device=None):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(FCReasoner, self).__init__()
        self.X = X
        self.infer_step = infer_step
        self.m = m
        self.C = self.X.size(0)
        self.G = self.X.size(1)
        self.gamma = gamma
        self.device = device
        self.W = self.init_identity_weights(m, device)
        # to learng the clause weights, initialize W as follows:
        # self.W = nn.Parameter(torch.Tensor(
        #    np.random.normal(size=(m, X.size(0)))).to(device))
        # clause functions
        self.cs = [ClauseFunction(i, X, gamma=gamma)
                   for i in range(self.X.size(0))]

    def init_identity_weights(self, m, device):
        ones = torch.ones((m, ), dtype=torch.float32) * 100
        return torch.diag(ones).to(device)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        R = x
        for t in range(self.infer_step):
            R = self.softor([R, self.r(R)], dim=1)
        return R

    def softor(self, xs, dim=0, gamma=0.01):
        """
        softor function for valuation vectors
        """
        # xs is List[Tensor] or Tensor
        if not torch.is_tensor(xs):
            xs = torch.stack(xs, dim)
        log_sum_exp = self.gamma*logsumexp(xs * (1/self.gamma), dim=dim)
        if log_sum_exp.max() > 1.0:
            return log_sum_exp / log_sum_exp.max()
        else:
            return log_sum_exp

    def r(self, x):
        B = x.size(0)  # batch size
        # apply each clause c_i and stack to a tensor C
        # C * B * G
        C = torch.stack([self.cs[i](x)
                        for i in range(self.X.size(0))], 0)

        # taking weighted sum using m weights and stack to a tensor H
        # m * C
        W_star = torch.softmax(self.W, 1)
        # m * C * B * G
        W_tild = W_star.unsqueeze(
            dim=-1).unsqueeze(dim=-1).expand(self.m, self.C, B, self.G)
        # m * C * B * G
        C_tild = C.unsqueeze(dim=0).expand(self.m, self.C, B, self.G)
        # m * B * G
        H = torch.sum(W_tild * C_tild, dim=1)
        # taking soft or to compose a logic program with m clauses
        # B * G
        R = self.softor(H, dim=0)
        return R

    def __r(self, x):
        # apply each clause c_i and stack to a tensor C
        # C * B * G
        C = torch.stack([self.cs[i](x)
                         for i in range(self.X.size(0))], 0)

        # taking weighted sum using m weights and stack to a tensor H
        # C * B * G
        H = torch.stack([weight_sum(W_l, C)
                         for W_l in torch.softmax(self.W, 1)], 0)
        H = self.softor(torch.sum)
        # taking soft or to compose a logic program with m clauses
        # B * G
        R = self.softor(H, dim=0)
        return R


class ClauseFunction(nn.Module):
    """
    A class of the clause function.
    """

    def __init__(self, i, I, gamma=0.01):
        super(ClauseFunction, self).__init__()
        self.i = i  # clause index
        self.I = I  # index tensor C * S * G, S is the number of possible substituions
        self.L = I.size(-1)  # number of body atoms
        self.S = I.size(-2)  # max number of possible substitutions
        self.gamma = gamma

    def softor(self, xs, dim=0, gamma=0.01):
        """
        softor function for valuation vectors
        """
        # xs is List[Tensor] or Tensor
        if not torch.is_tensor(xs):
            xs = torch.stack(xs, dim)
        log_sum_exp = self.gamma*logsumexp(xs * (1/self.gamma), dim=dim)
        if log_sum_exp.max() > 1.0:
            return log_sum_exp / log_sum_exp.max()
        else:
            return log_sum_exp

    def forward(self, x):
        batch_size = x.size(0)  # batch size
        # B * G
        V = x
        # G * S * b
        I_i = self.I[self.i, :, :, :]

        # B * G -> B * G * S * L
        V_tild = V.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, self.S, self.L)
        # G * S * L -> B * G * S * L
        I_i_tild = I_i.repeat(batch_size, 1, 1, 1)

        # B * G
        C = self.softor(torch.prod(
            torch.gather(V_tild, 1, I_i_tild), 3), dim=2)
        return C
