import numpy as np
import torch
import torch.nn as nn
import dgl.function as fn


sigma = 1e-6

class MLP(nn.Module):
    def __init__(self, in_dim, hid_dim, nlayer, dropout=0.5, use_bn=False):
        super(MLP, self).__init__()

        self.nlayer = nlayer

        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(in_dim, hid_dim, bias=True))
        for i in range(nlayer - 1):
            self.layers.append(nn.Linear(hid_dim, hid_dim, bias=True))

        self.bns = nn.ModuleList()
        for i in range(nlayer):
            self.bns.append(nn.BatchNorm1d(hid_dim))

        self.use_bn = use_bn
        self.act_fn = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        for i, lin in enumerate(self.layers):
            x = lin(x)
            if self.use_bn:
                x = self.bns[i](x)
            # x = self.act_fn(x)
            if i != self.nlayer - 1:
                x = self.act_fn(x)
                x = self.dropout(x)

        return x


class GRANDConv(nn.Module):
    def __init__(self, order, init_feat=False):
        super(GRANDConv, self).__init__()

        self.order = order
        self.init_feat = init_feat

    def forward(self, graph, feats):

        graph = graph.remove_self_loop()
        with graph.local_scope():
            # ''' Calculate Symmetric normalized adjacency matrix   \hat{A} '''
            # degs = graph.in_degrees().float().clamp(min=1)
            # norm = torch.pow(degs, -0.5).to(feats.device).unsqueeze(1)

            # graph.ndata['norm'] = norm
            # graph.apply_edges(fn.u_mul_v('norm', 'norm', 'weight'))
            # ''' Graph Conv '''
            x = feats

            if self.init_feat:
                y = 0 + feats
            else:
                y = feats * 0

            for i in range(self.order):
                graph.ndata['h'] = x
                graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h'))
                x = graph.ndata.pop('h')

                y.add_(x)

        return y / (self.order)


class Model(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim, num_layers, beta, dropout, K=1, use_bn=False):
        super(Model, self).__init__()

        self.encoder = MLP(in_dim, hid_dim, num_layers, dropout, use_bn=use_bn)
        self.mixhop = GRANDConv(K)

        self.pred = nn.Linear(hid_dim, out_dim, bias=True)
        self.beta = beta

        self.act_fn = nn.ReLU()

    def get_embedding(self, graph, feat):
        h1 = self.encoder(feat)
        return h1.detach()

    def embedding_reg(self, h1, h2):
        z1 = (h1 - h1.mean(0)) / (h1.std(0) + sigma)
        z2 = (h2 - h2.mean(0)) / (h2.std(0) + sigma)

        c = torch.mm(z1.T, z2)
        N, D = h1.shape
        c = c / N

        loss_inv = -torch.diagonal(c).sum()
        iden = torch.tensor(np.eye(D)).to(c.device)
        loss_dec = (iden - c).pow(2).sum()


        return loss_inv, loss_dec

    def forward(self, graph, feat):
        # encoding
        h1 = self.encoder(feat)
        h2 = self.mixhop(graph, h1)

        pred = self.pred(h1)
        loss_inv, loss_dec = self.embedding_reg(h1, h2)
        return (loss_inv, loss_dec), pred

