import copy
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg


@dataclass
class CustomLayerConfig:
    # batchnorm parameters.
    has_batchnorm: bool = False
    bn_eps: float = 1e-5
    bn_mom: float = 0.1

    # mem parameters.
    mem_inplace: bool = False

    # gnn parameters.
    dim_in: int = -1
    dim_out: int = -1
    edge_dim: int = -1
    dim_inner: int = None
    num_layers: int = 2
    has_bias: bool = True
    # regularizer parameters.
    has_l2norm: bool = True
    dropout: float = 0.0
    # activation parameters.
    has_act: bool = True
    final_act: bool = True
    act: str = "relu"

    # other parameters.
    keep_edge: float = 0.5


def Customnew_layer_config(
    dim_in, dim_out, num_layers, has_act, has_bias, cfg
):
    return CustomLayerConfig(
        has_batchnorm=cfg.gnn.batchnorm,
        bn_eps=cfg.bn.eps,
        bn_mom=cfg.bn.mom,
        mem_inplace=cfg.mem.inplace,
        dim_in=dim_in,
        dim_out=dim_out,
        edge_dim=cfg.dataset.edge_dim,
        has_l2norm=cfg.gnn.l2norm,
        dropout=cfg.gnn.dropout,
        has_act=has_act,
        final_act=True,
        act=cfg.gnn.act,
        has_bias=has_bias,
        keep_edge=cfg.gnn.keep_edge,
        dim_inner=cfg.gnn.dim_inner,
        num_layers=num_layers,
    )


# General classes
class CustomGeneralLayer(nn.Module):
    """
    General wrapper for layers

    Args:
        name (string): Name of the layer in registered :obj:`layer_dict`
        dim_in (int): Input dimension
        dim_out (int): Output dimension
        has_act (bool): Whether has activation after the layer
        has_bn (bool):  Whether has BatchNorm in the layer
        has_l2norm (bool): Wheter has L2 normalization after the layer
        **kwargs (optional): Additional args
    """

    def __init__(self, name, layer_config: CustomLayerConfig, **kwargs):
        super().__init__()
        self.has_l2norm = layer_config.has_l2norm
        has_bn = layer_config.has_batchnorm
        layer_config.has_bias = not has_bn
        self.layer = register.layer_dict[name](layer_config, **kwargs)
        layer_wrapper = []
        if has_bn:
            layer_wrapper.append(
                nn.BatchNorm1d(
                    layer_config.dim_out,
                    eps=layer_config.bn_eps,
                    momentum=layer_config.bn_mom,
                )
            )
        if layer_config.dropout > 0:
            layer_wrapper.append(
                nn.Dropout(
                    p=layer_config.dropout, inplace=layer_config.mem_inplace
                )
            )
        if layer_config.has_act:
            layer_wrapper.append(register.act_dict[layer_config.act]())
        self.post_layer = nn.Sequential(*layer_wrapper)

    def forward(
        self,
        batch,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        batch = self.layer(batch)
        if isinstance(batch, torch.Tensor):
            batch = self.post_layer(batch)
            if self.has_l2norm:
                batch = F.normalize(batch, p=2, dim=1)
        else:
            batch.x = self.post_layer(batch.x)
            if self.has_l2norm:
                batch.x = F.normalize(batch.x, p=2, dim=1)
        return batch


def CustomGNNPreMP(dim_in, dim_out, num_layers):
    """
    Wrapper for NN layer before GNN message passing

    Args:
        dim_in (int): Input dimension
        dim_out (int): Output dimension
        num_layers (int): Number of layers

    """
    return CustomGeneralMultiLayer(
        "linear",
        layer_config=Customnew_layer_config(
            dim_in, dim_out, num_layers, has_act=False, has_bias=False, cfg=cfg
        ),
    )


class CustomGeneralMultiLayer(nn.Module):
    """
    General wrapper for a stack of multiple layers

    Args:
        name (string): Name of the layer in registered :obj:`layer_dict`
        num_layers (int): Number of layers in the stack
        dim_in (int): Input dimension
        dim_out (int): Output dimension
        dim_inner (int): The dimension for the inner layers
        final_act (bool): Whether has activation after the layer stack
        **kwargs (optional): Additional args
    """

    def __init__(self, name, layer_config: CustomLayerConfig, **kwargs):
        super().__init__()
        dim_inner = (
            layer_config.dim_out
            if layer_config.dim_inner is None
            else layer_config.dim_inner
        )
        for i in range(layer_config.num_layers):
            d_in = layer_config.dim_in if i == 0 else dim_inner
            d_out = (
                layer_config.dim_out
                if i == layer_config.num_layers - 1
                else dim_inner
            )
            has_act = (
                layer_config.final_act
                if i == layer_config.num_layers - 1
                else True
            )
            inter_layer_config = copy.deepcopy(layer_config)
            inter_layer_config.dim_in = d_in
            inter_layer_config.dim_out = d_out
            inter_layer_config.has_act = has_act
            layer = CustomGeneralLayer(name, inter_layer_config, **kwargs)
            self.add_module("Layer_{}".format(i), layer)

    def forward(
        self,
        batch,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        for layer in self.children():
            batch = layer(batch)
        return batch
