# %load_ext autoreload
# %autoreload 2
# %matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.nn.pytorch import GraphConv, SAGEConv

# +
from dq.quantization import IntegerQuantizer
from dq.linear_quantized import LinearQuantized

def create_quantizer(qypte, ste, momentum, percentile, signed, sample_prop):
    if qypte == "FP32":
        return Identity
    else:
        return lambda: IntegerQuantizer(
            4 if qypte == "INT4" else 8,
            signed=signed,
            use_ste=ste,
            use_momentum=momentum,
            percentile=percentile,
            sample=sample_prop,
        )


# +
from dgl import function as fn
from dgl.utils import expand_as_pair, check_eq_shape, dgl_warning

class SAGEDQConv(nn.Module):
    '''
    SAGEConv layer from DGL, with the nn.Linear layer replaced by LinearQuantized layer
    
    '''
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 feat_drop=0.,
                 bias=True,
                 norm=None,
                 activation=None,
                layer_quantizers=None):
        super(SAGEDQConv, self).__init__()

        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.feat_drop = nn.Dropout(feat_drop)
        self.activation = activation
        # aggregator type: mean/pool/lstm/gcn
        if aggregator_type == 'pool':
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
        if aggregator_type == 'lstm':
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
            
        ''' Add Quantization '''
        if aggregator_type != 'gcn':
            # self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=False)
            self.fc_self = LinearQuantized(self._in_dst_feats, out_feats, layer_quantizers)

        # self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=False)
        self.fc_neigh = LinearQuantized(self._in_src_feats, out_feats, layer_quantizers)
        ''' Add Quantization '''
        
        
        if bias:
            self.bias = nn.parameter.Parameter(torch.zeros(self._out_feats))
        else:
            self.register_buffer('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        r"""

        Description
        -----------
        Reinitialize learnable parameters.

        Note
        ----
        The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
        The LSTM module is using xavier initialization method for its weights.
        """
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

    def _compatibility_check(self):
        """Address the backward compatibility issue brought by #2747"""
        if not hasattr(self, 'bias'):
            dgl_warning("You are loading a GraphSAGE model trained from a old version of DGL, "
                        "DGL automatically convert it to be compatible with latest version.")
            bias = self.fc_neigh.bias
            self.fc_neigh.bias = None
            if hasattr(self, 'fc_self'):
                if bias is not None:
                    bias = bias + self.fc_self.bias
                    self.fc_self.bias = None
            self.bias = bias

    def _lstm_reducer(self, nodes):
        """LSTM reducer
        NOTE(zihao): lstm reducer with default schedule (degree bucketing)
        is slow, we could accelerate this with degree padding in the future.
        """
        m = nodes.mailbox['m'] # (B, L, D)
        batch_size = m.shape[0]
        h = (m.new_zeros((1, batch_size, self._in_src_feats)),
             m.new_zeros((1, batch_size, self._in_src_feats)))
        _, (rst, _) = self.lstm(m, h)
        return {'neigh': rst.squeeze(0)}

    def forward(self, graph, feat, edge_weight=None):
        r"""

        Description
        -----------
        Compute GraphSAGE layer.

        Parameters
        ----------
        graph : DGLGraph
            The graph.
        feat : torch.Tensor or pair of torch.Tensor
            If a torch.Tensor is given, it represents the input feature of shape
            :math:`(N, D_{in})`
            where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
            If a pair of torch.Tensor is given, the pair must contain two tensors of shape
            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
        edge_weight : torch.Tensor, optional
            Optional tensor on the edge. If given, the convolution will weight
            with regard to the message.

        Returns
        -------
        torch.Tensor
            The output feature of shape :math:`(N_{dst}, D_{out})`
            where :math:`N_{dst}` is the number of destination nodes in the input graph,
            math:`D_{out}` is size of output feature.
        """
        self._compatibility_check()
        with graph.local_scope():
            if isinstance(feat, tuple):
                feat_src = self.feat_drop(feat[0])
                feat_dst = self.feat_drop(feat[1])
            else:
                feat_src = feat_dst = self.feat_drop(feat)
                if graph.is_block:
                    feat_dst = feat_src[:graph.number_of_dst_nodes()]
            msg_fn = fn.copy_src('h', 'm')
            if edge_weight is not None:
                assert edge_weight.shape[0] == graph.number_of_edges()
                graph.edata['_edge_weight'] = edge_weight
                msg_fn = fn.u_mul_e('h', '_edge_weight', 'm')

            h_self = feat_dst

            # Handle the case of graphs without edges
            if graph.number_of_edges() == 0:
                graph.dstdata['neigh'] = torch.zeros(
                    feat_dst.shape[0], self._in_src_feats).to(feat_dst)

            # Determine whether to apply linear transformation before message passing A(XW)
            lin_before_mp = self._in_src_feats > self._out_feats

            # Message Passing
            if self._aggre_type == 'mean':
                graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
                graph.update_all(msg_fn, fn.mean('m', 'neigh'))
                h_neigh = graph.dstdata['neigh']
                if not lin_before_mp:
                    h_neigh = self.fc_neigh(h_neigh)
            elif self._aggre_type == 'gcn':
                check_eq_shape(feat)
                graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
                if isinstance(feat, tuple):  # heterogeneous
                    graph.dstdata['h'] = self.fc_neigh(feat_dst) if lin_before_mp else feat_dst
                else:
                    if graph.is_block:
                        graph.dstdata['h'] = graph.srcdata['h'][:graph.num_dst_nodes()]
                    else:
                        graph.dstdata['h'] = graph.srcdata['h']
                graph.update_all(msg_fn, fn.sum('m', 'neigh'))
                # divide in_degrees
                degs = graph.in_degrees().to(feat_dst)
                h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
                if not lin_before_mp:
                    h_neigh = self.fc_neigh(h_neigh)
            elif self._aggre_type == 'pool':
                graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
                graph.update_all(msg_fn, fn.max('m', 'neigh'))
                h_neigh = self.fc_neigh(graph.dstdata['neigh'])
            elif self._aggre_type == 'lstm':
                graph.srcdata['h'] = feat_src
                graph.update_all(msg_fn, self._lstm_reducer)
                h_neigh = self.fc_neigh(graph.dstdata['neigh'])
            else:
                raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

            # GraphSAGE GCN does not require fc_self.
            if self._aggre_type == 'gcn':
                rst = h_neigh
            else:
                rst = self.fc_self(h_self) + h_neigh

            # bias term
            if self.bias is not None:
                rst = rst + self.bias

            # activation
            if self.activation is not None:
                rst = self.activation(rst)
            # normalization
            if self.norm is not None:
                rst = self.norm(rst)
            return rst

# -

'''
Adapted from the SAGE implementation from official dgl example
https://github.com/dmlc/dgl/blob/master/examples/pytorch/ogb/ogbn-products/graphsage/main.py
'''
class SAGEDQ(nn.Module):
    def __init__(self, num_layers, input_dim, hidden_dim, output_dim, dropout_ratio, activation, norm_type='none'):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.layers = nn.ModuleList()

        #qypte = "FP32"
        #qypte = "INT8"
        qypte = "INT4"

        ste = False
        momentum = False
        percentile = None
        sample_prop = None

        # ste quant
#         if args.ste_abs:
#             ste = True
#         elif args.ste_mom:
#             ste = True
#             momentum = True
#         elif args.gc_abs:
#             pass
#         elif args.gc_mom:
#             momentum = True
#         elif args.ste_per:
#             ste = True
#             percentile = 0.01 if args.int4 else 0.001
#         elif args.gc_per:
#             percentile = 0.01 if args.int4 else 0.001
#         else:
#             raise NotImplementedError


        ste = True
        momentum = True
        percentile = 0.01
        
        layer_quantizers = {
            "inputs": create_quantizer(
                qypte, ste, momentum, percentile, False, sample_prop
            ),
            "weights": create_quantizer(
                qypte, ste, momentum, percentile, True, sample_prop
            ),
            "features": create_quantizer(
                qypte, ste, momentum, percentile, True, sample_prop
            ),
        }
        
        
        if num_layers == 1:
            self.layers.append(SAGEDQConv(input_dim, output_dim, 'gcn', layer_quantizers=layer_quantizers))
        else:
            self.layers.append(SAGEDQConv(input_dim, hidden_dim, 'gcn', layer_quantizers=layer_quantizers))
            for i in range(num_layers - 2):
                self.layers.append(SAGEDQConv(hidden_dim, hidden_dim, 'gcn', layer_quantizers=layer_quantizers))
            self.layers.append(SAGEDQConv(hidden_dim, output_dim, 'gcn', layer_quantizers=layer_quantizers))
                
        self.dropout = nn.Dropout(dropout_ratio)
        self.activation = activation

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            # We need to first copy the representation of nodes on the RHS from the
            # appropriate nodes on the LHS.
            # Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
            # would be (num_nodes_RHS, D)
            h_dst = h[:block.num_dst_nodes()]
            # Then we compute the updated representation on the RHS.
            # The shape of h now becomes (num_nodes_RHS, D)
            h = layer(block, (h, h_dst))
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h


class ModelDQ(nn.Module):
    '''
    Wrapper of different GNN models
    '''
    def __init__(self, conf):
        super(ModelDQ, self).__init__()
        self.model_name = conf['model_name']
        if 'SAGEDQ' in conf['model_name']:
            self.encoder = SAGEDQ(
                num_layers=conf['num_layers'],
                input_dim=conf['feat_dim'],
                hidden_dim=conf['hidden_dim'],
                output_dim=conf['label_dim'],
                dropout_ratio=conf['dropout_ratio'],
                activation=F.relu,
                norm_type=conf['norm_type']).to(conf['device'])
        elif 'MLPDQ' in conf['model_name']:
            self.encoder = MLPDQ(num_layers=conf['num_layers'],
                        input_dim=conf['feat_dim'],
                        hidden_dim=conf['hidden_dim'],
                        output_dim=conf['label_dim'],
                        dropout_ratio=conf['dropout_ratio'],
                        norm_type=conf['norm_type']).to(conf['device'])

    def forward(self, data, feats):
        return self.encoder(data, feats)



class MLPDQ(nn.Module):
    def __init__(self, num_layers, input_dim, hidden_dim, output_dim, dropout_ratio, norm_type='none'):
        super(MLPDQ, self).__init__()
        self.dropout = nn.Dropout(dropout_ratio)
        self.norm_type = norm_type

        self.layers = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()

        
        #qypte = "FP32"
        #qypte = "INT8"
        qypte = "INT4"

        ste = False
        momentum = False
        percentile = None
        sample_prop = None

        # ste quant
#         if args.ste_abs:
#             ste = True
#         elif args.ste_mom:
#             ste = True
#             momentum = True
#         elif args.gc_abs:
#             pass
#         elif args.gc_mom:
#             momentum = True
#         elif args.ste_per:
#             ste = True
#             percentile = 0.01 if args.int4 else 0.001
#         elif args.gc_per:
#             percentile = 0.01 if args.int4 else 0.001
#         else:
#             raise NotImplementedError


        ste = True
        momentum = True
        percentile = 0.01
        
        layer_quantizers = {
            "inputs": create_quantizer(
                qypte, ste, momentum, percentile, False, sample_prop
            ),
            "weights": create_quantizer(
                qypte, ste, momentum, percentile, True, sample_prop
            ),
            "features": create_quantizer(
                qypte, ste, momentum, percentile, True, sample_prop
            ),
        }
        
        if num_layers == 1:
            self.layers.append(LinearQuantized(input_dim, output_dim, layer_quantizers))
        else:
            self.layers.append(LinearQuantized(input_dim, hidden_dim, layer_quantizers))
            if self.norm_type == 'batch':
                self.norms.append(nn.BatchNorm1d(hidden_dim))
            elif self.norm_type == 'layer':
                self.norms.append(nn.LayerNorm(hidden_dim))

            for i in range(num_layers - 2):
                self.layers.append(LinearQuantized(hidden_dim, hidden_dim, layer_quantizers))
                if self.norm_type == 'batch':
                    self.norms.append(nn.BatchNorm1d(hidden_dim))
                elif self.norm_type == 'layer':
                    self.norms.append(nn.LayerNorm(hidden_dim))

            self.layers.append(LinearQuantized(hidden_dim, output_dim, layer_quantizers))

    def forward(self, feats):
        h = feats
        for i, layer in enumerate(self.layers[:-1]):
            h = layer(h)                
            if self.norm_type != 'none':
                h = self.norms[i](h)
            h = F.relu(h)
            h = self.dropout(h)
        h = self.layers[-1](h)
        return h
