from graphgym.contrib.layer.gatv3 import GeneralGATConv, GATv1Layer, GATv2Layer, GATAnsatzLayer
from graphgym.config import cfg
from graphgym.register import register_layer


class GATv1(GeneralGATConv):
    def __init__(self, dim_in, dim_out, bias=False, **kwargs):
        kwargs['share_weights_value'] = cfg.gnn.share_weights
        kwargs['share_weights_score'] = cfg.gnn.share_weights
        kwargs['convolve'] = cfg.gnn.convolve
        kwargs['lambda_policy'] = cfg.gnn.lambda_policy
        kwargs['heads'] = cfg.gnn.heads
        kwargs['gcn_mode'] = cfg.gnn.gcn_mode
        dim_out = dim_out //cfg.gnn.heads
        bias = True

        super(GATv1, self).__init__(GATv1Layer, dim_in, dim_out, bias, **kwargs)


register_layer('gatv1', GATv1)


class GATv2(GeneralGATConv):
    def __init__(self, dim_in, dim_out, bias=False, **kwargs):
        kwargs['share_weights_value'] = cfg.gnn.share_weights
        kwargs['share_weights_score'] = cfg.gnn.share_weights
        kwargs['convolve'] = cfg.gnn.convolve
        kwargs['lambda_policy'] = cfg.gnn.lambda_policy
        kwargs['heads'] = cfg.gnn.heads
        kwargs['gcn_mode'] = cfg.gnn.gcn_mode
        dim_out = dim_out //cfg.gnn.heads

        bias = True

        super(GATv2, self).__init__(GATv2Layer, dim_in, dim_out, bias, **kwargs)


register_layer('gatv2', GATv2)


class GATAnsatz(GeneralGATConv):
    def __init__(self, dim_in, dim_out, bias=False, **kwargs):
        kwargs['share_weights_value'] = cfg.gnn.share_weights
        kwargs['share_weights_score'] = True #  cfg.gnn.share_weights
        kwargs['convolve'] = cfg.gnn.convolve
        kwargs['lambda_policy'] = cfg.gnn.lambda_policy
        kwargs['heads'] = cfg.gnn.heads
        kwargs['gcn_mode'] = cfg.gnn.gcn_mode
        dim_out = dim_out //cfg.gnn.heads

        kwargs['add_self_loops'] = False
        kwargs['version'] = cfg.bcsbm.version

        if kwargs['version'] == 'v1':
            kwargs['negative_slope'] = 0.2
            bias = False
        elif kwargs['version'] == 'v2':
            kwargs['negative_slope'] = 0.01
            bias = True
        else:
            kwargs['negative_slope'] = cfg.bcsbm.negative_slope

        kwargs['use_ansatz'] = cfg.gnn.fixed_ansatz
        kwargs['use_partial_ansatz'] = cfg.gnn.partial_ansatz
        kwargs['add_attn_info'] = cfg.gnn.add_attn_info

        kwargs['mu_norm'] = cfg.bcsbm.mu_norm
        kwargs['d'] = cfg.bcsbm.d
        kwargs['p'] = cfg.bcsbm.p
        kwargs['q'] = cfg.bcsbm.q

        super(GATAnsatz, self).__init__(GATAnsatzLayer, dim_in, dim_out, bias, **kwargs)


register_layer('gatansatz', GATAnsatz)



