"""
model structures for full-batch node/graph-level classification tasks
dataset ex) Cora, CiteSeer, PubMed / IMDB-BINARY, PROTEINS, NCI1, MUTAG / CIFAR10, MNIST
+ columnwise quantized weights
"""

from .quantization.quantizer_table_module import (
    GCNActivationTable,
    GINActivationTable,
    GATActivationTable,
    MergeQuantize
)

from .layers.mp_absorption import GCNConvQuant, GATConvQuant, GINConvQuant, SAGEConvQuant

from .quantization.linear_quantized import LinearQuantized
from ._sequental_table_input import SequentialTableInput
from ._main_model_structures import BatchNet
from .quantization.define_quantizers import define_single, define_columnwise, define_table_default, no_quantization

import torch.nn.functional as F
from torch.nn import Identity, ReLU, BatchNorm1d as BN


class GCN(BatchNet):
    def __init__(
        self,
        # model param
        dataset,
        num_layers,
        hidden,
        graph_level,
        # quantization param
        device,
        qtype: int = 32,
        ste: bool = False,
        momentum: bool = False,
        # special param
        large_graph = False,
        use_wc=True,
        sparse_check = False,
        q_group=None,
        # clip_param
        single=True, 
        bothside=True,
        percentile=-1, 
        **kwargs
        ):

        self.single = single
        self.bothside = bothside
        self.use_wc = use_wc
        self.percentile = percentile

        mp_quantizers, layer_quantizers, graph_layer_quantizers = self.new_quantizers(qtype, ste, momentum, device, q_group)
        
        mp_order = [0] * num_layers
        layer_order = [0] + [1] * (num_layers-1)
        self.order = (mp_order, layer_order)

        super(GCN, self).__init__(
            *("GCN", dataset, num_layers, hidden, graph_level),
            layer=GCNConvQuant,
            activation=F.relu,
            mp_quantizers=mp_quantizers, 
            layer_quantizers=layer_quantizers,
            graph_layer_quantizers=graph_layer_quantizers,
            mp_order=mp_order,
            layer_order=layer_order,
            large_graph=large_graph,
            sparse_check=sparse_check,
            **kwargs)
        
    def new_quantizers(self, qtype, ste, momentum, device, q_group):
        new_group = (q_group, qtype, ste, momentum)
        define_weight_q = define_columnwise if self.use_wc else define_single
        
        mp_quantizers = {
            "message": define_table_default(MergeQuantize, *(new_group), signed=True, symmetric=False, 
                                            dependency=False, device=device, norm_name="GCN", percentile=self.percentile),
            "aggregate": no_quantization(),
            "update_q": no_quantization(),
        }
        signed_lq = {
            "inputs": define_table_default(GCNActivationTable, *(new_group), signed=True, symmetric=False, 
                                           dependency=False, device=device, percentile=self.percentile),
            "weights": define_weight_q(*(new_group), signed=True, symmetric=True),
            "features": no_quantization(),
            "norm": define_table_default(GCNActivationTable, *(new_group), signed=True, symmetric=True, 
                                         dependency=False, device=device, percentile=self.percentile) 
        }
        unsigned_lq = {
            "inputs": define_table_default(GCNActivationTable, *(new_group), signed=False, symmetric=False, 
                                           dependency=False, device=device,percentile=self.percentile), 
            "weights": define_weight_q(*(new_group), signed=True, symmetric=True),
            "features": no_quantization(),
            "norm": define_table_default(GCNActivationTable, *(new_group), signed=True, symmetric=True, 
                                         dependency=False, device=device, percentile=self.percentile)
        }
        signed_glq = {
            "inputs": define_single(*(new_group), signed=True, symmetric=False, percentile=self.percentile),
            "weights": define_weight_q(*(new_group), signed=True, symmetric=True),
            "features": no_quantization(),
        }
        unsigned_glq = {
            "inputs": define_single(*(new_group), signed=False, symmetric=False, percentile=self.percentile),
            "weights": define_weight_q(*(new_group), signed=True, symmetric=True),
            "features": no_quantization(),
        }
        layer_quantizers = [signed_lq, unsigned_lq]
        graph_layer_quantizers = [signed_glq, unsigned_glq]
        return [mp_quantizers], layer_quantizers, graph_layer_quantizers
    
    def reset_quantizers(self, qtype, ste, momentum, device, q_group=None):
        mp_quantizers, layer_quantizers, graph_layer_quantizers = self.new_quantizers(qtype, ste, momentum, device, q_group)
        super().reset_quantizers(mp_quantizers=mp_quantizers, layer_quantizers=layer_quantizers, graph_layer_quantizers=graph_layer_quantizers)


class GAT(BatchNet):
    def __init__(
        self,
        # model param
        dataset,
        num_layers,
        hidden,
        graph_level,
        # quantization param
        device,
        qtype: int = 32,
        ste: bool = False,
        momentum: bool = False,
        heads: int = 8,
        # special param
        large_graph = False,
        use_wc=True,
        sparse_check = False,
        q_group=None,
        # clip_param
        single=True, 
        bothside=True,
        percentile=-1, 
        **kwargs
        ):
        
        self.single = single
        self.bothside = bothside
        self.use_wc = use_wc
        self.percentile = percentile
        
        mp_quantizers, layer_quantizers, graph_layer_quantizers = self.new_quantizers(qtype, ste, momentum, device, q_group)
        order = [0] * num_layers
        self.order = (order, order)
        
        super(GAT, self).__init__(
            *("GAT", dataset, num_layers, hidden, graph_level),
            layer=GATConvQuant,
            activation=F.elu,
            heads=heads,
            mp_quantizers=mp_quantizers, 
            layer_quantizers=layer_quantizers,
            graph_layer_quantizers=graph_layer_quantizers,
            mp_order=order,
            layer_order=order,
            large_graph=large_graph,
            sparse_check=sparse_check,
            **kwargs)

    def new_quantizers(self, qtype, ste, momentum, device, q_group):
        new_group = (q_group, qtype, ste, momentum)
        define_weight_q = define_columnwise if self.use_wc else define_single

        mp_quantizers = {
            "message": define_table_default(MergeQuantize, *(new_group), signed=True, symmetric=False, 
                                            dependency=False, device=device, norm_name="GAT", percentile=self.percentile),
            "aggregate": no_quantization(),
            "update_q": no_quantization(),
        }
        
        signed_lq = {
            "inputs": define_table_default(GATActivationTable, *(new_group), signed=True, symmetric=False, 
                                           dependency=False, device=device, percentile=self.percentile),
            "weights": define_weight_q(*(new_group), signed=True, symmetric=True),
            "features": no_quantization(),
            "attention_l": define_weight_q(*(new_group), signed=True, symmetric=True),
            "attention_r": define_weight_q(*(new_group), signed=True, symmetric=True),
            "alpha": define_table_default(GATActivationTable, *(new_group), signed=True, symmetric=True, 
                                          dependency=False, device=device, percentile=self.percentile),
        }
        unsigned_lq = {}
        
        signed_glq = {
            "inputs": define_single(*(new_group), signed=True, symmetric=False, percentile=self.percentile),
            "weights": define_weight_q(*(new_group), signed=True, symmetric=True),
            "features": no_quantization(),
        }
        unsigned_glq = {
            "inputs": define_single(*(new_group), signed=False, symmetric=False, percentile=self.percentile),
            "weights": define_weight_q(*(new_group), signed=True, symmetric=True),
            "features": no_quantization(),
        }

        layer_quantizers = [signed_lq, unsigned_lq]
        graph_layer_quantizers = [signed_glq, unsigned_glq]
        return [mp_quantizers], layer_quantizers, graph_layer_quantizers

    def reset_quantizers(self, qtype, ste, momentum, device, q_group=None):
        mp_quantizers, layer_quantizers, graph_layer_quantizers = self.new_quantizers(qtype, ste, momentum, device, q_group)
        super().reset_quantizers(mp_quantizers=mp_quantizers, layer_quantizers=layer_quantizers, graph_layer_quantizers=graph_layer_quantizers)



class GIN(BatchNet):
    def __init__(
        self,
        # model param
        dataset,
        num_layers,
        hidden,
        graph_level,
        # quantization param
        device,
        qtype: int = 32,
        ste: bool = False,
        momentum: bool = False,
        # special param
        large_graph = False,
        use_wc=True,
        sparse_check=False,
        q_group=None,
        # clip_param
        single=True, 
        bothside=True,
        percentile=-1, 
        **kwargs
        ):
        
        self.single = single
        self.bothside = bothside
        self.use_wc = use_wc
        self.percentile = percentile
        mp_quantizers, layer_quantizers, graph_layer_quantizers = self.new_quantizers(qtype, ste, momentum, device, q_group)
        
        mp_order = [0] + [1] * (num_layers-1)
        layer_order = [0] + [1] * (num_layers-1)
        self.order = (mp_order, layer_order)       
        if graph_level:
            nn = [
                # batchnorms? https://github.com/pyg-team/pytorch_geometric/blob/2.0.2/examples/mutag_gin.py
                SequentialTableInput([LinearQuantized(hidden, hidden, signed=True, layer_quantizers=layer_quantizers), BN(hidden), 
                                      ReLU(), LinearQuantized(hidden, hidden, signed=False, layer_quantizers=layer_quantizers), ReLU()]), # first layer
                SequentialTableInput([LinearQuantized(hidden, hidden, signed=False, layer_quantizers=layer_quantizers), BN(hidden), ReLU(), 
                                      LinearQuantized(hidden, hidden, signed=False, layer_quantizers=layer_quantizers), ReLU()]), # intermediate layers
                SequentialTableInput([LinearQuantized(hidden, hidden, signed=False, layer_quantizers=layer_quantizers), BN(hidden), ReLU(), 
                                      LinearQuantized(hidden, hidden, signed=False, layer_quantizers=layer_quantizers), ReLU()]), # last layer
            ]
        else:
            nn = [
                SequentialTableInput([LinearQuantized(8 if dataset =="ogbn-proteins" else dataset.num_features, hidden, signed=True, layer_quantizers=layer_quantizers), ReLU()]), # first layer
                SequentialTableInput([LinearQuantized(hidden, hidden, signed=False, layer_quantizers=layer_quantizers), ReLU()]), # intermediate layers
                SequentialTableInput([LinearQuantized(hidden, 112 if dataset =="ogbn-proteins" else dataset.num_classes, signed=False, layer_quantizers=layer_quantizers), ReLU()]), # last layer
            ]

        super(GIN, self).__init__(
            *("GIN", dataset, num_layers, hidden, graph_level),
            nn = nn,
            layer=GINConvQuant,
            activation=F.relu,
            mp_quantizers=mp_quantizers, 
            layer_quantizers=layer_quantizers,
            graph_layer_quantizers=graph_layer_quantizers,
            mp_order=mp_order,
            layer_order=layer_order,
            large_graph=large_graph,
            sparse_check=sparse_check,
            **kwargs)
        
    
    def new_quantizers(self, qtype, ste, momentum, device, q_group):
        new_group = (q_group, qtype, ste, momentum)
        define_weight_q = define_columnwise if self.use_wc else define_single
        
        signed_mp = {
            "message": define_table_default(MergeQuantize, *(new_group), signed=True, symmetric=False, 
                                            dependency=False, device=device, norm_name="GIN", percentile=self.percentile),
            "aggregate": no_quantization(),
            "update_q": no_quantization(),
            "merge": define_table_default(GINActivationTable, *(new_group), signed=True, symmetric=True, 
                                          dependency=False, device=device, percentile=self.percentile) 
        }
        unsigned_mp = {
            "message": define_table_default(MergeQuantize, *(new_group), signed=False, symmetric=False, 
                                            dependency=False, device=device, norm_name="GIN", percentile=self.percentile),
            "aggregate": no_quantization(),
            "update_q": no_quantization(),
            "merge": define_table_default(GINActivationTable, *(new_group), signed=True, symmetric=True, 
                                          dependency=False, device=device, percentile=self.percentile) 
        }
        signed_lq = {
            "inputs": define_table_default(GINActivationTable, *(new_group), signed=True, symmetric=False, 
                                           dependency=False, device=device, percentile=self.percentile),
            "weights": define_weight_q(*(new_group), signed=True, symmetric=True),
            "features": no_quantization(),
        }
        unsigned_lq = {
            "inputs": define_table_default(GINActivationTable, *(new_group), signed=False, symmetric=False, 
                                           dependency=False, device=device, percentile=self.percentile), # try symmetric=True
            "weights": define_weight_q(*(new_group), signed=True, symmetric=True),
            "features": no_quantization(),
        }
        signed_glq = {
            "inputs": define_single(*(new_group), signed=True, symmetric=False, percentile=self.percentile),
            "weights": define_weight_q(*(new_group), signed=True, symmetric=True),
            "features": no_quantization(),
        }
        unsigned_glq = {
            "inputs": define_single(*(new_group), signed=False, symmetric=False, percentile=self.percentile),
            "weights": define_weight_q(*(new_group), signed=True, symmetric=True),
            "features": no_quantization(),
        }
        mp_quantizers = [signed_mp, unsigned_mp]
        layer_quantizers = [signed_lq, unsigned_lq]
        graph_layer_quantizers = [signed_glq, unsigned_glq]
        return mp_quantizers, layer_quantizers, graph_layer_quantizers
    
    def reset_quantizers(self, qtype, ste, momentum, device, q_group=None):
        mp_quantizers, layer_quantizers, graph_layer_quantizers = self.new_quantizers(qtype, ste, momentum, device, q_group)
        super().reset_quantizers(mp_quantizers=mp_quantizers, layer_quantizers=layer_quantizers, graph_layer_quantizers=graph_layer_quantizers)


   
class GraphSAGE(BatchNet):
    def __init__(
        self,
        # model param
        dataset,
        num_layers,
        hidden,
        graph_level,
        # quantization param
        device,
        qtype: int = 32,
        ste: bool = False,
        momentum: bool = False,
        # special param
        large_graph = False,
        use_wc=True,
        sparse_check=False,
        q_group=None,
        # clip_param
        single=True, 
        bothside=True,
        percentile=-1,
        **kwargs
        ):

        self.single = single
        self.bothside = bothside
        self.use_wc = use_wc
        self.percentile = percentile
        
        mp_quantizers, layer_quantizers, graph_layer_quantizers = self.new_quantizers(qtype, ste, momentum, device, q_group)
        
        mp_order = [0] * num_layers
        layer_order = [1] * (num_layers)
        self.order = (mp_order, layer_order)

        super(GraphSAGE, self).__init__(
            *("GraphSAGE", dataset, num_layers, hidden, graph_level),
            layer=SAGEConvQuant,
            activation=F.relu,
            mp_quantizers=mp_quantizers, 
            layer_quantizers=layer_quantizers,
            graph_layer_quantizers=graph_layer_quantizers,
            mp_order=mp_order,
            layer_order=layer_order,
            large_graph=large_graph,
            sparse_check=sparse_check,
            **kwargs)
        

    def new_quantizers(self, qtype, ste, momentum, device, q_group):
        new_group = (q_group, qtype, ste, momentum)
        define_weight_q = define_columnwise if self.use_wc else define_single

        mp_quantizers = {
            "message": define_table_default(MergeQuantize, *(new_group), signed=False, symmetric=False, 
                                            dependency=False, device=device, norm_name="GAT", percentile=self.percentile),
            "aggregate": no_quantization(),
            "update_q": define_table_default(GATActivationTable, *(new_group), signed=True, symmetric=False, 
                                             dependency=False, device=device, percentile=self.percentile),
            "merge": define_table_default(GATActivationTable, *(new_group), signed=True, symmetric=True, 
                                          dependency=False, device=device, percentile=self.percentile)
        }
        signed_lq = {
            "inputs": define_table_default(GATActivationTable, *(new_group), signed=True, symmetric=False, 
                                           dependency=False, device=device, percentile=self.percentile),
            "inputs_r": define_table_default(GATActivationTable, *(new_group), signed=True, symmetric=False, 
                                             dependency=False, device=device, percentile=self.percentile),
            "pool": define_weight_q(*(new_group), signed=True, symmetric=True),
            "weights": define_weight_q(*(new_group), signed=True, symmetric=True),
            "weights_r": define_weight_q(*(new_group), signed=True, symmetric=True),
            "weights_l": define_weight_q(*(new_group), signed=True, symmetric=True),
            "features": no_quantization(),
            "root_features": no_quantization(),
            "outputs": no_quantization(),
        }
        unsigned_lq = {
            "inputs": define_table_default(GATActivationTable, *(new_group), signed=False, symmetric=False, 
                                           dependency=False, device=device, percentile=self.percentile),
            "inputs_r": define_table_default(GATActivationTable, *(new_group), signed=False, symmetric=False, 
                                             dependency=False, device=device, percentile=self.percentile),
            "pool": define_weight_q(*(new_group), signed=True, symmetric=True),
            "weights": define_weight_q(*(new_group), signed=True, symmetric=True),
            "weights_r": define_weight_q(*(new_group), signed=True, symmetric=True),
            "weights_l": define_weight_q(*(new_group), signed=True, symmetric=True),
            "features": no_quantization(),
            "root_features": no_quantization(),
            "outputs": no_quantization(),
        }
        signed_glq = {
            "inputs": define_single(*(new_group), signed=True, symmetric=False, percentile=self.percentile),
            "weights": define_weight_q(*(new_group), signed=True, symmetric=True),
            "features": no_quantization(),
        }
        unsigned_glq = {
            "inputs": define_single(*(new_group), signed=False, symmetric=False, percentile=self.percentile),
            "weights": define_weight_q(*(new_group), signed=True, symmetric=True),
            "features": no_quantization(),
        }
        layer_quantizers = [signed_lq, unsigned_lq]
        graph_layer_quantizers = [signed_glq, unsigned_glq]
        return [mp_quantizers], layer_quantizers, graph_layer_quantizers

    def reset_quantizers(self, qtype, ste, momentum, device, q_group=None):
        mp_quantizers, layer_quantizers, graph_layer_quantizers = self.new_quantizers(qtype, ste, momentum, device, q_group=None)
        super().reset_quantizers(mp_quantizers=mp_quantizers, layer_quantizers=layer_quantizers, graph_layer_quantizers=graph_layer_quantizers)

