import numpy as np
import torch
import torch.nn as nn
import dgl.function as fn
import torch.nn.functional as F

# from utils.hadamard import hadamard_transform_torch, hadamard_transform_cuda

EPS = 1E-10

def CorrReg(h1, h2, iden, alpha, beta):
    z1 = (h1 - h1.mean(0)) / (h1.std(0) + EPS)
    z2 = (h2 - h2.mean(0)) / (h2.std(0) + EPS)
    c1 = torch.mm(z1.T, z2)
    c2 = torch.mm(z1.T, z1)
    N, D = h1.shape
    c1 = c1 / N
    c2 = c2 / N

    loss_inv = -torch.diagonal(c1).sum()
    loss_dec = (iden - c2).pow(2).sum()


    return alpha * loss_inv + beta * loss_dec

def preprocess_feature(graph, feat, max_order = 10):
    '''
    This function generates the graph-augmented feature matrix.
    Note: this function is used only once for each dataset. We save the augmented feature matrix for later use.
    '''
    with graph.local_scope():
        degs = graph.in_degrees().float().clamp(min=1)
        norm = torch.pow(degs, -0.5).to(feat.device).unsqueeze(1)

        graph.ndata['norm'] = norm
        graph.apply_edges(fn.u_mul_v('norm', 'norm', 'weight')) 
        
        x = feat
        Feat = []

        for i in range(max_order):
            graph.ndata['h'] = x
            graph.update_all(fn.u_mul_e('h', 'weight', 'm'), fn.sum('m', 'h'))
            x = graph.ndata.pop('h')
            Feat.append(x)
            
        Feat = torch.stack(Feat, dim=0)

    return Feat
    

class MLP(nn.Module):
    def __init__(self, in_dim, hid_dim, num_layer, dropout = 0.0, use_bn = False):
        super(MLP, self).__init__()

        self.num_layer = num_layer

        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(in_dim, hid_dim, bias=True))
        for i in range(num_layer - 1):
            self.layers.append(nn.Linear(hid_dim, hid_dim, bias=True))
        
        self.bns = nn.ModuleList()
        for i in range(num_layer):
            self.bns.append(nn.BatchNorm1d(hid_dim))
      
        self.use_bn = use_bn
        self.act_fn = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        for i, lin in enumerate(self.layers):
            x = lin(x)
            if self.use_bn:
                x = self.bns[i](x)
            
            if i != self.num_layer - 1:
                x = self.act_fn(x)
                x = self.dropout(x)   
        return x


class Model(nn.Module):
    def __init__(self, in_dim, hid_dim, num_class, num_layer, dropout,
                 use_bn = False, rff_dim = 4096, temp = 0.5, approx = 'sorf'):
        super(Model, self).__init__()
        
        self.in_dim = in_dim
        self.hid_dim = hid_dim

        self.num_class = num_class
        self.mlp = MLP(in_dim, hid_dim, num_layer, dropout, use_bn)
        self.predictor = nn.Linear(hid_dim, num_class)

    def forward(self, feat, sfeat):
        # Node embeddings from MLP
        h1, h2 = self.mlp(feat), self.mlp(sfeat) 
        
        # Logits
        logits = self.predictor(h1)
        return logits, h1, h2
                
