# coding:utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv as GCNConv_ASAP
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import TopKPooling, SAGPooling, ASAPooling, PANPooling
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from torch_sparse import SparseTensor
from torch_geometric.utils import degree
from torch_sparse import SparseTensor


###########################################################
#             graph convolution structure                 #
###########################################################
# GIN convolution along the graph structure
class GINConv(MessagePassing):
    def __init__(self, emb_dim, use_edge_attr=True):
        '''
            emb_dim (int): node embedding dimensionality
        '''
        super(GINConv, self).__init__(aggr = "add")
        self.use_edge_attr = use_edge_attr
        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
        self.eps = torch.nn.Parameter(torch.Tensor([0.001]))
        self.edge_encoder = torch.nn.Linear(2, emb_dim)

    def forward(self, x, edge_index, edge_attr=None):
        if self.use_edge_attr:
            edge_embedding = self.edge_encoder(edge_attr)
        else:
            edge_embedding = None
        out = self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))

        return out

    def message(self, x_j, edge_attr):
        if self.use_edge_attr:
            return F.relu(x_j + edge_attr)
        else:
            return F.relu(x_j)

    def update(self, aggr_out):
        return aggr_out


### GCN convolution along the graph structure
class GCNConv(MessagePassing):
    def __init__(self, emb_dim, use_edge_attr=True):
        super(GCNConv, self).__init__(aggr='add')

        self.use_edge_attr = use_edge_attr
        # self.linear = torch.nn.Linear(emb_dim, emb_dim)
        self.linear = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
                                          torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, emb_dim))
        self.root_emb = torch.nn.Embedding(1, emb_dim)
        if self.use_edge_attr:
            self.edge_encoder = torch.nn.Linear(2, emb_dim)

    def forward(self, x, edge_index, edge_attr=None):
        x = self.linear(x)
        if self.use_edge_attr:
            edge_embedding = self.edge_encoder(edge_attr)
        else:
            edge_embedding = None
        row, col = edge_index

        #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)
        deg = degree(row, x.size(0), dtype = x.dtype) + 1
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        out = self.propagate(edge_index, x=x, edge_attr=edge_embedding, norm=norm) + F.relu(
            x + self.root_emb.weight) * 1. / deg.view(-1, 1)

        return out

    def message(self, x_j, edge_attr, norm):
        if self.use_edge_attr:
            return norm.view(-1, 1) * F.relu(x_j + edge_attr)
        else:
            return norm.view(-1, 1) * F.relu(x_j)

    def update(self, aggr_out):
        return aggr_out


# PEM convolution
class PEMConv(MessagePassing):
    def __init__(self, emb_dim, use_edge_attr=True):
        '''
            emb_dim (int): node embedding dimensionality
        '''
        super(PEMConv, self).__init__(aggr = "add")
        self.use_edge_attr = use_edge_attr
        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
                                       torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, emb_dim))
        self.eps = torch.nn.Parameter(torch.Tensor([0.001]))
        self.edge_encoder = torch.nn.Linear(2, emb_dim)

    def forward(self, x, edge_index, edge_attr=None):
        if self.use_edge_attr:
            edge_embedding = self.edge_encoder(edge_attr)
        else:
            edge_embedding = None

        col, row = edge_index
        N = x.shape[0]
        # x = self.mlp(x)
        adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N))  # .to_device(x.device)
        # adj = adj.set_diag()
        deg = adj.sum(dim=1).to(torch.float)
        deg_inv_sqrt_trn = deg.pow(-0.5)
        deg_inv_sqrt_trn[deg_inv_sqrt_trn == float('inf')] = 0
        adj = deg_inv_sqrt_trn.view(-1, 1) * adj * deg_inv_sqrt_trn.view(1, -1)
        weights = adj.sum(dim=1).to(torch.float) + self.eps

        if self.use_edge_attr:
            out = self.mlp(weights.view(-1, 1) * x + adj @ x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
        else:
            out = self.mlp(weights.view(-1, 1) * x + adj @ x)

        return out

    def message(self, x_j, edge_attr):
        if self.use_edge_attr:
            return F.relu(edge_attr)
        else:
            return F.relu(x_j)

    def update(self, aggr_out):
        return aggr_out


###########################################################
#                      Permutation                        #
###########################################################
# Permutation Invariance Model
class NetPooling(torch.nn.Module):
    def __init__(self, args, node_encoder):
        super(NetPooling, self).__init__()
        self.num_layer = args.num_layer
        self.emb_dim = args.emb_dim
        self.num_tasks = args.num_vocab
        self.use_edge_attr = args.use_edge_attr
        self.max_seq_len = args.max_seq_len

        self.JK = 'sum'
        self.graph_pooling = args.pool
        self.node_encoder = node_encoder
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        for i in range(self.num_layer):

            # asap and pan can not use edge_attr in the second layer
            if self.graph_pooling == "asap" and i >= 1:
                self.use_edge_attr = False
            if self.graph_pooling == "pan" and i >= 1:
                self.use_edge_attr = False

            if args.gnn == 'gin':
                self.convs.append(GINConv(emb_dim=self.emb_dim, use_edge_attr=self.use_edge_attr))
            elif args.gnn == 'gcn':
                self.convs.append(GCNConv(emb_dim=self.emb_dim, use_edge_attr=self.use_edge_attr))
            elif args.gnn == 'pem':
                self.convs.append(PEMConv(emb_dim=self.emb_dim, use_edge_attr=self.use_edge_attr))
            self.batch_norms.append(torch.nn.BatchNorm1d(self.emb_dim))

        ### Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sag":
            for i in range(args.num_layer):
                self.pools.append(SAGPooling(self.emb_dim))
        elif self.graph_pooling == "asap":
            for i in range(args.num_layer):
                self.pools.append(ASAPooling(self.emb_dim, GNN=GCNConv_ASAP))
        elif self.graph_pooling == "pan":
            for i in range(args.num_layer):
                self.pools.append(PANPooling(self.emb_dim))
        elif self.graph_pooling == "topk":
            for i in range(args.num_layer):
                self.pools.append(TopKPooling(self.emb_dim))
        else:
            raise ValueError("Invalid graph pooling type.")

        self.lin_L = nn.Sequential(nn.Linear(self.emb_dim * 2, self.emb_dim // 2), nn.BatchNorm1d(self.emb_dim // 2),
                                   nn.ReLU(), nn.AdaptiveAvgPool1d(self.emb_dim // 4),
                                   nn.Linear(self.emb_dim // 4, self.emb_dim // 16),
                                   nn.BatchNorm1d(self.emb_dim // 16), nn.ReLU(),
                                   nn.Linear(self.emb_dim // 16, self.emb_dim))
        self.lin_R = nn.Sequential(nn.Linear(self.emb_dim * 2, self.emb_dim // 2), nn.BatchNorm1d(self.emb_dim // 2),
                                   nn.ReLU(), nn.AdaptiveAvgPool1d(self.emb_dim // 4),
                                   nn.Linear(self.emb_dim // 4, self.emb_dim // 16),
                                   nn.BatchNorm1d(self.emb_dim // 16), nn.ReLU(),
                                   nn.Linear(self.emb_dim // 16, self.emb_dim))
        self.graph_pred_linear_list = torch.nn.ModuleList()
        for i in range(self.max_seq_len):
            self.graph_pred_linear_list.append(nn.Sequential(nn.Linear(self.emb_dim, self.num_tasks)))

    def forward(self, batched_data):
        x, edge_index, edge_attr, node_depth, batch = \
            batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.node_depth, batched_data.batch
        edge_weight = None  # for ASAP
        if self.graph_pooling == "pan":
            edge_weight = torch.ones(edge_index.shape[1]).to('cuda')

        h_list = [self.node_encoder(x, node_depth.view(-1,))]
        h_pools = []
        for layer in range(self.num_layer):
            h = self.convs[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            if layer != self.num_layer - 1:
                h = F.relu(h)
            if self.graph_pooling == "sag" or self.graph_pooling == "topk":
                h, edge_index, edge_attr, batch, perm, score = self.pools[layer](h, edge_index, edge_attr, batch)
            if self.graph_pooling == "asap":
                h, edge_index, edge_weight, batch, perm = self.pools[layer](h, edge_index, edge_weight, batch)
            if self.graph_pooling == "pan":
                M = SparseTensor.from_edge_index(edge_index=edge_index, edge_attr=edge_weight)
                h, edge_index, edge_weight, batch, perm, score = self.pools[layer](h, M, batch)

            h_list.append(h)
            h_pool = torch.cat([gmp(h, batch), gap(h, batch)], dim=1)
            h_pools.append(h_pool)

        ### Different implementations of Jk-concat
        if self.JK == "last":
            h_graph = h_list[-1]
        elif self.JK == "sum":
            h_graph = 0
            for layer in range(self.num_layer):
                h_graph += h_pools[layer]

        h_graph_L = self.lin_L(h_graph)
        h_graph_R = self.lin_R(h_graph)
        h_graph = h_graph_L + h_graph_R
        pred_list = []
        for i in range(self.max_seq_len):
            pred_list.append(self.graph_pred_linear_list[i](h_graph))

        return pred_list