import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Linear
from torch_geometric.nn import GCNConv

from ChebnetII_pro import ChebnetII_prop, SGCConv, Bern_prop, GPR_prop


class NormLayer(nn.Module):
    def __init__(self, args, hidden):
        super(NormLayer, self).__init__()
        self.norm_type = args.norm_type
        self.norm_x = args.norm_x
        self.plusone = args.plusone
        self.scale = args.scale
        self.layer_norm = args.layer_norm
        if self.norm_type in ['dcn', 'cn']:
            self.LayerNorm = nn.LayerNorm(hidden)
        if self.norm_type == 'bn':
            self.BatchNorm = nn.BatchNorm1d(hidden)
    
    def forward(self, x, tau=1.0):
        if self.norm_type == 'none':
            return x
        
        if self.norm_type == 'cn':
            norm_x = nn.functional.normalize(x, dim=1)
            sim = norm_x @ norm_x.T / tau
            sim = nn.functional.softmax(sim, dim=1)
            if self.norm_x:
                x = norm_x
            x_neg = sim @ x
            if self.plusone:
                x = (1 + self.scale) * x - self.scale * x_neg
            else:
                x = x - self.scale * x_neg
            if self.layer_norm:
                x = self.LayerNorm(x)
            return x
        
        if self.norm_type == 'dcn':
            norm_x = nn.functional.normalize(x, dim=1)
            sim = norm_x.T @ norm_x / tau
            sim = nn.functional.softmax(sim, dim=1)
            if self.norm_x:
                x = norm_x
            x_neg = x @ sim  
            if self.plusone:
                x = (1 + self.scale) * x - self.scale * x_neg
            else:
                x = x - self.scale * x_neg
            if self.layer_norm:
                x = self.LayerNorm(x)
            return x
        
        if self.norm_type == 'zca':
            eps = 1e-6
            x = x - torch.mean(x, dim=0)
            cov = (x.T @ x) / (x.size(0) - 1)
            U, S, _ = torch.linalg.svd(cov)
            s = torch.sqrt(torch.clamp(S, min=eps))
            s_inv = torch.diag(1./s)
            whiten = (U @ s_inv) @ U.T
            return x @ whiten.T
        
        if self.norm_type == 'bn':
            x = self.BatchNorm(x)
            return x

        raise NotImplementedError


class ChebNetII(torch.nn.Module):
    def __init__(self, dataset, args):
        super(ChebNetII, self).__init__()
        self.lin1 = Linear(dataset.num_features * 2 if args.residual else dataset.num_features, args.hidden)
        self.lin2 = Linear(args.hidden, args.hidden)
        
        self.props = nn.ModuleList()
        self.props.append(ChebnetII_prop(args.K))

        self.dprate = args.dprate
        self.dropout = args.dropout
        self.residual = args.residual
        self.reset_parameters()

    def reset_parameters(self):
        self.props[0].reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def get_embeddings(self, x, edge_index):
        return self(x, edge_index)

    def forward(self, x, edge_index):   
        if self.dprate == 0.0:
            prop_x = self.props[0](x, edge_index)
        else:
            x = F.dropout(x, p=self.dprate, training=self.training)
            prop_x = self.props[0](x, edge_index)
        
        if self.residual:
            x = torch.concat((x, prop_x), dim=1)
        else:
            x = prop_x
            
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin1(x)
        x = F.relu(x)

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin2(x)
        return x


class PropChebNetII(torch.nn.Module):
    def __init__(self, dataset, args):
        super(PropChebNetII, self).__init__()        
        self.props = nn.ModuleList()
        self.props.append(ChebnetII_prop(args.K))
        self.norm_layer = NormLayer(args, hidden=args.num_node_features)
        self.dprate = args.dprate
        self.dropout = args.dropout
        self.residual = args.residual
        self.reset_parameters()

    def reset_parameters(self):
        self.props[0].reset_parameters()

    def get_embeddings(self, x, edge_index):
        return self(x, edge_index)

    def forward(self, x, edge_index):   
        if self.dprate == 0.0:
            prop_x = self.props[0](x, edge_index)
        else:
            x = F.dropout(x, p=self.dprate, training=self.training)
            prop_x = self.props[0](x, edge_index)
        if self.residual:
            x = torch.concat((x, prop_x), dim=1)
        else:
            x = prop_x
        
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(x)
        return x


class SGC_Net(nn.Module):
    def __init__(self, dataset, args):
        super(SGC_Net, self).__init__()
        self.dropout = args.dropout
        self.residual = args.residual
        self.num_layers = args.num_layers
        self.K = args.K

        self.conv = SGCConv(K=args.K)
        
        self.lins = nn.ModuleList()
        self.lins.append(nn.Linear(dataset.num_node_features * 2 if self.residual else dataset.num_node_features, args.hidden))
        for _ in range(self.num_layers - 1):
            self.lins.append(nn.Linear(args.hidden, args.hidden))
        self.reset_parameters()

    def reset_parameters(self): 
        for i in range(self.num_layers):
            self.lins[i].reset_parameters()
    
    def get_embeddings(self, x, edge_index, random=False, constant=False):
        return self(x=x, edge_index=edge_index)

    def forward(self, x, edge_index):
        prop_x = F.relu(self.conv(x, edge_index))
        if self.residual:
            x = torch.concat((x, prop_x), dim=1)
        else:
            x = prop_x

        for i in range(self.num_layers):
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.lins[i](x)
        return x


class GPRGNN(torch.nn.Module):
    def __init__(self, dataset, args):
        super(GPRGNN, self).__init__()
        self.lin1 = Linear(dataset.num_features * 2 if args.residual else dataset.num_features, args.hidden)
        self.lin2 = Linear(args.hidden, args.hidden)
        self.props = nn.ModuleList()
        self.props.append(GPR_prop(args.K, args.alpha, args.Init))

        self.Init = args.Init
        self.dprate = args.dprate
        self.dropout = args.dropout
        self.residual = args.residual
        self.reset_parameters()

    def reset_parameters(self):
        self.props[0].reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def get_embeddings(self, x, edge_index, random=False, constant=False):
        return self(x=x, edge_index=edge_index)
    
    def forward(self, x, edge_index):
        if self.dprate == 0.0:
            prop_x = self.props[0](x, edge_index)
        else:
            x = F.dropout(x, p=self.dprate, training=self.training)
            prop_x = self.props[0](x, edge_index)
        
        if self.residual:
            x = torch.concat((x, prop_x), dim=1)
        else:
            x = prop_x
        
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin1(x)
        x = F.relu(x)

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin2(x)
        return x
    


class PropGPRGNN(torch.nn.Module):
    def __init__(self, dataset, args):
        super(PropGPRGNN, self).__init__()
        self.props = nn.ModuleList()
        self.props.append(GPR_prop(args.K, args.alpha, args.Init))
        self.norm_layer = NormLayer(args, hidden=args.hidden)
        self.Init = args.Init
        self.dprate = args.dprate
        self.dropout = args.dropout
        self.residual = args.residual
        self.reset_parameters()

    def reset_parameters(self):
        self.props[0].reset_parameters()

    def get_embeddings(self, x, edge_index):
        return self(x=x, edge_index=edge_index)
    
    def forward(self, x, edge_index):
        x = self.norm_layer(x)
        if self.dprate == 0.0:
            prop_x = self.props[0](x, edge_index)
        else:
            x = F.dropout(x, p=self.dprate, training=self.training)
            prop_x = self.props[0](x, edge_index)
        if self.residual:
            x = torch.concat((x, prop_x), dim=1)
        else:
            x = prop_x
        
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(x)
        return x
    

class BernNet(torch.nn.Module):
    def __init__(self, dataset, args):
        super(BernNet, self).__init__()
        self.lin1 = Linear(dataset.num_features * 2 if args.residual else dataset.num_features, args.hidden)
        self.lin2 = Linear(args.hidden, args.hidden)
        self.props = nn.ModuleList()
        self.props.append(Bern_prop(args.K))

        self.dprate = args.dprate
        self.dropout = args.dropout
        self.residual = args.residual
        self.reset_parameters()

    def reset_parameters(self):
        self.props[0].reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()
        
    def get_embeddings(self, x, edge_index, random=False, constant=False):
        return self(x=x, edge_index=edge_index)

    def forward(self, x, edge_index):
        if self.dprate == 0.0:
            prop_x = self.props[0](x, edge_index)
        else:
            x = F.dropout(x, p=self.dprate, training=self.training)
            prop_x = self.props[0](x, edge_index)
        if self.residual:
            x = torch.concat((x, prop_x), dim=1)
        else:
            x = prop_x

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin1(x)
        x = F.relu(x)

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin2(x)
        return x


class PropBernNet(torch.nn.Module):
    def __init__(self, dataset, args):
        super(PropBernNet, self).__init__()
        self.props = nn.ModuleList()
        self.props = nn.ModuleList()
        self.props.append(Bern_prop(args.K))

        self.dprate = args.dprate
        self.dropout = args.dropout
        self.residual = args.residual
        self.reset_parameters()

    def reset_parameters(self):
        self.props[0].reset_parameters()
        
    def get_embeddings(self, x, edge_index, random=False, constant=False):
        return self(x=x, edge_index=edge_index)

    def forward(self, x, edge_index):
        if self.dprate == 0.0:
            prop_x = self.props[0](x, edge_index)
        else:
            x = F.dropout(x, p=self.dprate, training=self.training)
            prop_x = self.props[0](x, edge_index)
        if self.residual:
            x = torch.concat((x, prop_x), dim=1)
        else:
            x = prop_x
        
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(x)
        return x
    

class GCN_Net(torch.nn.Module):
    def __init__(self, dataset, args):
        super(GCN_Net, self).__init__()
        self.dropout = args.dropout
        self.residual = args.residual
        self.convs = torch.nn.ModuleList()
        
        self.convs.append(GCNConv(dataset.num_features, args.hidden))
        self.convs.append(GCNConv(args.hidden + dataset.num_features if self.residual else args.hidden, args.hidden))

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            prop_x = conv(x, edge_index)
            prop_x = F.relu(prop_x)
            prop_x = F.dropout(prop_x, p=self.dropout, training=self.training)
            if self.residual:
                x = torch.cat((x, prop_x), dim=1)
            else:
                x = prop_x
        x = self.convs[-1](x, edge_index)
        return x
    
    def get_embeddings(self, x, edge_index, random=None, constant=None):
        return self(x, edge_index)

    