import torch
import torch.nn as nn
import torch_geometric as pyg
from torch_scatter import scatter

import torch_geometric.graphgym.register as register
from torch_geometric.graphgym import cfg
from torch_geometric.graphgym.register import register_head


@register_head('san_graph')
class SANGraphHead(nn.Module):
    """
    SAN prediction head for graph prediction tasks.
    Args:
        dim_in (int): Input dimension.
        dim_out (int): Output dimension. For binary prediction, dim_out=1.
        L (int): Number of hidden layers.
    """

    def __init__(self, dim_in, dim_spec_out, dim_bins_out, L=2, dual_head=False):
        super().__init__()
        dim_out = dim_spec_out * dim_bins_out
        self.n_spec = dim_spec_out
        self.dual_head = dual_head
        self.pooling_fun = register.pooling_dict[cfg.model.graph_pooling]

        list_FC_layers = [
            nn.Linear(dim_in // 2 ** l, dim_in // 2 ** (l + 1), bias=True)
            for l in range(L)]
        list_FC_layers.append(
            nn.Linear(dim_in // 2 ** L, dim_out, bias=True))
        self.FC_layers = nn.ModuleList(list_FC_layers)
        # self.layer_norm = nn.LayerNorm(dim_in // 2 ** L)
        self.L = L
        self.activation = nn.SiLU() # register.act_dict[cfg.gnn.act]()

        ## Dual head
        if self.dual_head:
            list_FC_layers = [
                nn.Linear(dim_in // 2 ** l, dim_in // 2 ** (l + 1), bias=True)
                for l in range(L)]
            list_FC_layers.append(
                nn.Linear(dim_in // 2 ** L, dim_out, bias=True))
            self.FC_layers_dual = nn.ModuleList(list_FC_layers)
            self.layer_norm_dual = nn.LayerNorm(dim_in // 2 ** L)
            self.ff_out = nn.Sequential(
                nn.Linear(dim_out * 2, dim_out * 4), 
                nn.SiLU(), 
                nn.Linear(dim_out * 4, dim_out)
            )


    def _apply_index(self, batch):
        return batch.graph_feature, batch.y

    def forward(self, batch, return_embedding=False):

        # Pool
        graph_emb = self.pooling_fun(batch.x, batch.batch)
        if self.dual_head:
            graph_emb_dual = self.pooling_fun(batch.x_features, batch.batch)

        # Embed
        for l in range(self.L):
            graph_emb = self.FC_layers[l](graph_emb)
            graph_emb = self.activation(graph_emb)
            if self.dual_head:
                graph_emb_dual = self.FC_layers_dual[l](graph_emb_dual)
                graph_emb_dual = self.activation(graph_emb_dual)
        if return_embedding:
            emb = graph_emb.clone()
        
        # Norm and last layer
        # graph_emb = self.layer_norm(graph_emb)
        graph_emb = self.FC_layers[self.L](graph_emb)
        if self.dual_head:
            graph_emb = self.activation(graph_emb)
            graph_emb_dual = self.layer_norm_dual(graph_emb_dual)
            graph_emb_dual = self.FC_layers_dual[self.L](graph_emb_dual)
            graph_emb_dual = self.activation(graph_emb_dual)
            graph_emb = self.ff_out(torch.cat([graph_emb, graph_emb_dual], dim=-1))
        batch.graph_feature = graph_emb.view(len(graph_emb), self.n_spec, -1)
        pred, label = self._apply_index(batch)
        if return_embedding:
            return pred, label, emb
        else:
            return pred, label

