from torch.nn import Linear
from torch.nn import Parameter
from torch_geometric.nn import GINConv, SAGEConv, GCNConv
import numpy as np
import torch.nn.init as init
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import add_remaining_self_loops, degree
from torch_scatter import scatter
from torch_geometric.nn import APPNP

class MLP_classifier(torch.nn.Module):
    def __init__(self, args):
        super(MLP_classifier, self).__init__()
        self.args = args

        self.lin = Linear(args.hidden2, args.num_classes)

    def clip_parameters(self):
        for p in self.lin.parameters():
            p.data.clamp_(-self.args.clip_c, self.args.clip_c)

    def reset_parameters(self):
        self.lin.reset_parameters()

    def forward(self, h, edge_index=None):
        h = self.lin(h)

        return h

def propagate2(x, edge_index):
    edge_index, _ = add_remaining_self_loops(
        edge_index, num_nodes=x.size(0))

    # calculate the degree normalize term
    row, col = edge_index
    deg = degree(col, x.size(0), dtype=x.dtype)
    deg_inv_sqrt = deg.pow(-0.5)
    # for the first order appro of laplacian matrix in GCN, we use deg_inv_sqrt[row]*deg_inv_sqrt[col]
    edge_weight = deg_inv_sqrt[row] * deg_inv_sqrt[col]

    # normalize the features on the starting point of the edge
    out = edge_weight.view(-1, 1) * x[row]

    return scatter(out, edge_index[-1], dim=0, dim_size=x.size(0), reduce='add')
class GCN_encoder_scatter(torch.nn.Module):
    def __init__(self, args):
        super(GCN_encoder_scatter, self).__init__()

        self.args = args

        self.lin = Linear(args.num_features, args.hidden, bias=False)

        self.bias = Parameter(torch.Tensor(args.hidden))

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.fill_(0.0)

    def forward(self, x, edge_index, adj_norm_sp):
        h = self.lin(x)
        h = propagate2(h, edge_index) + self.bias
        return h

class MLP_encoder(torch.nn.Module):
    def __init__(self, args):
        super(MLP_encoder, self).__init__()
        self.args = args

        self.lin = Linear(args.num_features, args.hidden)

    def reset_parameters(self):
        self.lin.reset_parameters()

    def forward(self, x, edge_index=None, mask_node=None):
        h = self.lin(x)

        return h
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        self.reset_parameters()

    def forward(self, x):
        return self.layer(x)

    def reset_parameters(self):
        for layer in self.layer:
            if isinstance(layer, nn.Linear):
                init.xavier_uniform_(layer.weight)
                init.constant_(layer.bias, 0)

class GCN_encoder(nn.Module):
    def __init__(self, args):
        super(GCN_encoder, self).__init__()
        self.conv1 = GCNConv(args.num_features, args.hidden)
        self.transition = nn.Sequential(
            nn.ReLU(),
            nn.BatchNorm1d(args.hidden),
            nn.Dropout(p=args.dropout)
        )
        self.conv2 = GCNConv(args.hidden, args.hidden)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

    def forward(self, x, edge_index, edge_weight=None):
        x = self.conv1(x, edge_index, edge_weight)
        x = self.transition(x)
        h = self.conv2(x, edge_index, edge_weight)
        return h


class Projector(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Projector, self).__init__()
        self.lin1 = nn.Linear(in_dim, out_dim)
        self.lin2 = nn.Linear(out_dim, out_dim)
        self.lin3 = nn.Linear(out_dim, out_dim)

    def reset_parameters(self):
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()
        self.lin3.reset_parameters()

    def forward(self, h):
        y = self.lin1(h)
        y = self.lin2(y)
        y = self.lin3(y)
        return y







class Encoder_X(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.lin1 = nn.Linear(in_channels, hidden_channels)
        self.lin_mu = nn.Linear(hidden_channels, out_channels)
        self.lin_log_std = nn.Linear(hidden_channels, out_channels)
        self.h_layer = None

    def forward(self, x):
        h = self.lin1(x).relu()
        self.h_layer = h
        mu = self.lin_mu(h)
        log_std = self.lin_log_std(h)
        log_std = log_std.clamp(min=-20, max=10)
        
        return mu, log_std
    
    def reset_parameters(self):
        self.lin1.reset_parameters()
        self.lin_mu.reset_parameters()
        self.lin_log_std.reset_parameters()

    def kl_loss(self,mu, log_std):
        return -0.5 * (1 + 2 * log_std - mu.pow(2) - log_std.exp().pow(2)).mean()


class Encoder_D(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.lin1 = nn.Linear(in_channels, hidden_channels)
        self.lin_mu = nn.Linear(hidden_channels, out_channels)
        self.lin_log_std = nn.Linear(hidden_channels, out_channels)
        self.h_layer = None

    def forward(self, x):
        h = self.lin1(x).relu()
        self.h_layer = h
        mu = self.lin_mu(h)
        log_std = self.lin_log_std(h)
        log_std = log_std.clamp(min=-20, max=10)

        return mu, log_std

    def reset_parameters(self):
        self.lin1.reset_parameters()
        self.lin_mu.reset_parameters()
        self.lin_log_std.reset_parameters()

    def kl_loss(self, mu, log_std):
        return -0.5 * (1 + 2 * log_std - mu.pow(2) - log_std.exp().pow(2)).mean()



class Encoder_A(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv_mu = GCNConv(hidden_channels, out_channels)
        self.conv_log_std = GCNConv(hidden_channels, out_channels)
        self.h_layer = None

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index).relu()
        self.h_layer = h
        mu = self.conv_mu(h, edge_index)
        log_std = self.conv_log_std(h, edge_index)
        log_std = log_std.clamp(min=-20, max=10)
        
        return mu, log_std
        
    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv_mu.reset_parameters()
        self.conv_log_std.reset_parameters()

    def kl_loss(self,mu, log_std):
        return -0.5 * (1 + 2 * log_std - mu.pow(2) - log_std.exp().pow(2)).mean()
    
class GIN_encoder(nn.Module):
    def __init__(self, args):
        super(GIN_encoder, self).__init__()

        self.args = args

        self.mlp = nn.Sequential(
            nn.Linear(args.num_features, args.hidden),
            # nn.ReLU(),
            # nn.BatchNorm1d(args.hidden),
            # nn.Linear(args.hidden, args.hidden),
        )

        self.conv = GINConv(self.mlp)

    def reset_parameters(self):
        self.conv.reset_parameters()

    def forward(self, x, edge_index, adj_norm_sp=None):
        h = self.conv(x, edge_index)
        return h


class SAGE_encoder(nn.Module):
    def __init__(self, args):
        super(SAGE_encoder, self).__init__()

        self.args = args
        self.conv1 = SAGEConv(args.num_features, args.hidden, normalize=True)
        self.conv1.aggr = 'mean'
        self.transition = nn.Sequential(
            nn.ReLU(),
            nn.BatchNorm1d(args.hidden),
            nn.Dropout(p=args.dropout)
        )
        self.conv2 = SAGEConv(args.hidden, args.hidden, normalize=True)
        self.conv2.aggr = 'mean'

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

    def forward(self, x, edge_index, edge_weight=None):
        x = self.conv1(x, edge_index, edge_weight)
        x = self.transition(x)
        h = self.conv2(x, edge_index, edge_weight)
        return h

class PropensityModel(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.mlp(x)




