import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class SpecialSpmmFunction(torch.autograd.Function):
    """Special function for only sparse region backpropataion layer."""

    @staticmethod
    def forward(ctx, indices, values, shape, b):
        assert indices.requires_grad == False
        a = torch.sparse_coo_tensor(indices, values, shape)
        ctx.save_for_backward(a, b)
        ctx.N = shape[0]
        return torch.matmul(a, b)

    @staticmethod
    def backward(ctx, grad_output):
        a, b = ctx.saved_tensors
        grad_values = grad_b = None
        if ctx.needs_input_grad[1]:
            grad_a_dense = grad_output.matmul(b.t())
            edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :]
            grad_values = grad_a_dense.view(-1)[edge_idx]
        if ctx.needs_input_grad[3]:
            grad_b = a.t().matmul(grad_output)
        return None, grad_values, None, grad_b


class SpecialSpmm(nn.Module):
    def forward(self, indices, values, shape, b):
        return SpecialSpmmFunction.apply(indices, values, shape, b)


class JointAttentionLayer(nn.Module):
    """
    Using sparse version for large datasets such as Pubmed
    """

    def __init__(self, in_features, out_features, dropout, alpha, scale, autobalance, roughconstrain, concat=True):
        super(JointAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat
        self.autobalance = autobalance
        self.roughconstrain = roughconstrain
        self.scale = scale

        if self.autobalance:
            self.aw = nn.Parameter(torch.zeros(size=(1, 2)))
            nn.init.xavier_normal_(self.aw.data, gain=1.414)
        else:
            self.fw = 0.5

        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_normal_(self.W.data, gain=1.414)

        self.a = nn.Parameter(torch.zeros(size=(1, 2 * out_features)))
        nn.init.xavier_normal_(self.a.data, gain=1.414)

        # self.dropout = nn.Dropout(dropout)
        self.leakyrelu = nn.LeakyReLU(self.alpha)
        self.special_spmm = SpecialSpmm()

    def forward(self, input, edge, cadinality, ssc_coef):
        dv = 'cuda' if input.is_cuda else 'cpu'

        N = input.size()[0]
        # edge = adj.nonzero().t()

        h = torch.mm(input, self.W)
        # h: N x out
        assert not torch.isnan(h).any()

        # Feature-attention on the nodes - Shared attention mechanism
        edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t()
        # edge: 2*D x E

        edge_e = torch.exp(-self.leakyrelu(self.a.mm(edge_h).squeeze()))
        assert not torch.isnan(edge_e).any()
        # edge_e: E

        e_rowsum = self.special_spmm(edge, edge_e, torch.Size([N, N]), torch.ones(size=(N, 1), device=dv))
        # e_rowsum: N x 1

        edge_e = edge_e.unsqueeze(dim=1)
        edge_e = edge_e.div(e_rowsum[edge[0, :], :])
        assert not torch.isnan(edge_e).any()

        if self.roughconstrain:
            # all the coefficients where adj > 0 are considered
            satt = torch.unsqueeze(ssc_coef, dim=1)
            satt = torch.exp(satt)
        else:
            # only positive coefficients of connected vertices are considered
            satt = torch.unsqueeze(ssc_coef, dim=1)
            zero_vec = -9e15 * torch.ones_like(satt)
            satt = torch.where(satt > 0, satt, zero_vec)
            satt = torch.exp(satt)

        satt_rowsum = self.special_spmm(edge, satt.squeeze(), torch.Size([N, N]), torch.ones(size=(N, 1), device=dv))
        satt = satt.div(satt_rowsum[edge[0, :], :])
        assert not torch.isnan(satt).any()

        if self.autobalance:
            weight = F.softmax(self.aw, dim=1)
            edge_e = weight[0, 0] * edge_e + weight[0, 1] * satt
        else:
            edge_e = self.fw * edge_e + (1 - self.fw) * satt

        e_rowsum = self.special_spmm(edge, edge_e.squeeze(), torch.Size([N, N]), torch.ones(size=(N, 1), device=dv))

        edge_e = edge_e.div(e_rowsum[edge[0, :], :])

        # edge_e = self.dropout(edge_e)
        edge_e = F.dropout(edge_e, self.dropout, training=self.training)
        # edge_e: E

        h_prime = self.special_spmm(edge, edge_e.squeeze(), torch.Size([N, N]), h)
        
        h_prime = h_prime + self.scale * cadinality * h

        if self.concat:
            # if this layer is not last layer,
            return F.elu(h_prime)
        else:
            # if this layer is last layer,
            return h_prime

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class StructuralSubspaceLayer(nn.Module):
    # batch: the number of samples
    def __init__(self, batch):
        super(StructuralSubspaceLayer, self).__init__()
        self.batch = batch

        self.W = nn.Parameter(torch.zeros(size=(self.batch, self.batch)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.special_spmm = SpecialSpmm()

    def forward(self, edget, edget_v):
        # N = x.size()[0]
        ssc_coef = self.W - torch.diag(torch.diag(self.W), diagonal=0)
        # edge = x.nonzero().t()
        # edge_v = torch.masked_select(x, x > 0)
        # selfrep = torch.mm(ssc_coef, x)
        selfrep = self.special_spmm(edget, edget_v, torch.Size([self.batch, self.batch]), ssc_coef)
        return selfrep, ssc_coef
