from functools import partial
import torch
import torch.nn as nn
from torch_geometric.nn.conv import MessagePassing
import torch.nn.functional as F
from torch_scatter import scatter_add
from torch_geometric.nn.inits import zeros,glorot
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
from torch.nn import Linear, BatchNorm1d, Sequential, ReLU
from torch_geometric.nn import global_mean_pool, global_add_pool, GINConv, GATConv
from torch_geometric.data import Batch,Data
from torch.nn import Parameter
import random


def graph_pooling(data, num_vertices):
    out = data.sum(1)
    num_vertices = num_vertices.unsqueeze(-1).expand_as(out)
    return torch.div(out, num_vertices)


class GCNConv(MessagePassing):
    def __init__(self,
                 in_channels,
                 out_channels,
                 improved=False,
                 cached=False,
                 bias=True,
                 edge_norm=True,
                 gfn=False):
        super(GCNConv, self).__init__('add')

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.improved = improved
        self.cached = cached
        self.cached_result = None
        self.edge_norm = edge_norm
        self.gfn = gfn
        self.message_mask = None
        self.weight = Parameter(torch.Tensor(in_channels, out_channels))

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        zeros(self.bias)
        self.cached_result = None

    @staticmethod
    def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1),),
                                     dtype=dtype,
                                     device=edge_index.device)

        edge_weight = edge_weight.view(-1)
        assert edge_weight.size(0) == edge_index.size(1)
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
        loop_weight = torch.full((num_nodes,),
                                 1 if not improved else 2,
                                 dtype=edge_weight.dtype,
                                 device=edge_weight.device)
        edge_weight = torch.cat([edge_weight, loop_weight], dim=0)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    def forward(self, x, edge_index, edge_weight=None):
        """"""

        x = torch.matmul(x, self.weight)
        if self.gfn:
            return x

        if not self.cached or self.cached_result is None:
            if self.edge_norm:
                edge_index, norm = GCNConv.norm(
                    edge_index,
                    x.size(0),
                    edge_weight,
                    self.improved,
                    x.dtype)
            else:
                norm = None
            self.cached_result = edge_index, norm

        edge_index, norm = self.cached_result
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):

        if self.edge_norm:
            return norm.view(-1, 1) * x_j
        else:
            return x_j

    def update(self, aggr_out):
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)

class PlainGCN(torch.nn.Module):

    def __init__(self, num_features,
                  hidden,num_classes=1,
                 num_conv_layers=3,
                 num_fc_layers=2, gfn=False, dropout=0,fc_dropout=0.1,p_dim=80,
                 edge_norm=True):
        super(PlainGCN, self).__init__()

        self.global_pool = global_add_pool
        self.dropout = dropout

        GConv = partial(GCNConv, edge_norm=edge_norm, gfn=gfn)

        hidden_in = num_features
        self.p_dim=p_dim
        self.bn_feat = BatchNorm1d(hidden_in)
        self.conv_feat = GCNConv(hidden_in, hidden, gfn=True)
        self.bns_conv = torch.nn.ModuleList()
        self.convs = torch.nn.ModuleList()
        self.fc1 = nn.Linear(hidden, p_dim, bias=False)
        self.fc2 = nn.Linear(p_dim, 1, bias=False)
        self.fc_dropout = nn.Dropout(fc_dropout)

        self.fc = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.LayerNorm(hidden),
            nn.PReLU(),
            nn.Linear(hidden, hidden),
            nn.LayerNorm(hidden),
            nn.PReLU(),
            nn.Linear(hidden, 1)
        )

        for m in self.fc:
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.fill_(0.0)


        for i in range(num_conv_layers):
            self.bns_conv.append(BatchNorm1d(hidden))
            self.convs.append(GConv(hidden, hidden))
        self.bn_hidden = BatchNorm1d(hidden)
        self.bns_fc = torch.nn.ModuleList()
        self.lins = torch.nn.ModuleList()

        for i in range(num_fc_layers - 1):
            self.bns_fc.append(BatchNorm1d(hidden))
            self.lins.append(Linear(hidden, hidden))
        self.lin_class = Linear(hidden, num_classes)

        # BN initialization.
        for m in self.modules():
            if isinstance(m, (torch.nn.BatchNorm1d)):
                torch.nn.init.constant_(m.weight, 1)
                torch.nn.init.constant_(m.bias, 0.0001)

    def to_pyg(self, xxs, edge_index_list, num_vertice):

        assert xxs.shape[0] == len(edge_index_list)
        assert xxs.shape[0] == len(num_vertice)

        data_list = []
        for xx, e, n in zip(xxs, edge_index_list, num_vertice):
            data_list.append(Data(x=xx[:n], edge_index=e))
        batch = Batch.from_data_list(data_list)
        return batch

    def forward(self, data):
        bach = self.to_pyg(data["operations"], data["edge_index_list"], data["num_vertices"])
        bs = data["num_vertices"].shape[0]
        numv = data["num_vertices"]
        x = bach.x if bach.x is not None else bach.feat
        edge_index, batch = bach.edge_index, bach.batch
        x = x.to(torch.float32)
        x = self.bn_feat(x)
        x = F.relu(self.conv_feat(x, edge_index))
        for i, conv in enumerate(self.convs):
            x = F.relu(conv(x, edge_index))
        x = x.view(bs,-1,x.shape[-1])
        x = graph_pooling(x,numv)
        x = self.fc1(x)
        x = self.fc_dropout(x)
        x = self.fc2(x).view(-1)
        return x


