from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear, BatchNorm1d, Sequential, ReLU
from torch_geometric.nn import global_max_pool,global_mean_pool, global_add_pool

from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops
from torch_geometric.nn.inits import glorot, zeros


pool_map = {
    'max':global_max_pool,
    'mean':global_mean_pool,
    'sum':global_add_pool,
}

class GCNConv(MessagePassing):
    def __init__(self,
                 in_channels,
                 out_channels,
                 improved=False,
                 cached=False,
                 bias=True,
                 edge_norm=True,
                 gfn=False,
                 drop_out = 0.2):
        super(GCNConv, self).__init__('add')

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.improved = improved
        self.cached = cached
        self.cached_result = None
        self.edge_norm = edge_norm
        self.gfn = gfn
        self.message_mask = None
        self.weight = Parameter(torch.Tensor(in_channels, out_channels))
        self.drop_out  = torch.nn.Dropout(drop_out)
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        zeros(self.bias)
        self.cached_result = None

    @staticmethod
    def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ),
                                     dtype=dtype,
                                     device=edge_index.device)
        
        edge_weight = edge_weight.view(-1)
        
        
        assert edge_weight.size(0) == edge_index.size(1)
        
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
        # Add edge_weight for loop edges.
        loop_weight = torch.full((num_nodes, ),
                                 1 if not improved else 2,
                                 dtype=edge_weight.dtype,
                                 device=edge_weight.device)
        edge_weight = torch.cat([edge_weight, loop_weight], dim=0)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        
        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    def forward(self, x, edge_index, edge_weight=None):
        """"""
        
        x = torch.matmul(x, self.weight)
        if self.gfn:
            return x
    
        if not self.cached or self.cached_result is None:
            if self.edge_norm:
                edge_index, norm = GCNConv.norm(
                    edge_index, 
                    x.size(0), 
                    edge_weight, 
                    self.improved, 
                    x.dtype)
            else:
                norm = None
            self.cached_result = edge_index, norm

        edge_index, norm = self.cached_result
        x = self.propagate(edge_index, x=x, norm=norm)
        x = self.drop_out(x)
        return x

    def message(self, x_j, norm):

        if self.edge_norm:
            return norm.view(-1, 1) * x_j
        else:
            return x_j
        
    def update(self, aggr_out):
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)

class GCNNet(torch.nn.Module):
    """GCN with BN and residual connection."""
    def __init__(self,args, num_features,
                       num_classes, hidden, 
                       num_feat_layers=1, 
                       num_conv_layers=2,
                 num_fc_layers=1, gfn=False, collapse=False, residual=False,
                 res_branch="BNConvReLU", global_pool="sum", dropout=0, 
                 edge_norm=True):
        super(GCNNet, self).__init__()

        self.global_pool = pool_map[global_pool]
        self.dropout = nn.Dropout(dropout)
        GConv = partial(GCNConv, edge_norm=edge_norm, gfn=gfn)

        hidden_in = num_features
        self.bn_feat = BatchNorm1d(hidden_in)
        self.conv_feat = GCNConv(hidden_in, hidden, gfn=True) # linear transform
        self.bns_conv = torch.nn.ModuleList()
        self.convs_1 = torch.nn.ModuleList()
        self.convs_2 = torch.nn.ModuleList()
        
        self.convs_1.append(GConv(hidden, hidden*2))
        self.convs_1.append(GConv(hidden*2, hidden*2))

        self.convs_2.append(GConv(hidden*2, hidden*4))
        self.convs_2.append(GConv(hidden*4, hidden*4))



        self.bns_fc = torch.nn.ModuleList()
        self.lins = torch.nn.ModuleList()
        
        for i in range(num_fc_layers):
            # self.bns_fc.append(BatchNorm1d(hidden))
            self.lins.append(Linear(hidden*4, hidden*4))
        
        self.fc = Linear(hidden*4, num_classes)
        self.head = nn.Sequential(nn.Linear(hidden*4, hidden), nn.BatchNorm1d(hidden), nn.ReLU(inplace=True), nn.Linear(hidden, hidden*4))
        self.head_fc = nn.Sequential(nn.Linear(hidden*4, hidden), nn.BatchNorm1d(hidden), nn.ReLU(inplace=True), nn.Linear(hidden, hidden*4))
        # BN initialization.
        for m in self.modules():
            if isinstance(m, (torch.nn.BatchNorm1d)):
                torch.nn.init.constant_(m.weight, 1)
                torch.nn.init.constant_(m.bias, 0.0001)

    def forward(self, x,edge_index, edge_weight,batch):
        x = self.bn_feat(x.float())
        x = F.relu(self.conv_feat(x, edge_index,edge_weight))
        for i, conv in enumerate(self.convs_1):
            x = self.dropout(x)
            if i == 0:
                x = F.relu(conv(x, edge_index,edge_weight))
                x_in = x
            else:
                x = F.relu(conv(x, edge_index,edge_weight))

        x = x+x_in
        for i, conv in enumerate(self.convs_2):
            x = self.dropout(x)
            if i == 0:
                x = F.relu(conv(x, edge_index,edge_weight))
                x_in = x
            else:
                x = F.relu(conv(x, edge_index,edge_weight))
        x = x+x_in
        x = self.global_pool(x, batch)
        for i, lin in enumerate(self.lins):
            x = F.relu(lin(x))
        
        x = self.fc(x)
        return x

    def forward_bcl(self, x,edge_index, edge_weight,batch):
        x = self.bn_feat(x.float())
        x = F.relu(self.conv_feat(x, edge_index,edge_weight))
        for i, conv in enumerate(self.convs_1):
            x = self.dropout(x)
            if i == 0:
                x = F.relu(conv(x, edge_index,edge_weight))
                x_in = x
            else:
                x = F.relu(conv(x, edge_index,edge_weight))

        x = x+x_in
        for i, conv in enumerate(self.convs_2):
            x = self.dropout(x)
            if i == 0:
                x = F.relu(conv(x, edge_index,edge_weight))
                x_in = x
            else:
                x = F.relu(conv(x, edge_index,edge_weight))
        x = x+x_in
        x = self.global_pool(x, batch)
        for i, lin in enumerate(self.lins):
            x = F.relu(lin(x))

        x_feat = F.normalize(self.head(x), dim=1)
        x = self.fc(x)
        centers_logits = F.normalize(self.head_fc(self.fc.weight), dim=1) 
        return centers_logits,x,x_feat


def model_gcn(args,num_features, num_classes):
    return GCNNet(args,num_features, num_classes,128) 