 
"""
model structures for full-batch node/graph-level classification tasks
dataset ex) Cora, CiteSeer, PubMed / IMDB-BINARY, PROTEINS, NCI1, MUTAG / CIFAR10, MNIST
"""



import torch
import torch.nn.functional as F
from torch.nn import Identity, ReLU, BatchNorm1d as BN

from .layers.mp_FP32 import (
    GCNConvQuant, GATConvQuant, GINConvQuant, SAGEConvQuant, 
    REQUIRED_QUANTIZER_KEYS, REQUIRED_GCN_KEYS, REQUIRED_GAT_KEYS, REQUIRED_GS_KEYS
    )
from ._sequental_table_input import SequentialTableInput
from ._main_model_structures import BatchNet

from .quantization.define_quantizers import no_quantization
from .quantization.linear_quantized import LinearQuantized 

#REQUIRED_QUANTIZER_KEYS = ["aggregate", "message", "update_q"]
#REQUIRED_GCN_KEYS = ["weights", "inputs", "features", "norm"]
#REQUIRED_GAT_KEYS = ["weights", "inputs", "features", "attention_l", "attention_r", "alpha"]


class GCN(BatchNet):
    def __init__(
        self,
        # model param
        dataset,
        num_layers,
        hidden,
        graph_level,
        # quantization param
        device,
        qtype: int = 32,
        ste: bool = False,
        momentum: bool = False,
        # special param
        large_graph = False,
        use_wc=False,
        sparse_check = False,
        q_group=None,
        **kwargs,

        ):
        mp_quantizers, layer_quantizers, graph_layer_quantizers = self.new_quantizers(qtype, ste, momentum, device, q_group)
        
        mp_order = [0] * num_layers
        layer_order = [0] + [1] * (num_layers-1)
        self.order = (mp_order, layer_order)

        super(GCN, self).__init__(
            *("GCN", dataset, num_layers, hidden, graph_level),
            layer=GCNConvQuant,
            activation=F.relu,
            mp_quantizers=mp_quantizers, 
            layer_quantizers=layer_quantizers,
            graph_layer_quantizers=graph_layer_quantizers,
            mp_order=mp_order,
            layer_order=layer_order,
            large_graph=large_graph,
            sparse_check=sparse_check,
            **kwargs)
        
    def new_quantizers(self, qtype, ste, momentum, device, q_group):
        new_group = (q_group, qtype, ste, momentum)
        mp_quantizers = {
            "message": no_quantization(),
            "aggregate": no_quantization(),
            "update_q": no_quantization(),
        }
        signed_lq = {
            "inputs": no_quantization(),
            "weights": no_quantization(),
            "features": no_quantization(),
            "norm": no_quantization()
        }
        unsigned_lq = {
            "inputs": no_quantization(),
            "weights": no_quantization(),
            "features": no_quantization(),
            "norm": no_quantization() 
        }
        signed_glq = {
            "inputs": no_quantization(),
            "weights": no_quantization(),
            "features": no_quantization(),
        }
        unsigned_glq = {
            "inputs": no_quantization(),
            "weights": no_quantization(),
            "features": no_quantization(),
        }
        layer_quantizers = [signed_lq, unsigned_lq]
        graph_layer_quantizers = [signed_glq, unsigned_glq]
        return [mp_quantizers], layer_quantizers, graph_layer_quantizers
    
    def reset_quantizers(self, qtype, ste, momentum, device, q_group=None):
        mp_quantizers, layer_quantizers, graph_layer_quantizers = self.new_quantizers(qtype, ste, momentum, device, q_group)
        super().reset_quantizers(mp_quantizers=mp_quantizers, layer_quantizers=layer_quantizers, graph_layer_quantizers=graph_layer_quantizers)


class GAT(BatchNet):
    def __init__(
        self,
        # model param
        dataset,
        num_layers,
        hidden,
        graph_level,
        # quantization param
        device,
        qtype: int = 32,
        ste: bool = False,
        momentum: bool = False,
        heads: int = 8,
        # special param
        large_graph = False,
        use_wc=False,
        sparse_check=False,
        q_group=None,
        **kwargs
        ):
        mp_quantizers, layer_quantizers, graph_layer_quantizers = self.new_quantizers(qtype, ste, momentum, device)
        order = [0] * num_layers
        self.order = (order, order)
        
        super(GAT, self).__init__(
            *("GAT", dataset, num_layers, hidden, graph_level),
            layer=GATConvQuant,
            activation=F.elu,
            heads=heads,
            mp_quantizers=mp_quantizers, 
            layer_quantizers=layer_quantizers,
            graph_layer_quantizers=graph_layer_quantizers,
            mp_order=order,
            layer_order=order,
            large_graph=large_graph,
            sparse_check=sparse_check,
            **kwargs)

    def new_quantizers(self, qtype, ste, momentum, device):
        new_group = (qtype, ste, momentum)
        mp_quantizers = {
            "message":no_quantization(),
            "aggregate": no_quantization(),
            "update_q": no_quantization(),
        }
        signed_lq = {
            "inputs": no_quantization(),
            "weights": no_quantization(), 
            "features": no_quantization(), 
            "attention_l": no_quantization(),
            "attention_r": no_quantization(),
            "alpha": no_quantization(),
        }
        unsigned_lq = {}
        signed_glq = {
            "inputs": no_quantization(),
            "weights": no_quantization(),
            "features": no_quantization(),
        }
        unsigned_glq = {
            "inputs": no_quantization(),
            "weights": no_quantization(),
            "features": no_quantization(),
        }
        layer_quantizers = [signed_lq, unsigned_lq]
        graph_layer_quantizers = [signed_glq, unsigned_glq]
        return [mp_quantizers], layer_quantizers, graph_layer_quantizers

    def reset_quantizers(self, qtype, ste, momentum, device, q_group=None):
        mp_quantizers, layer_quantizers, graph_layer_quantizers = self.new_quantizers(qtype, ste, momentum, device, q_group)
        super().reset_quantizers(mp_quantizers=mp_quantizers, layer_quantizers=layer_quantizers, graph_layer_quantizers=graph_layer_quantizers)          

class GIN(BatchNet):
    def __init__(
        self,
        # model param
        dataset,
        num_layers,
        hidden,
        graph_level,
        # quantization param
        device,
        qtype: int = 32,
        ste: bool = False,
        momentum: bool = False,
        # special param
        large_graph = False,
        use_wc=False,
        sparse_check=False,
        q_group=None,
        **kwargs
        ):
        mp_quantizers, layer_quantizers, graph_layer_quantizers = self.new_quantizers(qtype, ste, momentum, device, q_group)
        
        mp_order = [0] + [1] * (num_layers-1)
        layer_order = [0] + [1] * (num_layers-1)
        self.order = (mp_order, layer_order)       
        if graph_level:
            nn = [
                # batchnorms? https://github.com/pyg-team/pytorch_geometric/blob/2.0.2/examples/mutag_gin.py
                SequentialTableInput([LinearQuantized(hidden, hidden, signed=True, layer_quantizers=layer_quantizers), BN(hidden), 
                                      ReLU(), LinearQuantized(hidden, hidden, signed=False, layer_quantizers=layer_quantizers), ReLU()]), # first layer
                SequentialTableInput([LinearQuantized(hidden, hidden, signed=False, layer_quantizers=layer_quantizers), BN(hidden), ReLU(), 
                                      LinearQuantized(hidden, hidden, signed=False, layer_quantizers=layer_quantizers), ReLU()]), # intermediate layers
                SequentialTableInput([LinearQuantized(hidden, hidden, signed=False, layer_quantizers=layer_quantizers), BN(hidden), ReLU(), 
                                      LinearQuantized(hidden, hidden, signed=False, layer_quantizers=layer_quantizers), ReLU()]), # last layer
            ]
        else:
            nn = [
                SequentialTableInput([LinearQuantized(8 if dataset =="ogbn-proteins" else dataset.num_features, hidden, signed=True, layer_quantizers=layer_quantizers), ReLU()]), # first layer
                SequentialTableInput([LinearQuantized(hidden, hidden, signed=False, layer_quantizers=layer_quantizers), ReLU()]), # intermediate layers
                SequentialTableInput([LinearQuantized(hidden, 112 if dataset =="ogbn-proteins" else dataset.num_classes, signed=False, layer_quantizers=layer_quantizers), ReLU()]), # last layer
            ]
        super(GIN, self).__init__(
            *("GIN", dataset, num_layers, hidden, graph_level),
            nn = nn,
            layer=GINConvQuant,
            activation=F.relu,
            mp_quantizers=mp_quantizers, 
            layer_quantizers=layer_quantizers,
            graph_layer_quantizers=graph_layer_quantizers,
            mp_order=mp_order,
            layer_order=layer_order,
            large_graph=large_graph,
            sparse_check=sparse_check,
            **kwargs)
        
    
    def new_quantizers(self, qtype, ste, momentum, device, q_group):
        new_group = (q_group, qtype, ste, momentum)
        mp = {
            "message":no_quantization(),
            "aggregate": no_quantization(),
            "update_q": no_quantization(),
        }
        lq = {
            "inputs": no_quantization(),
            "weights": no_quantization(),
            "features": no_quantization(),
        }
        glq = {
            "inputs": no_quantization(),
            "weights": no_quantization(),
            "features": no_quantization(),
        }

        mp_quantizers = [mp, mp]
        layer_quantizers = [lq, lq]
        graph_layer_quantizers = [glq, glq]
        return mp_quantizers, layer_quantizers, graph_layer_quantizers
    
    def reset_quantizers(self, qtype, ste, momentum, device, q_group=None):
        mp_quantizers, layer_quantizers, graph_layer_quantizers = self.new_quantizers(qtype, ste, momentum, device, q_group)
        super().reset_quantizers(mp_quantizers=mp_quantizers, layer_quantizers=layer_quantizers, graph_layer_quantizers=graph_layer_quantizers)

class GraphSAGE(BatchNet):
    def __init__(
        self,
        # model param
        dataset,
        num_layers,
        hidden,
        graph_level,
        # quantization param
        device,
        qtype: int = 32,
        ste: bool = False,
        momentum: bool = False,
        # special param
        large_graph = False,
        use_wc=False,
        sparse_check=False,
        q_group=None,
        **kwargs
        ):
        
        mp_quantizers, layer_quantizers, graph_layer_quantizers = self.new_quantizers(qtype, ste, momentum, device, q_group)
        
        mp_order = [0] * num_layers
        layer_order = [1] * (num_layers)
        self.order = (mp_order, layer_order)

        super(GraphSAGE, self).__init__(
            *("GraphSAGE", dataset, num_layers, hidden, graph_level),
            layer=SAGEConvQuant,
            activation=F.relu,
            mp_quantizers=mp_quantizers, 
            layer_quantizers=layer_quantizers,
            graph_layer_quantizers=graph_layer_quantizers,
            mp_order=mp_order,
            layer_order=layer_order,
            large_graph=large_graph,
            sparse_check=sparse_check,
            **kwargs)
        

    def new_quantizers(self, qtype, ste, momentum, device, q_group):
        new_group = (q_group, qtype, ste, momentum)
        mp_quantizers = {
            "message": no_quantization(),
            "aggregate": no_quantization(),
            "update_q": no_quantization(),
        }
        lq = {
            "inputs": no_quantization(),
            "inputs_r": no_quantization(),
            "pool": no_quantization(),
            "weights": no_quantization(),
            "weights_r": no_quantization(),
            "features": no_quantization(),
            "root_features": no_quantization(),
            "outputs": no_quantization(),
        }
        glq = {
            "inputs": no_quantization(),
            "weights": no_quantization(),
            "features": no_quantization(),
        }
        layer_quantizers = [lq, lq]
        graph_layer_quantizers = [glq, glq]
        return [mp_quantizers], layer_quantizers, graph_layer_quantizers

    def reset_quantizers(self, qtype, ste, momentum, device, q_group):
        mp_quantizers, layer_quantizers, graph_layer_quantizers = self.new_quantizers(qtype, ste, momentum, device, q_group)
        super().reset_quantizers(mp_quantizers=mp_quantizers, layer_quantizers=layer_quantizers, graph_layer_quantizers=graph_layer_quantizers)

