import torch.nn as nn
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym import cfg
from torch_geometric.graphgym.register import register_head
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from graphgps.layer.fsw_conv import FSW_readout
import torch


@register_head('mlp_fsw')
class MLPSWHead(nn.Module):
    """
    MLP prediction head for graph prediction tasks.

    Args:
        dim_in (int): Input dimension.
        dim_out (int): Output dimension. For binary prediction, dim_out=1.
        L (int): Number of hidden layers.
    """

    def __init__(self, dim_in, dim_out):
        super().__init__()
        assert cfg.model.graph_pooling == 'fsw_readout'
        in_channels = dim_in
        out_channels = dim_out
        dim_inner = cfg.gnn.dim_inner
        fsw_embed_dim = cfg.gnn.fsw_embed_dim
        learnable_embedding = cfg.gnn.learnable_embedding
        concat_self = cfg.gnn.concat_self
        bias = cfg.gnn.bias
        conv_mlp_layers = cfg.gnn.conv_mlp_layers
        conv_mlp_hidden_dim = cfg.gnn.conv_mlp_hidden_dim
        conv_mlp_activation_final = register.act_dict[cfg.gnn.conv_mlp_activation_final]()
        conv_mlp_activation_hidden = register.act_dict[cfg.gnn.conv_mlp_activation_hidden]()
        ###
        encode_vertex_degrees = cfg.gnn.encode_vertex_degrees
        vertex_degree_encoding_function = cfg.gnn.vertex_degree_encoding_function
        homog_degree_encoding = cfg.gnn.homog_degree_encoding
        ###
        mlp_init = cfg.gnn.mlp_init
        batchnorm_final = cfg.gnn.batchnorm_final
        batchnorm_hidden = cfg.gnn.batchnorm_hidden
        dropout_final = cfg.gnn.dropout_final
        dropout_hidden = cfg.gnn.dropout_hidden
        self.skip_connections = cfg.gnn.residual

        self.global_sw = FSW_readout(
                        in_channels=dim_in, out_channels=dim_inner, edgefeat_dim=0, embed_dim=fsw_embed_dim, 
                        learnable_embedding=learnable_embedding, encode_vertex_degrees=encode_vertex_degrees, 
                        vertex_degree_encoding_function=vertex_degree_encoding_function, 
                        homog_degree_encoding=homog_degree_encoding, concat_self=False, bias=bias, 
                        mlp_layers=0, mlp_hidden_dim=conv_mlp_hidden_dim, 
                        mlp_activation_final=conv_mlp_activation_final, mlp_activation_hidden=conv_mlp_activation_hidden,
                        mlp_init=mlp_init, batchNorm_final=batchnorm_final, batchNorm_hidden=batchnorm_hidden,
                        dropout_final=dropout_final, dropout_hidden=dropout_hidden, self_loop_weight=0.0,
                        edge_weighting='unit')

        mlp_init = cfg.gnn.mlp_init
        dropout = cfg.gnn.dropout
        L = cfg.gnn.layers_post_mp

        layers = []
        for _ in range(L-1):
            layers.append(nn.Dropout(dropout))
            lin = nn.Linear(in_channels, in_channels, bias=True)
            if mlp_init is None:
                # do nothing
                pass
            elif mlp_init == 'xavier_uniform':
                torch.nn.init.xavier_uniform_(lin.weight)
            elif mlp_init == 'xavier_normal':
                torch.nn.init.xavier_normal_(lin.weight)
            elif mlp_init == 'kaiming_uniform':
                torch.nn.init.kaiming_uniform_(lin.weight)
            elif mlp_init == 'kaiming_normal':
                torch.nn.init.kaiming_normal_(lin.weight)
            else:
                raise RuntimeError('Invalid value passed at argument mlp_init')
            layers.append(lin)
            layers.append(register.act_dict[cfg.gnn.act]())

        layers.append(nn.Dropout(dropout))
        lin = nn.Linear(in_channels, out_channels, bias=True)
        if mlp_init is None:
            # do nothing
            pass
        elif mlp_init == 'xavier_uniform':
            torch.nn.init.xavier_uniform_(lin.weight)
        elif mlp_init == 'xavier_normal':
            torch.nn.init.xavier_normal_(lin.weight)
        elif mlp_init == 'kaiming_uniform':
            torch.nn.init.kaiming_uniform_(lin.weight)
        elif mlp_init == 'kaiming_normal':
            torch.nn.init.kaiming_normal_(lin.weight)
        else:
            raise RuntimeError('Invalid value passed at argument mlp_init')
        layers.append(lin)
        self.mlp = nn.Sequential(*layers)

    def _scale_and_shift(self, x):
        return x

    def _apply_index(self, batch):
        return batch.graph_feature, batch.y

    def forward(self, batch):
        x = self.global_sw(batch.x, batch.batch)
        pred = self.mlp(x)
        pred = self._scale_and_shift(pred)

        return pred, batch.y

