import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from ChebnetII_pro import ChebnetII_prop, SGCConv, Bern_prop, GPR_prop


class NormLayer(nn.Module):
    def __init__(self, args, hidden):
        super(NormLayer, self).__init__()
        self.norm_type = args.norm_type
        self.norm_x = args.norm_x
        self.plusone = args.plusone
        self.scale = args.scale
        self.layer_norm = args.layer_norm
        if self.norm_type in ['dcn', 'cn']:
            self.LayerNorm = nn.LayerNorm(hidden)
        if self.norm_type == 'bn':
            self.BatchNorm = nn.BatchNorm1d(hidden)
    
    def forward(self, x, tau=1.0):
        if self.norm_type == 'none':
            return x
        
        if self.norm_type == 'cn':
            norm_x = nn.functional.normalize(x, dim=1)
            sim = norm_x @ norm_x.T / tau
            sim = nn.functional.softmax(sim, dim=1)
            if self.norm_x:
                x = norm_x
            x_neg = sim @ x
            if self.plusone:
                x = (1 + self.scale) * x - self.scale * x_neg
            else:
                x = x - self.scale * x_neg
            if self.layer_norm:
                x = self.LayerNorm(x)
            return x
        
        if self.norm_type == 'dcn':
            norm_x = nn.functional.normalize(x, dim=1)
            sim = norm_x.T @ norm_x / tau
            sim = nn.functional.softmax(sim, dim=1)
            if self.norm_x:
                x = norm_x
            x_neg = x @ sim  
            if self.plusone:
                x = (1 + self.scale) * x - self.scale * x_neg
            else:
                x = x - self.scale * x_neg
            if self.layer_norm:
                x = self.LayerNorm(x)
            return x
        
        if self.norm_type == 'zca':
            eps = 1e-6
            x = x - torch.mean(x, dim=0)
            cov = (x.T @ x) / (x.size(0) - 1)
            U, S, _ = torch.linalg.svd(cov)
            s = torch.sqrt(torch.clamp(S, min=eps))
            s_inv = torch.diag(1./s)
            whiten = (U @ s_inv) @ U.T
            return x @ whiten.T
        
        if self.norm_type == 'bn':
            x = self.BatchNorm(x)
            return x

        raise NotImplementedError
    

class ChebNetII(torch.nn.Module):
    def __init__(self, dataset, args):
        super(ChebNetII, self).__init__()
        self.lin1 = Linear(dataset.num_features, args.hidden)
        self.lin2 = Linear(args.hidden, args.hidden)
        self.prop1 = ChebnetII_prop(args.K)

        self.dprate = args.dprate
        self.dropout = args.dropout
        self.reset_parameters()

    def reset_parameters(self):
        self.prop1.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def get_embeddings(self, x, edge_index, **kwargs):
        return self(x, edge_index)

    def forward(self, x, edge_index):
        if self.dprate == 0.0:
            x = self.prop1(x, edge_index)
        else:
            x = F.dropout(x, p=self.dprate, training=self.training)
            x = self.prop1(x, edge_index)
        
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin1(x)
        x = F.relu(x)

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin2(x)
        return x


class GCN_Net(torch.nn.Module):
    def __init__(self, dataset, args):
        super(GCN_Net, self).__init__()
        self.dropout = args.dropout
        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()
        
        self.convs.append(GCNConv(dataset.num_features, args.hidden))
        self.bns.append(torch.nn.BatchNorm1d(args.hidden))
        for _ in range(args.num_layers-1):
            self.convs.append(GCNConv(args.hidden, args.hidden))
            self.bns.append(torch.nn.BatchNorm1d(args.hidden))

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            # x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return x
    
    def get_embeddings(self, x, edge_index, random1=None, random2=None):
        return self(x, edge_index)
    

class SGC_Net(nn.Module):
    def __init__(self, dataset, args):
        super(SGC_Net, self).__init__()
        self.dropout = args.dropout
        self.linear = nn.Linear(dataset.num_node_features, args.hidden)
        self.conv = SGCConv(K=args.K)

    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.linear(x))
        return x
    

class DGCN_Net(nn.Module):
    def __init__(self, dataset, args):
        super(DGCN_Net, self).__init__()
        self.dropout = args.dropout
        self.residual = args.residual
        self.num_layers = args.num_layers
        self.K = args.K
        self.conv = SGCConv(K=args.K)
    
        self.lins = nn.ModuleList()
        self.lins.append(nn.Linear(dataset.num_node_features * 2 if self.residual else dataset.num_node_features, args.hidden))
        for _ in range(self.num_layers - 1):
            self.lins.append(nn.Linear(args.hidden, args.hidden))

        self.norms = nn.ModuleList()
        for _ in range(self.num_layers):
            self.norms.append(NormLayer(args, args.hidden))
        self.reset_parameters()

    def reset_parameters(self): 
        for i in range(self.num_layers):
            self.lins[i].reset_parameters()
    
    def get_embeddings(self, x, edge_index, random1=False, random2=False):
        if random1:
            self.lins[0].weight = nn.Parameter(torch.normal(mean=0, std=0.1, size=self.lins[0].weight.size()).to(self.lins[0].weight.device), requires_grad=False)
        if random2:
            self.lins[1].weight = nn.Parameter(torch.normal(mean=0, std=0.01, size=self.lins[1].weight.size()).to(self.lins[1].weight.device), requires_grad=False)
        return self(x=x, edge_index=edge_index)

    def forward(self, x, edge_index):
        prop_x = F.relu(self.conv(x, edge_index))
        if self.residual:
            x = torch.concat((x, prop_x), dim=1)
        else:
            x = prop_x

        for i in range(self.num_layers):
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.lins[i](x)
            x = F.relu(x)
            x = self.norms[i](x)
        return x


class GPRGNN(torch.nn.Module):
    def __init__(self, dataset, args):
        super(GPRGNN, self).__init__()
        self.lin1 = Linear(dataset.num_features, args.hidden)
        self.lin2 = Linear(args.hidden, args.hidden)
        self.prop1 = GPR_prop(args.K, args.alpha, args.Init)

        self.Init = args.Init
        self.dprate = args.dprate
        self.dropout = args.dropout
        self.reset_parameters()

    def reset_parameters(self):
        self.prop1.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()


    def get_embeddings(self, x, edge_index, random1=False, random2=False):
        return self(x=x, edge_index=edge_index)
    
    def forward(self, x, edge_index):
        if self.dprate == 0.0:
            x = self.prop1(x, edge_index)
        else:
            x = F.dropout(x, p=self.dprate, training=self.training)
            x = self.prop1(x, edge_index)
        
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin1(x)
        x = F.relu(x)

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin2(x)
        return x
    

class BernNet(torch.nn.Module):
    def __init__(self, dataset, args):
        super(BernNet, self).__init__()
        self.lin1 = Linear(dataset.num_features, args.hidden)
        self.lin2 = Linear(args.hidden, args.hidden)
        self.prop1 = Bern_prop(args.K)

        self.dprate = args.dprate
        self.dropout = args.dropout
        self.residual = args.residual
        self.reset_parameters()

    def reset_parameters(self):
        self.prop1.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def get_embeddings(self, x, edge_index, random1=False, random2=False):
        return self(x=x, edge_index=edge_index)
    
    def forward(self, x, edge_index):
        if self.dprate == 0.0:
            x = self.prop1(x, edge_index)
        else:
            x = F.dropout(x, p=self.dprate, training=self.training)
            x = self.prop1(x, edge_index)
        
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin1(x)
        x = F.relu(x)

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin2(x)
        return x


class PropChebNetII(torch.nn.Module):
    def __init__(self, dataset, args):
        super(PropChebNetII, self).__init__()
        self.dprate = args.dprate
        self.dropout = args.dropout
        self.residual = args.residual
        self.K = args.K
        self.use_bn = args.use_bn

        self.bn = nn.BatchNorm1d(dataset.num_node_features)
        self.prop1 = ChebnetII_prop(args.K)

    def reset_parameters(self): 
        self.prop1.reset_parameters()
        self.bn.reset_parameters()
    
    def get_embeddings(self, x, edge_index):
        return self(x=x, edge_index=edge_index)
    
    def forward(self, x, edge_index):
        if self.dprate != 0.0:
            x = F.dropout(x, p=self.dprate, training=self.training)
        prop_x = self.prop1(x, edge_index)
        x = torch.concat((x, prop_x), dim=1) if self.residual else prop_x
        if self.use_bn:
            x = self.bn(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(x)
        return x
    

class PropBernNet(torch.nn.Module):
    def __init__(self, dataset, args):
        super(PropBernNet, self).__init__()
        self.dprate = args.dprate
        self.dropout = args.dropout
        self.residual = args.residual
        self.use_bn = args.use_bn

        self.prop1 = Bern_prop(args.K)
        self.bn = nn.BatchNorm1d(dataset.num_node_features)
        self.reset_parameters()

    def reset_parameters(self):
        self.prop1.reset_parameters()
        self.bn.reset_parameters()

    def get_embeddings(self, x, edge_index):
        return self(x=x, edge_index=edge_index)
    
    def forward(self, x, edge_index):
        if self.dprate != 0.0:
            x = F.dropout(x, p=self.dprate, training=self.training)            
        prop_x = self.prop1(x, edge_index)
        x = torch.concat((x, prop_x), dim=1) if self.residual else prop_x
        if self.use_bn:
            x = self.bn(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(x)
        return x
    

class PropGPRGNN(torch.nn.Module):
    def __init__(self, dataset, args):
        super(PropGPRGNN, self).__init__()
        self.Init = args.Init
        self.dprate = args.dprate
        self.dropout = args.dropout
        self.residual = args.residual
        self.use_bn = args.use_bn

        self.prop1 = GPR_prop(args.K, args.alpha, args.Init)
        self.bn = nn.BatchNorm1d(dataset.num_node_features)
        self.reset_parameters()

    def reset_parameters(self):
        self.prop1.reset_parameters()
        self.bn.reset_parameters()

    def get_embeddings(self, x, edge_index):
        return self(x=x, edge_index=edge_index)
    
    def forward(self, x, edge_index):
        if self.dprate != 0.0:
            x = F.dropout(x, p=self.dprate, training=self.training)
        prop_x = self.prop1(x, edge_index)
        x = torch.concat((x, prop_x), dim=1) if self.residual else prop_x
        if self.use_bn:
            x = self.bn(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(x)
        return x