"""
support both sparsetensor and edgetensor, message passing quant with scale absorption
"""
import inspect
from collections import OrderedDict
from typing import Optional, Tuple, Union
from torch_geometric.typing import (
    Adj,
    OptTensor,
    PairTensor,
    OptPairTensor,
    SparseTensor,
    torch_sparse,
)
from torch import Tensor
from torch.nn import Parameter, Module, ModuleDict

import torch
import torch.nn.functional as F

from torch_geometric.utils import (
    softmax,
    add_self_loops,
    remove_self_loops,
    add_remaining_self_loops,
)
import torch_scatter
from torch_scatter import scatter_add
from torch_geometric.nn.inits import glorot, zeros
import torch_sparse
from torch_sparse import SparseTensor, set_diag

from torch_geometric.utils import spmm
import numpy as np


def scatter_(name, src, index, dim=0, dim_size=None):
    """Taken from an earlier version of PyG"""
    assert name in ["add", "mean", "min", "max"]

    op = getattr(torch_scatter, "scatter_{}".format(name))
    out = op(src, index, dim, None, dim_size)
    out = out[0] if isinstance(out, tuple) else out

    if name == "max":
        out[out < -10000] = 0
    elif name == "min":
        out[out > 10000] = 0

    return out

msg_special_args = set(
    [
        "edge_index",
        "edge_index_i",
        "edge_index_j",
        "size",
        "size_i",
        "size_j",
    ]
)

aggr_special_args = set(
    [
        "index",
        "dim_size",
    ]
)

update_special_args = set([])

# due to a collision with pytorch using the key "update"
REQUIRED_QUANTIZER_KEYS = ["aggregate", "message", "update_q"]
REQUIRED_GCN_KEYS = ["weights", "inputs", "features", "norm"]
REQUIRED_GAT_KEYS = ["weights", "inputs", "features", "attention_l", "attention_r", "alpha"]
REQUIRED_GS_KEYS = ["inputs", "inputs_r", "pool", "weights", "weights_l", "weights_r", "features", "root_features", "outputs"]



class MessagePassingQuant(Module):
    """Modified from the PyTorch Geometric message passing class"""

    def __init__(
        self, aggr="add", flow="source_to_target", node_dim=0, mp_quantizers=None,
    ):
        super(MessagePassingQuant, self).__init__()

        self.aggr = aggr
        assert self.aggr in ["add", "mean", "max"]

        self.flow = flow
        assert self.flow in ["source_to_target", "target_to_source"]

        self.node_dim = node_dim
        assert self.node_dim >= 0

        self.__msg_params__ = inspect.signature(self.message).parameters
        self.__msg_params__ = OrderedDict(self.__msg_params__)

        self.__aggr_params__ = inspect.signature(self.aggregate).parameters
        self.__aggr_params__ = OrderedDict(self.__aggr_params__)
        self.__aggr_params__.popitem(last=False)

        self.__update_params__ = inspect.signature(self.update).parameters
        self.__update_params__ = OrderedDict(self.__update_params__)
        self.__update_params__.popitem(last=False)

        msg_args = set(self.__msg_params__.keys()) - msg_special_args
        aggr_args = set(self.__aggr_params__.keys()) - aggr_special_args
        update_args = set(self.__update_params__.keys()) - update_special_args

        self.__args__ = set().union(msg_args, aggr_args, update_args)

        assert mp_quantizers is not None
        self.mp_quant_fns = mp_quantizers

    def reset_parameters(self):
        self.mp_quantizers = ModuleDict()
        for key in REQUIRED_QUANTIZER_KEYS:
            self.mp_quantizers[key] = self.mp_quant_fns[key]()
    

    def __set_size__(self, size, index, tensor):
        if not torch.is_tensor(tensor):
            pass
        elif size[index] is None:
            size[index] = tensor.size(self.node_dim)
        elif size[index] != tensor.size(self.node_dim):
            raise ValueError(
                (
                    f"Encountered node tensor with size "
                    f"{tensor.size(self.node_dim)} in dimension {self.node_dim}, "
                    f"but expected size {size[index]}."
                )
            )

    def __collect__(self, edge_index, size, kwargs):
        i, j = (0, 1) if self.flow == "target_to_source" else (1, 0)
        ij = {"_i": i, "_j": j}

        out = {}
        for arg in self.__args__:
            if arg[-2:] not in ij.keys():
                out[arg] = kwargs.get(arg, inspect.Parameter.empty)
            else:
                idx = ij[arg[-2:]]
                data = kwargs.get(arg[:-2], inspect.Parameter.empty)

                if data is inspect.Parameter.empty:
                    out[arg] = data
                    continue

                if isinstance(data, tuple) or isinstance(data, list):
                    assert len(data) == 2
                    self.__set_size__(size, 1 - idx, data[1 - idx])
                    data = data[idx]

                if not torch.is_tensor(data):
                    out[arg] = data
                    continue

                self.__set_size__(size, idx, data)
                
                if isinstance(edge_index, SparseTensor):
                    out[arg] = data # in its original form of N * F
                else:
                    out[arg] = data.index_select(self.node_dim, edge_index[idx]) # modify into form of E * F

        size[0] = size[1] if size[0] is None else size[0]
        size[1] = size[0] if size[1] is None else size[1]

        # Add special message arguments.
        out["edge_index"] = edge_index
        if isinstance(edge_index, SparseTensor):
            row, col, edge_attr = edge_index.t().coo()
            edge = torch.stack([row, col], dim=0)
            out["edge_index_i"] = edge[i]
            out["edge_index_j"] = edge[j]
        else:
            out["edge_index_i"] = edge_index[i]
            out["edge_index_j"] = edge_index[j]
        out["size"] = size
        out["size_i"] = size[i]
        out["size_j"] = size[j]
        out["node_wise"] = True if isinstance(edge_index, SparseTensor) else False



        # Add special aggregate arguments.
        out["index"] = out["edge_index_i"]
        out["dim_size"] = out["size_i"]
        out['ptr'] = None

        # Add TOPG arguments.
        out['q_group'] = kwargs.get('q_group', None)

        return out

    def __distribute__(self, params, kwargs):
        out = {}
        for key, param in params.items():
            data = kwargs[key]
            if data is inspect.Parameter.empty:
                if param.default is inspect.Parameter.empty:
                    raise TypeError(f"Required parameter {key} is empty.")
                data = param.default
            out[key] = data
        return out

    def propagate(self, edge_index, size=None, **kwargs): # q_group = (inf_mask, hop_num, indegree)
        size = [None, None] if size is None else size
        size = [size, size] if isinstance(size, int) else size
        size = size.tolist() if torch.is_tensor(size) else size
        size = list(size) if isinstance(size, tuple) else size
        assert isinstance(size, list)
        assert len(size) == 2
        q_group = kwargs.get('q_group', None)
        kwargs = self.__collect__(edge_index, size, kwargs)


        # message

        msg_kwargs = self.__distribute__(self.__msg_params__, kwargs)
        msg = self.message(**msg_kwargs)
        if isinstance(msg, tuple):
            msg, edge_index = msg
        # aggregation

        aggr_kwargs = self.__distribute__(self.__aggr_params__, kwargs)
        if isinstance(edge_index, SparseTensor) or isinstance(edge_index, tuple):
            aggr_kwargs["index"] = edge_index
        
        agg = self.aggregate(msg, **aggr_kwargs)
        agg_size = agg.shape if agg.dim() == 3 else None
        agg = agg.view(agg.shape[0], -1)
        out = self.mp_quantizers["aggregate"](agg, True, False, q_group=q_group)
        if agg_size is not None: out = out.view(agg_size)

        # update
        update_kwargs = self.__distribute__(self.__update_params__, kwargs)
        up = self.update(out, **update_kwargs)
        out = self.mp_quantizers["update_q"](up, True, False, q_group=q_group)



        return out

    def message(self, x_j):  # pragma: no cover
        return x_j

    def aggregate(self, inputs, index, dim_size):  # pragma: no cover
        if isinstance(index, SparseTensor):
            #return torch_sparse.matmul(index, inputs, self.aggr)
            return spmm(index, inputs, self.aggr)
        else:
            return scatter_(self.aggr, inputs, index, self.node_dim, dim_size)

    def update(self, inputs):  # pragma: no cover
        return inputs
    
    def freeze_quantization_parameters(self):
        for key in REQUIRED_QUANTIZER_KEYS:
            self.mp_quantizers[key].freeze_quantization_parameters()

    def unfreeze_quantization_parameters(self):
        for key in REQUIRED_QUANTIZER_KEYS:
            self.mp_quantizers[key].unfreeze_quantization_parameters()
    
    def reset_quantizers(self, mp_quantizers):
        assert mp_quantizers is not None
        self.mp_quant_fns = mp_quantizers
        self.mp_quantizers = ModuleDict()
        for key in REQUIRED_QUANTIZER_KEYS:
            self.mp_quantizers[key] = self.mp_quant_fns[key]()



class GCNConvQuant(MessagePassingQuant):
    def __init__(
        self,
        in_channels,
        out_channels, 
        bias=True,
        normalize=True,
        mp_quantizers=None,
        layer_quantizers=None,
        improved=False,
        cached=False,
        heads=None,
        **kwargs,
    ):
        super(GCNConvQuant, self).__init__(aggr="add", mp_quantizers=mp_quantizers, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.improved = improved
        self.cached = cached
        self.normalize = normalize

        self.weight = Parameter(torch.Tensor(in_channels, out_channels))

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter("bias", None)

        assert layer_quantizers is not None
        self.layer_quant_fns = layer_quantizers
        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        glorot(self.weight)
        zeros(self.bias)
        self.cached_result = None
        self.cached_num_edges = None
        # create quantization modules for this layer
        self.layer_quantizers = ModuleDict()
        for key in REQUIRED_GCN_KEYS:
            self.layer_quantizers[key] = self.layer_quant_fns[key]()

    def reset_quantizers(self, mp_quantizers, layer_quantizers):
        # reset message_passing
        super().reset_quantizers(mp_quantizers=mp_quantizers)
        # reset layer
        assert layer_quantizers is not None
        self.layer_quant_fns = layer_quantizers
        self.layer_quantizers = ModuleDict()
        for key in REQUIRED_GCN_KEYS:
            self.layer_quantizers[key] = self.layer_quant_fns[key]()

    def freeze_quantization_parameters(self):
        super().freeze_quantization_parameters()
        for key in REQUIRED_GCN_KEYS:
            self.layer_quantizers[key].freeze_quantization_parameters()

    def unfreeze_quantization_parameters(self):
        super().unfreeze_quantization_parameters()
        for key in REQUIRED_GCN_KEYS:
            self.layer_quantizers[key].unfreeze_quantization_parameters()

    @staticmethod
    def norm(edge_index, num_nodes, edge_weight=None, improved=False, dtype=None):
        
        fill_value = 1 if not improved else 2
        
        if isinstance(edge_index, SparseTensor):
            
            assert edge_index.size(0) == edge_index.size(1)

            if not edge_index.has_value():
                edge_index = edge_index.fill_value(1, dtype=dtype)
            edge_index = torch_sparse.fill_diag(edge_index, fill_value)

            deg = torch_sparse.sum(edge_index, dim=1)
            deg_inv_sqrt = deg.to(torch.float).pow_(-0.5)
            deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
            edge_index = torch_sparse.mul(edge_index, deg_inv_sqrt.view(-1, 1))
            edge_index = torch_sparse.mul(edge_index, deg_inv_sqrt.view(1, -1))

            # get edge_weights separately
            row, col, edge_attr = edge_index.t().coo()
            edge_index = torch.stack([row, col], dim=0) # back to EdgeTensor format
            return edge_index, edge_attr, True
    
        
        if edge_weight is None:
            edge_weight = torch.ones(
                (edge_index.size(1),), dtype=dtype, device=edge_index.device
            )

        edge_index, edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, fill_value, num_nodes
        )

        row, col = edge_index
        deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)

        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col], False
    

    def forward(self, x, edge_index, q_group=None, edge_weight=None):
        x_q = self.layer_quantizers["inputs"](x, True, False, None, q_group)
        w_q = self.layer_quantizers["weights"](self.weight)
        x = torch.matmul(x_q, w_q)
        
        
        x = self.layer_quantizers["features"](x, True, False, None, q_group)

        if self.cached and self.cached_result is not None:
            if edge_index.size(1) != self.cached_num_edges:
                raise RuntimeError(
                    "Cached {} number of edges, but found {}. Please "
                    "disable the caching behavior of this layer by removing "
                    "the `cached=True` argument in its constructor.".format(
                        self.cached_num_edges, edge_index.size(1)
                    )
                )

        if not self.cached or self.cached_result is None:
            self.cached_num_edges = edge_index.size(1)
            if self.normalize:
                edge_index, norm, is_sparse = self.norm(
                    edge_index,
                    x.size(self.node_dim),
                    edge_weight,
                    self.improved,
                    x.dtype,
                )
            else:
                norm = edge_weight
                is_sparse = True if isinstance(edge_index, SparseTensor) else False            
            
            self.cached_result = edge_index, norm

        edge_index, norm = self.cached_result
        
        return self.propagate(edge_index, x=x, norm=norm, q_group=q_group)

    def message(self, x_j, norm, edge_index=None, node_wise=None, q_group=None):
        
        x_j_int, norm_ = self.mp_quantizers["message"](x_j, norm, node_wise, edge_index=edge_index, q_group=q_group)
        norm = self.layer_quantizers["norm"](norm_, False, True, edge_index=edge_index, q_group=q_group)
        if not node_wise:
            x_all = norm.view(-1, 1) * x_j_int
        else:
            edge_index = edge_index.t()
            edge_index = edge_index.set_value(norm, 'coo')
            edge_index = edge_index.t()
            x_all = x_j_int

        return x_all

    def update(self, aggr_out):
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

    def __repr__(self):
        return "{}({}, {})".format(
            self.__class__.__name__, self.in_channels, self.out_channels
        )


class GATConvQuant(MessagePassingQuant):
    def __init__(
        self,
        in_channels,
        out_channels,
        heads=1,
        concat=True,
        negative_slope=0.2,
        dropout=0,
        bias=True,
        mp_quantizers=None,
        layer_quantizers=None,
        **kwargs,
    ):
        super(GATConvQuant, self).__init__(aggr="add", mp_quantizers=mp_quantizers, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout

        self.weight = Parameter(torch.Tensor(in_channels, heads * out_channels))
        #self.att = Parameter(torch.Tensor(1, heads, 2 * out_channels))
        self.att_l = Parameter(torch.Tensor(1, heads, out_channels))
        self.att_r = Parameter(torch.Tensor(1, heads, out_channels))

        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter("bias", None)
        self._alpha = None

        assert layer_quantizers is not None
        self.layer_quant_fns = layer_quantizers
        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        glorot(self.weight)
        glorot(self.att_l)
        glorot(self.att_r)
        zeros(self.bias)

        # create quantization modules for this layer
        self.layer_quantizers = ModuleDict()
        for key in REQUIRED_GAT_KEYS:
            self.layer_quantizers[key] = self.layer_quant_fns[key]()

    def reset_quantizers(self, mp_quantizers, layer_quantizers):
        # reset message_passing
        super().reset_quantizers(mp_quantizers=mp_quantizers)
        # reset layer
        assert layer_quantizers is not None
        self.layer_quant_fns = layer_quantizers
        self.layer_quantizers = ModuleDict()
        for key in REQUIRED_GAT_KEYS:
            self.layer_quantizers[key] = self.layer_quant_fns[key]()

    def freeze_quantization_parameters(self):
        super().freeze_quantization_parameters()
        for key in REQUIRED_GAT_KEYS:
            self.layer_quantizers[key].freeze_quantization_parameters()

    def unfreeze_quantization_parameters(self):
        super().unfreeze_quantization_parameters()
        for key in REQUIRED_GAT_KEYS:
            self.layer_quantizers[key].unfreeze_quantization_parameters()


    def forward(self, x, edge_index, q_group=None, size=None):
    
        x_q = self.layer_quantizers["inputs"](x, True, False, q_group=q_group)
        w_q = self.layer_quantizers["weights"](self.weight)

        if size is None and torch.is_tensor(x_q):
            if isinstance(edge_index, SparseTensor):
                edge_index = torch_sparse.remove_diag(edge_index)
                edge_index = torch_sparse.fill_diag(edge_index, fill_value=1)
            else:
                edge_index, _ = remove_self_loops(edge_index)
                edge_index, _ = add_self_loops(edge_index, num_nodes=x_q.size(0))

        if torch.is_tensor(x_q):
            x = torch.matmul(x_q, w_q)
            x_q = self.layer_quantizers["features"](x, True, False, q_group=q_group)

        else:
            x = (
                None if x_q[0] is None else torch.matmul(x_q[0], w_q),
                None if x_q[1] is None else torch.matmul(x_q[1], w_q),
            )

            x_q = (
                None if x[0] is None else self.layer_quantizers["features"](x[0]),
                None if x[1] is None else self.layer_quantizers["features"](x[1]),
            )

        return self.propagate(edge_index, size=size, x=x_q, q_group=q_group)

    def message(self, edge_index_i, edge_index_j, x_i, x_j, size_i, node_wise, edge_index, q_group):
        assert x_i is not None

        H, C = self.heads, self.out_channels

        x_i = x_i.view(-1, H, C)
        x_j = x_j.view(-1, H, C)
        att_l_q = self.layer_quantizers["attention_l"](self.att_l)
        att_r_q = self.layer_quantizers["attention_r"](self.att_r)

        alpha_l = (x_i * att_l_q).sum(dim=-1)
        alpha_r = (x_j * att_r_q).sum(dim=-1)
        if isinstance(edge_index, SparseTensor):
            alpha = alpha_l[edge_index_i] + alpha_r[edge_index_j]
        else:
            alpha = alpha_l + alpha_r
        alpha = F.leaky_relu(alpha, self.negative_slope)
        alpha = softmax(alpha, edge_index_i, num_nodes=size_i)

        
        msg_size = x_j.shape
        x_j = x_j.view(msg_size[0], -1)
        
        x_j_int, alpha_ = self.mp_quantizers["message"](x_j, alpha, node_wise, edge_index=edge_index, q_group=q_group)
        alpha = self.layer_quantizers["alpha"](alpha_, False, True, 
                                               edge_index=torch.stack([edge_index_j,edge_index_i], dim=0), q_group=q_group)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        x_j_int = x_j_int.view(msg_size)


        """
        if isinstance(edge_index, SparseTensor):
            edge_index = edge_index.t()
            edge_index = edge_index.set_value(alpha, 'coo')        
            edge_index = edge_index.t()
            return x_j_int.view(msg_size[0], -1)
        """
        if isinstance(edge_index, SparseTensor):
            edges = ()
            edge_index = edge_index.t()
            for i in range(self.heads):
                edge_head = edge_index.set_value(alpha[:, i].squeeze(), 'coo')        
                edge_head = edge_head.t()
                edges += (edge_head, )
            #return (x_j.view(-1, H*C), edge_index)
            return (x_j_int, edges)        
        else:
            return x_j_int * alpha.unsqueeze(-1)
        
    def aggregate(self, inputs, index, dim_size):  # pragma: no cover
        #if isinstance(index, SparseTensor):
        if isinstance(index, tuple):
            out = []
            for i in range(len(index)):
                out.append(spmm(index[i], inputs[:, i, :].squeeze(), self.aggr).unsqueeze(1))
            return torch.cat(out, dim=1)
        else:
            return scatter_(self.aggr, inputs, index, self.node_dim, dim_size)
    def update(self, aggr_out):
        if self.concat:
            aggr_out = aggr_out.view(-1, self.heads * self.out_channels)
        else:
            aggr_out = aggr_out.mean(dim=1)

        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

    def __repr__(self):
        return "{}({}, {}, heads={})".format(
            self.__class__.__name__, self.in_channels, self.out_channels, self.heads
        )


class GINConvQuant(MessagePassingQuant):
    """This is a reimplemented version of GINConv that uses the update function"""

    def __init__(self, nn, eps=0, train_eps=True, mp_quantizers=None, layer_quantizers=None, heads=None, **kwargs):
        super(GINConvQuant, self).__init__(
            aggr="add", mp_quantizers=mp_quantizers, **kwargs
        )
        self.nn = nn
        self.initial_eps = eps
        if train_eps:
            self.eps = torch.nn.Parameter(torch.Tensor([eps]))
        else:
            self.register_buffer("eps", torch.Tensor([eps]))
        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        self.mp_quantizers["merge"] = self.mp_quant_fns["merge"]()
        self.nn.reset_parameters()
        self.eps.data.fill_(self.initial_eps)

    def reset_quantizers(self, mp_quantizers, layer_quantizers):
        # reset message_passing
        super().reset_quantizers(mp_quantizers=mp_quantizers)
        self.mp_quantizers["merge"] = self.mp_quant_fns["merge"]()
        # reset layer
        assert layer_quantizers is not None
        self.nn.reset_quantizers(layer_quantizers=layer_quantizers)

    def freeze_quantization_parameters(self):
        super().freeze_quantization_parameters()
        self.mp_quantizers['merge'].update = False
        self.nn.freeze_quantization_parameters() 

    def unfreeze_quantization_parameters(self):
        super().unfreeze_quantization_parameters()
        self.mp_quantizers['merge'].update = True
        self.nn.unfreeze_quantization_parameters() 

    def forward(self, x, edge_index, q_group=None):
        x = x.unsqueeze(-1) if x.dim() == 1 else x
        if isinstance(edge_index, SparseTensor):
            edge_index = torch_sparse.remove_diag(edge_index)
        else:
            edge_index, _ = remove_self_loops(edge_index)
        out = self.propagate(edge_index, x=x, q_group=q_group)

        return out

    def message(self, x_j, node_wise, edge_index, q_group):

        norm_shape = edge_index.t().coo()[0].shape if isinstance(edge_index, SparseTensor) else x_j.shape[0]
        x_j_int, norm_ = self.mp_quantizers["message"](x_j, torch.ones(norm_shape, device=x_j.device), node_wise, edge_index=edge_index, q_group=q_group)
        norm = self.mp_quantizers["merge"](norm_, False, True, edge_index=edge_index, q_group=q_group)

        if not node_wise:
            x_all = norm.view(-1, 1) * x_j_int
        else:
            edge_index = edge_index.t()
            edge_index = edge_index.set_value(norm, 'coo')
            edge_index = edge_index.t()
            x_all = x_j_int
        return x_all

    def update(self, aggr_out, x, q_group):
        return self.nn((1 + self.eps) * x + aggr_out, q_group=q_group)

    def __repr__(self):
        return "{}(nn={})".format(self.__class__.__name__, self.nn)



class SAGEConvQuant(MessagePassingQuant): # default : max_pool operator
    r"""The GraphSAGE operator from the `"Inductive Representation Learning on
    Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper

    .. math::
        \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \cdot
        \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j

    Args:
        in_channels (int or tuple): Size of each input sample. A tuple
            corresponds to the sizes of source and target dimensionalities.
        out_channels (int): Size of each output sample.
        normalize (bool, optional): If set to :obj:`True`, output features
            will be :math:`\ell_2`-normalized, *i.e.*,
            :math:`\frac{\mathbf{x}^{\prime}_i}
            {\| \mathbf{x}^{\prime}_i \|_2}`.
            (default: :obj:`False`)
        root_weight (bool, optional): If set to :obj:`False`, the layer will
            not add transformed root node features to the output.
            (default: :obj:`True`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """
    def __init__(self, in_channels: Union[int, Tuple[int, int]],
                 out_channels: int, normalize: bool = False,
                 root_weight: bool = True,
                 bias: bool = True, 
                 projection: bool = False,
                 # specail args
                 mp_quantizers=None,
                 layer_quantizers=None,
                 heads=None,
                 **kwargs):  # yapf: disable
        #kwargs.setdefault('aggr', 'max')
        super(SAGEConvQuant, self).__init__(
            aggr="mean", mp_quantizers=mp_quantizers, **kwargs
        )

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        self.root_weight = root_weight
        self.bias = bias
        self.projection = projection

        #self.bias = True

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        lin_l_channels = in_channels[0]
        if self.projection:
            self.lin = Parameter(torch.Tensor(in_channels[0], out_channels))
            self.lin_bias = Parameter(torch.Tensor(out_channels))
            lin_l_channels = out_channels

        self.lin_l = Parameter(torch.Tensor(lin_l_channels, out_channels))
        if self.bias: self.lin_l_bias = Parameter(torch.Tensor(out_channels))
        if self.root_weight:
            self.lin_r = Parameter(torch.Tensor(in_channels[1], out_channels))

        assert layer_quantizers is not None
        self.layer_quant_fns = layer_quantizers
        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        self.mp_quantizers["merge"] = self.mp_quant_fns["merge"]()

        if self.projection:
            glorot(self.lin)
            zeros(self.lin_bias)
        glorot(self.lin_l)
        if self.bias:
            zeros(self.lin_l_bias)
        if self.root_weight:
            glorot(self.lin_r)
        # create quantization modules for this layer
        self.layer_quantizers = ModuleDict()
        for key in REQUIRED_GS_KEYS:
            self.layer_quantizers[key] = self.layer_quant_fns[key]()
            

    def reset_quantizers(self, mp_quantizers, layer_quantizers):
        # reset message_passing
        super().reset_quantizers(mp_quantizers=mp_quantizers)
        self.mp_quantizers["merge"] = self.mp_quant_fns["merge"]()
        # reset layer
        assert layer_quantizers is not None
        self.layer_quant_fns = layer_quantizers
        self.layer_quantizers = ModuleDict()
        for key in REQUIRED_GS_KEYS:
            self.layer_quantizers[key] = self.layer_quant_fns[key]()

    def freeze_quantization_parameters(self):
        super().freeze_quantization_parameters()
        self.mp_quantizers['merge'].update = False
        for key in REQUIRED_GS_KEYS:
            self.layer_quantizers[key].freeze_quantization_parameters()

    def unfreeze_quantization_parameters(self):
        super().unfreeze_quantization_parameters()
        self.mp_quantizers['merge'].update = True
        for key in REQUIRED_GS_KEYS:
            self.layer_quantizers[key].unfreeze_quantization_parameters()


    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, q_group=None,
                size=None) -> Tensor:
        
        if isinstance(x, Tensor):
            x: OptPairTensor = (x, x)

        # propagate_type: (x: OptPairTensor)
        x_q = self.layer_quantizers["inputs"](x[0], True, False, q_group=q_group)
        if self.projection:
            lin = self.layer_quantizers["pool"](self.lin)
            out = self.layer_quantizers["features"](F.relu(torch.matmul(x_q, lin)+self.lin_bias), True, False, q_group=q_group)
        else:
            out = x_q
        out = self.propagate(edge_index, x=out, size=size, q_group=q_group)
        lin_l_q = self.layer_quantizers["weights_l"](self.lin_l)
        out = torch.matmul(out, lin_l_q)
        if self.bias:
            out += self.lin_l_bias


        x_r = x[1]
        x_r_q = self.layer_quantizers["inputs_r"](x_r, True, False, q_group=q_group)
        lin_r_q = self.layer_quantizers["weights_r"](self.lin_r)

        if self.root_weight and x_r is not None:
            out += self.layer_quantizers["root_features"](torch.matmul(x_r_q, lin_r_q), True, False, q_group=q_group)

        if self.normalize:
            out = F.normalize(out, p=2., dim=-1)

        out = self.layer_quantizers["outputs"](out, True, False, q_group=q_group)
        return out

    def message(self, x_j, node_wise, edge_index, q_group):

        norm_shape = edge_index.t().coo()[0].shape if isinstance(edge_index, SparseTensor) else x_j.shape[0]
        x_j_int, norm_ = self.mp_quantizers["message"](x_j, torch.ones(norm_shape, device=x_j.device), node_wise, edge_index=edge_index, q_group=q_group)
        norm = self.mp_quantizers["merge"](norm_, False, True, edge_index=edge_index, q_group=q_group)

        if not node_wise:
            x_all = norm.view(-1, 1) * x_j_int
        else:
            edge_index = edge_index.t()
            edge_index = edge_index.set_value(norm, 'coo')
            edge_index = edge_index.t()
            x_all = x_j_int
        return x_all
    
    def aggregate(self, inputs, index, dim_size):  # pragma: no cover
        if isinstance(index, SparseTensor):
            index = index.set_value(None, layout=None) 
            return spmm(index, inputs, self.aggr)
        else:
            return scatter_(self.aggr, inputs, index, self.node_dim, dim_size)


    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)
    


    

