import torch
from graphgym.config import cfg
from graphgym.register import register_layer
from torch import Tensor
from torch_geometric.utils import degree

from .gatv3 import GeneralGATConv, GATv1Layer, GATv2Layer
from .gatv3.gat_ansatz import GATAnsatzLayer
from .gatv3.gatpna import GATPNAv1Layer, GATPNAv2Layer
from .gatv3.gingatv1 import GINGATv1Layer


class GATv1(GeneralGATConv):
    def __init__(self, dim_in, dim_out, bias=False, **kwargs):
        kwargs['share_weights_value'] = cfg.gnn.share_weights
        kwargs['share_weights_score'] = cfg.gnn.share_weights
        kwargs['convolve'] = cfg.gnn.convolve
        kwargs['lambda_policy'] = cfg.gnn.lambda_policy
        kwargs['heads'] = cfg.gnn.heads
        kwargs['gcn_mode'] = cfg.gnn.gcn_mode
        dim_out = dim_out // cfg.gnn.heads
        bias = True

        super(GATv1, self).__init__(GATv1Layer, dim_in, dim_out, bias, **kwargs)


register_layer('gatv1', GATv1)


class GINGATv1(GeneralGATConv):
    def __init__(self, dim_in, dim_out, bias=False, **kwargs):
        kwargs['share_weights_value'] = cfg.gnn.share_weights
        kwargs['share_weights_score'] = cfg.gnn.share_weights
        kwargs['convolve'] = cfg.gnn.convolve
        kwargs['lambda_policy'] = cfg.gnn.lambda_policy
        kwargs['heads'] = cfg.gnn.heads
        kwargs['gcn_mode'] = cfg.gnn.gcn_mode
        kwargs['eps'] = cfg.gnn.eps
        kwargs['train_eps'] = cfg.gnn.train_eps
        dim_out = dim_out // cfg.gnn.heads
        bias = True

        super(GINGATv1, self).__init__(GINGATv1Layer, dim_in, dim_out, bias, **kwargs)


register_layer('gingatv1', GINGATv1)


class GATv2(GeneralGATConv):
    def __init__(self, dim_in, dim_out, bias=False, **kwargs):
        kwargs['share_weights_value'] = cfg.gnn.share_weights
        kwargs['share_weights_score'] = cfg.gnn.share_weights
        kwargs['convolve'] = cfg.gnn.convolve
        kwargs['lambda_policy'] = cfg.gnn.lambda_policy
        kwargs['heads'] = cfg.gnn.heads
        kwargs['gcn_mode'] = cfg.gnn.gcn_mode
        dim_out = dim_out // cfg.gnn.heads

        bias = True

        super(GATv2, self).__init__(GATv2Layer, dim_in, dim_out, bias, **kwargs)


register_layer('gatv2', GATv2)


class GATAnsatz(GeneralGATConv):
    def __init__(self, dim_in, dim_out, bias=False, **kwargs):
        kwargs['share_weights_value'] = cfg.gnn.share_weights
        kwargs['share_weights_score'] = True  # cfg.gnn.share_weights
        kwargs['convolve'] = cfg.gnn.convolve
        kwargs['lambda_policy'] = cfg.gnn.lambda_policy
        kwargs['heads'] = cfg.gnn.heads
        kwargs['gcn_mode'] = cfg.gnn.gcn_mode
        dim_out = dim_out // cfg.gnn.heads

        kwargs['add_self_loops'] = False
        kwargs['version'] = cfg.bcsbm.version

        if kwargs['version'] == 'v1':
            kwargs['negative_slope'] = 0.2
            bias = False
        elif kwargs['version'] == 'v2':
            kwargs['negative_slope'] = 0.01
            bias = True
        else:
            kwargs['negative_slope'] = cfg.bcsbm.negative_slope

        kwargs['use_ansatz'] = cfg.gnn.fixed_ansatz
        kwargs['use_partial_ansatz'] = cfg.gnn.partial_ansatz
        kwargs['add_attn_info'] = cfg.gnn.add_attn_info

        kwargs['mu_norm'] = cfg.bcsbm.mu_norm
        kwargs['d'] = cfg.bcsbm.d
        kwargs['p'] = cfg.bcsbm.p
        kwargs['q'] = cfg.bcsbm.q

        super(GATAnsatz, self).__init__(GATAnsatzLayer, dim_in, dim_out, bias, **kwargs)


register_layer('gatansatz', GATAnsatz)


class GATPNAv1(GeneralGATConv):
    def __init__(self, dim_in, dim_out, datasets, bias=False, **kwargs):
        kwargs['deg'] = GATPNAv1.get_degree_histogram(datasets[0])

        kwargs['aggregators'] = cfg.gnn.aggregators
        kwargs['scalers'] = cfg.gnn.scalers
        kwargs['share_weights_value'] = cfg.gnn.share_weights
        kwargs['share_weights_score'] = cfg.gnn.share_weights
        kwargs['convolve'] = cfg.gnn.convolve
        kwargs['lambda_policy'] = cfg.gnn.lambda_policy
        kwargs['heads'] = cfg.gnn.heads
        kwargs['gcn_mode'] = cfg.gnn.gcn_mode

        bias = True

        super(GATPNAv1, self).__init__(GATPNAv1Layer, dim_in, dim_out, bias, **kwargs)

    @staticmethod
    def get_degree_histogram(loader) -> Tensor:
        max_degree = 0
        for data in loader:
            d = degree(data.edge_index[1], num_nodes=data.num_nodes,
                       dtype=torch.long)
            max_degree = max(max_degree, int(d.max()))
        # Compute the in-degree histogram tensor
        deg_histogram = torch.zeros(max_degree + 1, dtype=torch.long)
        for data in loader:
            d = degree(data.edge_index[1], num_nodes=data.num_nodes,
                       dtype=torch.long)
            deg_histogram += torch.bincount(d, minlength=deg_histogram.numel())

        return deg_histogram


register_layer('gatpnav1', GATPNAv1)


class GATPNAv2(GeneralGATConv):
    def __init__(self, dim_in, dim_out, datasets, bias=False, **kwargs):
        kwargs['deg'] = GATPNAv2.get_degree_histogram(datasets[0])

        kwargs['aggregators'] = cfg.gnn.aggregators
        kwargs['scalers'] = cfg.gnn.scalers
        kwargs['share_weights_value'] = cfg.gnn.share_weights
        kwargs['share_weights_score'] = cfg.gnn.share_weights
        kwargs['convolve'] = cfg.gnn.convolve
        kwargs['lambda_policy'] = cfg.gnn.lambda_policy
        kwargs['heads'] = cfg.gnn.heads
        kwargs['gcn_mode'] = cfg.gnn.gcn_mode

        bias = True

        super(GATPNAv2, self).__init__(GATPNAv2Layer, dim_in, dim_out, bias, **kwargs)

    @staticmethod
    def get_degree_histogram(loader) -> Tensor:
        max_degree = 0
        for data in loader:
            d = degree(data.edge_index[1], num_nodes=data.num_nodes,
                       dtype=torch.long)
            max_degree = max(max_degree, int(d.max()))
        # Compute the in-degree histogram tensor
        deg_histogram = torch.zeros(max_degree + 1, dtype=torch.long)
        for data in loader:
            d = degree(data.edge_index[1], num_nodes=data.num_nodes,
                       dtype=torch.long)
            deg_histogram += torch.bincount(d, minlength=deg_histogram.numel())

        return deg_histogram


register_layer('gatpnav2', GATPNAv2)
