import dgl
import dgl.function as fn
from dgl.nn.pytorch import SAGEConv

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from itertools import accumulate

from graph_learning.module import ModuleConfig, get_module

@ModuleConfig.register('agnn',
                       help='[Encoder] AGNN')
class AGNNModuleConfig(ModuleConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @property
    def builder(self):
        return AGNN

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--in-size', type=int)
        parser.add_argument('--hidden-size', type=int)
        parser.add_argument('--num-layers', type=int)
        parser.add_argument('--anchor-size', type=int)
        parser.add_argument('--dropout', type=float)
        parser.add_argument('--bn', action='store_true',
                            help='use batchnorm')

def get_anchor_graph(g, anchors):
    anchor_g = dgl.DGLGraph().to(anchors.device)
    n = g.number_of_nodes()
    anchor_g.add_nodes(n)

    srcs = anchors.repeat_interleave(n)
    dsts = torch.arange(n, device=anchors.device).repeat(anchors.size(0))

    anchor_g.add_edges(srcs, dsts)
    return anchor_g

class AGNNLayer(nn.Module):
    def __init__(self, in_size, out_size, actication):
        super().__init__()
        self.activation = actication
        self.fc_hidden = nn.Linear(2*in_size, out_size)
        self.pos_out = nn.Linear(out_size, 1)

    def message_func(self, edges):
        h_src = edges.src['x']
        h_dst = edges.dst['x']
        dist = edges.data['dist']
        h = dist * torch.cat([h_src, h_dst], -1)
        h = self.fc_hidden(h)
        h = self.activation(h)
        return {'m': h}

    def reduce_func(self, nodes):
        # b, a, h
        m = nodes.mailbox['m']
        h = m.mean(1)
        h_pos = self.pos_out(m).squeeze(-1)
        return {'h': h, 'z': h_pos}

    def forward(self, graph, dists, x):
        graph = graph.local_var()

        graph.edata['dist'] = dists
        graph.ndata['x'] = x

        graph.update_all(
            self.message_func,
            self.reduce_func)

        h = graph.ndata['h']
        h_pos = graph.ndata['z']

        return h_pos, h

class AGNN(nn.Module):
    def __init__(self, in_size, hidden_size, bn,
                 num_layers, dropout, name):
        super().__init__()
        self.use_bn = bn
        self.name = name
        self.num_layers = num_layers

        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout)

        if self.use_bn:
            self.bn_layers = nn.ModuleList()
        self.agnn_layers = nn.ModuleList()

        for i in range(num_layers):
            conv_in_size = in_size if i==0 else hidden_size
            conv_out_size = hidden_size
            activation = self.activation
            self.agnn_layers.append(AGNNLayer(
                conv_in_size, conv_out_size, activation))
            if self.use_bn:
                self.bn_layers.append(nn.BatchNorm1d(conv_out_size))

    def forward(self, data, x):
        graph = data.graph()
        graph = graph.adapt(graph.local_var())

        anchor_graph = graph.gdata['_anchor_graph']
        dists = graph.gdata['_pgnn_dist']
        dists_e = dists[anchor_graph.edges()].unsqueeze(-1).float()

        h = x
        for i in range(self.num_layers):
            h_pos, h = self.agnn_layers[i](anchor_graph, dists_e, h)
            if self.use_bn:
                h = self.bn_layers[i](h)
            h = self.dropout(h)

        return {'hidden': h_pos,
                'outputs': {}}
