from graph_learning.module import ModuleConfig, register_module
from graph_learning.config import config_dispatch

import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import function as fn
from dgl.utils import expand_as_pair, check_eq_shape, dgl_warning

from .multi_layer_mp import CommonMultiLayerMPConfig

class SAGEConv(nn.Module):
    """ Adapted from dgl.
    """
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 weighted=False,
                 feat_drop=0.,
                 bias=True,
                 norm=None,
                 activation=None):
        super(SAGEConv, self).__init__()

        self.weighted = weighted

        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.feat_drop = nn.Dropout(feat_drop)
        self.activation = activation
        # aggregator type: mean/pool/lstm/gcn
        # if aggregator_type == 'pool':
        #     self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
        #     if weighted:
        #         self.fc_w_pool = nn.Linear(1, self._in_src_feats, bias=False)
        if aggregator_type == 'lstm':
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
        if aggregator_type != 'gcn':
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=False)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=False)
        if weighted:
            self.fc_w = nn.Linear(1, out_feats, bias=False)
        if bias:
            self.bias = nn.parameter.Parameter(torch.zeros(self._out_feats))
        else:
            self.register_buffer('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        r"""

        Description
        -----------
        Reinitialize learnable parameters.

        Note
        ----
        The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
        The LSTM module is using xavier initialization method for its weights.
        """
        gain = nn.init.calculate_gain('relu')
        # if self._aggre_type == 'pool':
        #     nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        #     if self.weighted:
        #         nn.init.xavier_uniform_(self.fc_w_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
        if self.weighted:
            nn.init.xavier_uniform_(self.fc_w.weight, gain=gain)

    def _compatibility_check(self):
        """Address the backward compatibility issue brought by #2747"""
        if not hasattr(self, 'bias'):
            dgl_warning("You are loading a GraphSAGE model trained from a old version of DGL, "
                        "DGL automatically convert it to be compatible with latest version.")
            bias = self.fc_neigh.bias
            self.fc_neigh.bias = None
            if hasattr(self, 'fc_self'):
                if bias is not None:
                    bias = bias + self.fc_self.bias
                    self.fc_self.bias = None
            self.bias = bias

    def _lstm_reducer(self, nodes):
        m = nodes.mailbox['m'] # (B, L, D)
        batch_size = m.shape[0]
        h = (m.new_zeros((1, batch_size, self._in_src_feats)),
             m.new_zeros((1, batch_size, self._in_src_feats)))
        _, (rst, _) = self.lstm(m, h)
        return {'neigh': rst.squeeze(0)}

    def forward(self, graph, feat):
        self._compatibility_check()
        with graph.local_scope():
            if isinstance(feat, tuple):
                feat_src = self.feat_drop(feat[0])
                feat_dst = self.feat_drop(feat[1])
            else:
                feat_src = feat_dst = self.feat_drop(feat)
                if graph.is_block:
                    feat_dst = feat_src[:graph.number_of_dst_nodes()]
            msg_fn = fn.copy_src('h', 'm')


            h_self = feat_dst

            # Handle the case of graphs without edges
            if graph.number_of_edges() == 0:
                graph.dstdata['neigh'] = torch.zeros(
                    feat_dst.shape[0], self._in_src_feats).to(feat_dst)

            # Determine whether to apply linear transformation before message passing A(XW)

            # Message Passing
            if self._aggre_type == 'mean':
                graph.srcdata['h'] = feat_src
                graph.update_all(msg_fn, fn.mean('m', 'neigh'))
                h_neigh = self.fc_neigh(graph.dstdata['neigh'])
                if self.weighted:
                    graph.update_all(fn.copy_e('weight', 'm'),
                                     fn.mean('m', 'nw'))
                    h_neigh += self.fc_w(graph.dstdata['nw'].unsqueeze(1))
            elif self._aggre_type == 'pool':
                if self.weighted:
                    graph.srcdata['h'] = self.fc_neigh(feat_src)
                    graph.apply_edges(fn.copy_u('h', 'hn'))
                    graph.edata['hw'] = self.fc_w(graph.edata['weight'].unsqueeze(1)) + graph.edata['hn']
                    # graph.edata['hw'] = F.relu(
                    #     self.fc_w_pool(graph.edata['weight'].unsqueeze(1)) +
                    #     graph.edata['hn'])
                    graph.update_all(fn.copy_e('hw', 'm'), fn.max('m', 'neigh'))
                else:
                    #graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
                    graph.srcdata['h'] = self.fc_neigh(feat_src)
                    graph.update_all(msg_fn, fn.max('m', 'neigh'))
                #h_neigh = self.fc_neigh(graph.dstdata['neigh'])
                h_neigh = graph.dstdata['neigh']
            elif self._aggre_type == 'lstm':
                graph.srcdata['h'] = feat_src
                graph.update_all(msg_fn, self._lstm_reducer)
                h_neigh = self.fc_neigh(graph.dstdata['neigh'])
            else:
                raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

            rst = self.fc_self(h_self) + h_neigh

            # bias term
            if self.bias is not None:
                rst = rst + self.bias

            # activation
            if self.activation is not None:
                rst = self.activation(rst)
            # normalization
            if self.norm is not None:
                rst = self.norm(rst)

            return rst

@ModuleConfig.register('sage')
class GraphSageModuleConfig(CommonMultiLayerMPConfig):
    def _layer_builder(self, in_size, out_size):
        return SAGEConv(
            in_size, out_size,
            aggregator_type=self.aggregator_type,
            weighted=self.weighted)

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--aggregator-type',
              choices=['mean', 'gcn', 'pool', 'lstm'])
        parser.add_argument('--weighted', action='store_true')
