import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class JointAttentionLayer(nn.Module):
    """
    Dense version of joint attention layer
    """

    def __init__(self, in_features, out_features, dropout, alpha, scale, autobalance, roughconstrain, concat=True):
        super(JointAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat
        self.autobalance = autobalance
        self.roughconstrain = roughconstrain
        self.scale = scale

        if self.autobalance:
            self.aw = nn.Parameter(torch.zeros(size=(1, 2)))
            nn.init.xavier_uniform_(self.aw.data, gain=1.414)
        else:
            self.fw = 0.75

        self.conv = torch.nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=1, bias=False)
        self.conv1 = torch.nn.Conv1d(in_channels=out_features, out_channels=1, kernel_size=1, bias=False)
        self.conv2 = torch.nn.Conv1d(in_channels=out_features, out_channels=1, kernel_size=1, bias=False)

        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, input, adj, cadinality, ssc_coef):

        h = torch.squeeze(self.conv(torch.unsqueeze(input, dim=2)), dim=2)
        N = h.size()[0]

        a_input1 = torch.squeeze(self.conv1(torch.unsqueeze(h, dim=2)), dim=2)
        a_input2 = torch.squeeze(self.conv2(torch.unsqueeze(h, dim=2)), dim=2)
        a_input = a_input1.repeat(1, N)
        a_input = a_input + (a_input2.repeat(1, N)).t()
        e = self.leakyrelu(a_input)

        zero_vec = -9e15 * torch.ones_like(e)

        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)

        if self.roughconstrain:
            # all the coefficients where adj > 0 are considered
            satt = torch.where(adj > 0, ssc_coef, zero_vec)
        else:
            # only positive coefficients of connected vertices are considered
            satt = torch.where(adj * ssc_coef > 0, ssc_coef, zero_vec)

        satt = F.softmax(satt, dim=1)

        if self.autobalance:
            weight = F.softmax(self.aw, dim=1)
            attention = weight[0, 0] * attention + weight[0, 1] * satt
        else:
            attention = self.fw * attention + (1 - self.fw) * satt

        attention = F.normalize(attention, dim=1)

        # permutation ensuring the model possess the WL-test ability
        attention = attention + self.scale * torch.diag(cadinality, diagonal=0)

        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, h)

        torch.cuda.empty_cache()
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class StructuralSubspaceLayer(nn.Module):
    # batch: the number of samples in each mini batch
    def __init__(self, batch):
        super(StructuralSubspaceLayer, self).__init__()
        self.batch = batch

        self.W = nn.Parameter(torch.zeros(size=(self.batch, self.batch)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)

    def forward(self, x):
        ssc_coef = self.W - torch.diag(torch.diag(self.W), diagonal=0)
        selfrep = torch.mm(ssc_coef, x)
        return selfrep, ssc_coef


class SpecialSpmmFunction(torch.autograd.Function):
    """Special function for only sparse region backpropataion layer."""
    @staticmethod
    def forward(ctx, indices, values, shape, b):
        assert indices.requires_grad == False
        a = torch.sparse_coo_tensor(indices, values, shape)
        ctx.save_for_backward(a, b)
        ctx.N = shape[0]
        return torch.matmul(a, b)

    @staticmethod
    def backward(ctx, grad_output):
        a, b = ctx.saved_tensors
        grad_values = grad_b = None
        if ctx.needs_input_grad[1]:
            grad_a_dense = grad_output.matmul(b.t())
            edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :]
            grad_values = grad_a_dense.view(-1)[edge_idx]
        if ctx.needs_input_grad[3]:
            grad_b = a.t().matmul(grad_output)
        return None, grad_values, None, grad_b


class SpecialSpmm(nn.Module):
    def forward(self, indices, values, shape, b):
        return SpecialSpmmFunction.apply(indices, values, shape, b)


class SpJointAttentionLayer(nn.Module):
    """
        Sparse version of joint attention layer
    """

    def __init__(self, in_features, out_features, dropout, alpha, scale, autobalance, roughconstrain, concat=True):
        super(SpJointAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat
        self.autobalance = autobalance
        self.roughconstrain = roughconstrain
        self.scale = scale

        if self.autobalance:
            self.aw = nn.Parameter(torch.zeros(size=(1, 2)))
            nn.init.xavier_normal_(self.aw.data, gain=1.414)
        else:
            self.fw = 0.75

        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_normal_(self.W.data, gain=1.414)

        self.a = nn.Parameter(torch.zeros(size=(1, 2 * out_features)))
        nn.init.xavier_normal_(self.a.data, gain=1.414)

        # self.dropout = nn.Dropout(dropout)
        self.leakyrelu = nn.LeakyReLU(self.alpha)
        self.special_spmm = SpecialSpmm()

    def forward(self, input, adj, cadinality, ssc_coef):
        dv = 'cuda' if input.is_cuda else 'cpu'

        N = input.size()[0]
        edge = adj.nonzero().t()

        h = torch.mm(input, self.W)
        # h: N x out
        assert not torch.isnan(h).any()
        # h[torch.isnan(h)] = 0

        # Feature-attention on the nodes - Shared attention mechanism
        edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t()
        # edge: 2*D x E

        edge_e = torch.exp(-self.leakyrelu(self.a.mm(edge_h).squeeze()))
        assert not torch.isnan(edge_e).any()
        # edge_e[torch.isnan(edge_e)] = 0
        # edge_e: E

        e_rowsum = self.special_spmm(edge, edge_e, torch.Size([N, N]), torch.ones(size=(N, 1), device=dv))
        # e_rowsum: N x 1

        edge_e = edge_e.unsqueeze(dim=1)
        edge_e = edge_e.div(e_rowsum[edge[0, :], :])
        assert not torch.isnan(edge_e).any()
        # edge_e[torch.isnan(edge_e)] = 0

        if self.roughconstrain:
            # all the coefficients where adj > 0 are considered
            satt = torch.masked_select(ssc_coef, adj > 0)
            satt = torch.unsqueeze(satt, dim=1)
            satt = torch.exp(satt)
        else:
            # only positive coefficients of connected vertices are considered
            satt = torch.masked_select(ssc_coef, adj > 0)
            satt = torch.unsqueeze(satt, dim=1)
            zero_vec = -9e15 * torch.ones_like(satt)
            satt = torch.where(satt > 0, satt, zero_vec)
            satt = torch.exp(satt)

        satt_rowsum = self.special_spmm(edge, satt.squeeze(), torch.Size([N, N]), torch.ones(size=(N, 1), device=dv))
        satt = satt.div(satt_rowsum[edge[0, :], :])
        satt[torch.isnan(satt)] = 0
        # assert not torch.isnan(satt).any()

        if self.autobalance:
            weight = F.softmax(self.aw, dim=1)
            edge_e = weight[0, 0] * edge_e + weight[0, 1] * satt
        else:
            edge_e = self.fw * edge_e + (1 - self.fw) * satt

        e_rowsum = self.special_spmm(edge, edge_e.squeeze(), torch.Size([N, N]), torch.ones(size=(N, 1), device=dv))

        edge_e = edge_e.div(e_rowsum[edge[0, :], :])
        edge_e[torch.isnan(edge_e)] = 0

        # edge_e = self.dropout(edge_e)
        edge_e = F.dropout(edge_e, self.dropout, training=self.training)
        # edge_e: E

        h_prime = self.special_spmm(edge, edge_e.squeeze(), torch.Size([N, N]), h)

        # small permutation ensuring WL-test

        h_prime = h_prime + self.scale * F.dropout(cadinality.unsqueeze(dim=1), self.dropout,
                                                   training=self.training) * h

        if self.concat:
            # if this layer is not last layer,
            return F.elu(h_prime)
        else:
            # if this layer is last layer,
            return h_prime

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class SpmmSSFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, indices, values, shape, ssc_coef):
        assert indices.requires_grad == False
        input = torch.sparse_coo_tensor(indices, values, shape)
        ctx.save_for_backward(input, ssc_coef)
        ctx.N = shape[0]
        # return torch.matmul(ssc_coef, input)
        return (torch.matmul(input.t(), ssc_coef.t())).t()

    @staticmethod
    def backward(ctx, grad_output):
        input, ssc_coef = ctx.saved_tensors
        if ctx.needs_input_grad[1]:
            grad_a_dense = ssc_coef.t().matmul(grad_output)
            edge_idx = input._indices()[0, :] * ctx.N + input._indices()[1, :]
            grad_values = grad_a_dense.view(-1)[edge_idx]

        if ctx.needs_input_grad[3]:
            grad_ssc_coef = grad_output.matmul(input.t())
        return None, grad_values, None, grad_ssc_coef


class SpmmSS(nn.Module):
    def forward(self, indices, values, shape, ssc_coef):
        return SpmmSSFunction.apply(indices, values, shape, ssc_coef)


class SpStructuralSubspaceLayer(nn.Module):
    # batch: the number of samples in each mini batch
    def __init__(self, batch):
        super(SpStructuralSubspaceLayer, self).__init__()
        self.batch = batch

        self.W = nn.Parameter(torch.zeros(size=(self.batch, self.batch)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.spssmm = SpmmSS()

    def forward(self, x):
        ssc_coef = self.W - torch.diag(torch.diag(self.W), diagonal=0)
        # selfrep = torch.mm(ssc_coef, x)
        N = ssc_coef.size()[0]
        edge = x.nonzero().t()
        edge_v = torch.masked_select(x, x > 0).squeeze()
        selfrep = self.spssmm(edge, edge_v, torch.Size([N, N]), ssc_coef)
        return selfrep, ssc_coef


class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
    """

    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.conv = torch.nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=1, bias=False)
        self.conv1 = torch.nn.Conv1d(in_channels=out_features, out_channels=1, kernel_size=1, bias=False)
        self.conv2 = torch.nn.Conv1d(in_channels=out_features, out_channels=1, kernel_size=1, bias=False)

        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, input, adj):
        h = torch.squeeze(self.conv(torch.unsqueeze(input, dim=2)), dim=2)
        # h = torch.mm(input, self.W)
        N = h.size()[0]

        a_input1 = torch.squeeze(self.conv1(torch.unsqueeze(h, dim=2)), dim=2)
        a_input2 = torch.squeeze(self.conv2(torch.unsqueeze(h, dim=2)), dim=2)
        a_input = a_input1.repeat(1, N)
        a_input = a_input + (a_input2.repeat(1, N)).t()
        e = self.leakyrelu(a_input)

        zero_vec = -9e15 * torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, h)

        torch.cuda.empty_cache()

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
