"""
main sturctures of full-batch, mini-batch model structures
"""
import torch
from torch_geometric.nn import global_add_pool, global_mean_pool
from torch_sparse import SparseTensor
from torch import Tensor
from torch.nn import Identity, BatchNorm1d as BN
import torch.nn.functional as F

from .quantization.linear_quantized import LinearQuantized
from tqdm import tqdm
import copy

        

def channel_find(dataset, graph_level, hidden, heads=1):
    if dataset == "ogbn-proteins":
        return 8, 8, 112, 112
    if hasattr(dataset, "num_classes"):
        last_channel = dataset.num_classes if not graph_level else hidden
        num_features = dataset.num_features
        num_classes = dataset.num_classes
    else: # IMDB-BINARY
        last_channel = hidden
        num_features = 1
        num_classes = 2
    first_channel = hidden*heads if graph_level else num_features
    return first_channel, num_features, last_channel, num_classes



class BatchNet(torch.nn.Module):
    def __init__(
        self,
        # model param
        model_name, 
        dataset, 
        num_layers, 
        hidden, 
        graph_level,
        # model arch
        layer,
        activation,
        # quantization params
        mp_quantizers,
        layer_quantizers,
        graph_layer_quantizers,
        mp_order=None,
        layer_order=None,
        # special params
        nn=None,
        heads=1,
        large_graph=False,
        sparse_check=False, # this is for checking the sanity of sparse tensor. 
        residual=False,
        dropout=0.5, 
        skips=True,
    ):
        super(BatchNet, self).__init__()

        assert len(mp_order) == num_layers
        assert len(layer_order) == num_layers
        
        self.model_name = model_name
        self.graph_level = graph_level
        self.num_layers = num_layers
        self.act = activation
        self.mp_order = mp_order
        self.layer_order = layer_order
        self.large_graph = large_graph
        self.sparse_check = sparse_check
        self.residual = residual
        self.dropout = dropout
        self.heads = heads
        self.skips = skips

        first_channel, num_features, last_channel, num_classes = channel_find(dataset, graph_level, hidden, heads=heads)

        layers = torch.nn.ModuleList()
        if "GIN" in self.model_name:
            layers.append(layer(nn[0], train_eps=True,
                        mp_quantizers=mp_quantizers[self.mp_order[0]])) # first layer
            for i in range(num_layers - 2): # intermediate layers
                layers.append(layer(copy.deepcopy(nn[1]), train_eps=True,
                        mp_quantizers=mp_quantizers[self.mp_order[i+1]]))
            layers.append(layer(nn[2], train_eps=True,
                        mp_quantizers=mp_quantizers[self.mp_order[-1]])) # last layer
        else:
            layers.append(layer(first_channel, hidden, heads=heads,
                        mp_quantizers=mp_quantizers[self.mp_order[0]], layer_quantizers=layer_quantizers[self.layer_order[0]])) # first layer
            for i in range(num_layers - 2): # intermediate layers
                layers.append(layer(hidden*heads, hidden, heads=heads,
                        mp_quantizers=mp_quantizers[self.mp_order[i+1]], layer_quantizers=layer_quantizers[self.layer_order[i+1]]))
            last_heads = heads if graph_level else 1
            layers.append(layer(hidden*heads, last_channel, heads=last_heads,
                        mp_quantizers=mp_quantizers[self.mp_order[-1]], layer_quantizers=layer_quantizers[self.layer_order[-1]])) # last layer

        setattr(self, "atts", layers) if "GAT" in self.model_name else setattr(self, "convs", layers)
        
        # other layer components in graph-level tasks
        if self.graph_level:
            # linear layers
            add_size = dataset[0].pos.size(1) if dataset[0].pos is not None else 0
            self.mlp = LinearQuantized(num_features+add_size, hidden*heads, signed=True, layer_quantizers=layer_quantizers)
            self.mlp1 = LinearQuantized(hidden*heads, hidden, signed=True, layer_quantizers=graph_layer_quantizers)
            self.mlp2 = LinearQuantized(hidden, num_classes, signed=False, layer_quantizers=graph_layer_quantizers)

            self.bns = torch.nn.ModuleList()
            for i in range(self.num_layers):
                self.bns.append(BN(hidden*heads))
        
        if self.large_graph and not self.sparse_check:
            # skip - linear layers
            if self.skips:
                self.skips = torch.nn.ModuleList()
                self.skips.append(LinearQuantized(num_features, hidden*heads, signed=True, layer_quantizers=layer_quantizers))
                for _ in range(num_layers - 2):
                    self.skips.append(
                        LinearQuantized(hidden*heads, hidden*heads, signed=True, layer_quantizers=layer_quantizers))
                self.skips.append(LinearQuantized(hidden*heads, num_classes, signed=True, layer_quantizers=layer_quantizers))

                

    def reset_parameters(self):
        layers = getattr(self, "atts") if "GAT" in self.model_name else getattr(self, "convs")
        for layer in layers:
            layer.reset_parameters()
        if self.graph_level:
            self.mlp.reset_parameters()
            self.mlp1.reset_parameters()
            self.mlp2.reset_parameters()
            for bn in self.bns:
                bn.reset_parameters()

    def reset_quantizers(self, mp_quantizers, layer_quantizers, graph_layer_quantizers):
        layers = getattr(self, "atts") if "GAT" in self.model_name else getattr(self, "convs")
        for i, layer in enumerate(layers):
            if self.model_name == "GIN":
                layer.reset_quantizers(mp_quantizers[self.mp_order[i]], layer_quantizers)
            else:
                layer.reset_quantizers(mp_quantizers[self.mp_order[i]], layer_quantizers[self.layer_order[i]])
        if self.graph_level:
            self.mlp.reset_quantizers(layer_quantizers=layer_quantizers)
            self.mlp1.reset_quantizers(layer_quantizers=graph_layer_quantizers)
            self.mlp2.reset_quantizers(layer_quantizers=graph_layer_quantizers)
        
    def freeze_quantization_parameters(self):
        layers = getattr(self, "atts") if "GAT" in self.model_name else getattr(self, "convs")
        for layer in layers:
            layer.freeze_quantization_parameters()
        if self.graph_level:
            self.mlp.freeze_quantization_parameters()
            self.mlp1.freeze_quantization_parameters()
            self.mlp2.freeze_quantization_parameters()
    
    def unfreeze_quantization_parameters(self):
        layers = getattr(self, "atts") if "GAT" in self.model_name else getattr(self, "convs")
        for layer in layers:
            layer.unfreeze_quantization_parameters()
        if self.graph_level:
            self.mlp.unfreeze_quantization_parameters()
            self.mlp1.unfreeze_quantization_parameters()
            self.mlp2.unfreeze_quantization_parameters()

    
    
    def forward(self, data, q_group, sigmoid=False):
        
        # PRE-PROCESSING
        last_act = self.num_layers if self.graph_level else self.num_layers-1
        if not isinstance(data, Tensor) and "batch" in data: # graph-level task
            x, edge_index, batch = data.x, data.edge_index, data.batch
            if data.pos is not None:
                x = torch.cat((x,data.pos),dim=1)
            q = q_group[-1] if isinstance(q_group, tuple) else q_group
            x = self.mlp(x, q_group=q) # embedding
        else: # node-level task
            if self.large_graph or self.sparse_check: # mini-batch
                x, adjs = data # adjs: (edge_index, e_id, size)
            else: # full-batch
                x, edge_index, batch, adjs = data.x, data.edge_index, None, None

        layers = getattr(self, "atts") if "GAT" in self.model_name else getattr(self, "convs")
        x_res = x if self.graph_level else 0
        
        # FORWARD PASS
        for i, layer in enumerate(layers):
            edge = adjs[i][0] if (self.large_graph or self.sparse_check) else edge_index
            q = q_group[i] if isinstance(q_group, tuple) else q_group
            if self.large_graph and self.skips:
                x = layer(x, edge, q) + self.skips[i](x, q_group=q)
            else: x = layer(x, edge, q)
            if self.graph_level: x = self.bns[i](x)
            if i < last_act:
                x = self.act(x) #x = F.relu(x)
                if self.residual: x = x + x_res
                x = F.dropout(x, p=self.dropout, training=self.training)
                x_res = x if self.graph_level else 0

        if self.graph_level:
            x = global_mean_pool(x, batch)
            x = F.relu(self.mlp1(x))
            #x = F.dropout(x, p=0.5, training=self.training)
            x = self.mlp2(x)

        return x if sigmoid else F.log_softmax(x, dim=-1)
    

    def inference(self, x_all, test_loader, q_group, device): # only called in large-graph node peroperty tasks, final evaluation
        # Compute representations of nodes layer by layer, using *all*
        # available edges. This leads to faster computation in contrast to
        # immediately computing the final representations of each batch.
        pbar = tqdm(total=x_all.size(0) * self.num_layers)
        pbar.set_description('Test')

        layers = getattr(self, "atts") if "GAT" in self.model_name else getattr(self, "convs")

        for i in range(self.num_layers):
            xs = []
            for batch_size, n_id, adjs in test_loader:
                edge_index, _, size = adjs.to(device)
                x = x_all[n_id]
                q = q_group[i] if isinstance(q_group, tuple) else q_group
                q_input = [q[0][n_id], q[1][n_id], q[2][n_id]] if q[0] is not None else q
                if self.skips:
                    x = layers[i](x.to(device), edge_index, q_input) + self.skips[i](x.to(device), q_group=q_input)
                else: 
                    x = layers[i](x.to(device), edge_index, q_input)
                if i != self.num_layers - 1:
                    x = self.act(x)
                
                xs.append(x[:batch_size].cpu())
                pbar.update(batch_size)
            x_all = torch.cat(xs, dim=0)
        
        pbar.close()
        return x_all
    