# from graphgym.config import cfg

from typing import Union, Tuple

import torch
# from graphgym.config import cfg
from torch import Tensor
from torch_geometric.typing import Adj, PairTensor
from torch_geometric.utils import remove_self_loops
from torch_sparse import SparseTensor

from .gatv1 import GATv1Layer


class GINGATv1Layer(GATv1Layer):
    def __init__(
            self,
            in_channels: Union[int, Tuple[int, int]],
            out_channels: int,
            negative_slope: float = 0.2,
            heads: int = 1,
            convolve: bool = True,
            lambda_policy: str = None,  # [None, 'learn1', 'learn2', 'learn12', 'gcn_gat', 'individual']
            gcn_mode: bool = False,
            share_weights_score: bool = False,
            share_weights_value: bool = False,
            eps: float = 0.,
            train_eps: bool = False,
            **kwargs,
    ):

        super().__init__(in_channels=in_channels,
                         out_channels=out_channels,
                         negative_slope=negative_slope,
                         add_self_loops=False,
                         heads=heads,
                         bias=False,
                         convolve=convolve,
                         lambda_policy=lambda_policy,  # [None, 'learn1', 'learn2', 'learn12', 'gcn_gat', 'individual']
                         gcn_mode=gcn_mode,
                         share_weights_score=share_weights_score,
                         share_weights_value=share_weights_value,
                         **kwargs)

        if train_eps:
            self.eps = torch.nn.Parameter(torch.Tensor([eps]))
        else:
            self.register_buffer('eps', torch.Tensor([eps]))

        self.initial_eps = eps

        self.reset_parameters()


    def reset_parameters(self):
        super().reset_parameters()
        if hasattr(self, 'eps'):
            self.eps.data.fill_(self.initial_eps)

    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, size_target: int = None,
                return_attention_info: bool = False):

        if isinstance(edge_index, Tensor):
            edge_index, _ = remove_self_loops(edge_index)
        elif isinstance(edge_index, SparseTensor):
            raise NotImplementedError

        return super(GINGATv1Layer, self).forward(x=x,
                                                  edge_index=edge_index,
                                                  size_target=size_target,
                                                  return_attention_info=return_attention_info)

    def update_fn(self, x_agg, x_i):
        return self.merge_heads(x_agg) + (self.eps + 1.0) * self.merge_heads(x_i)
