import torch
import torch.nn as nn

from magnetic_edge_gnn.models.gnn.gnn_layers import (
    FusionLayer,
)


class DualBlockFlat(nn.Module):
    """Layer for the dual architecture that mixes paralellely."""

    def __init__(
        self,
        equi_in_dim: int,
        equi_out_dim: int,
        inv_in_dim: int,
        inv_out_dim: int,
        use_fusion: bool,
        inv_to_equi: bool,
        equi_to_inv: bool,
        init_equi_conv_fn=None,
        init_inv_conv_fn=None,
        init_equi_inv_conv_fn=None,
        init_inv_equi_conv_fn=None,
        skip_connection: bool = True,
        bias: bool = True,
        **kwargs,
    ):
        super().__init__()
        self._equi_in_dim = equi_in_dim
        self._equi_out_dim = equi_out_dim
        self._inv_in_dim = inv_in_dim
        self._inv_out_dim = inv_out_dim
        self.use_fusion = use_fusion
        self.inv_to_equi = inv_to_equi
        self.equi_to_inv = equi_to_inv
        self.skip_connection = skip_connection

        # Fusion can only be used when both equi and invariant inputs are available
        self.use_fusion = use_fusion and (equi_in_dim > 0) and (inv_in_dim > 0)

        if self.use_fusion:
            self.equi_fusion_layer = FusionLayer(
                in1_channels=equi_out_dim,
                in2_channels=inv_out_dim,
                out_channels=equi_out_dim,
                bias=False,
                **kwargs,
            )
            self.inv_fusion_layer = FusionLayer(
                in1_channels=equi_out_dim,
                in2_channels=inv_out_dim,
                out_channels=inv_out_dim,
                **kwargs,
            )

        if inv_in_dim > 0:
            self.inv_conv = init_inv_conv_fn(
                in_channels=inv_in_dim,
                out_channels=inv_out_dim,
                **(kwargs | dict(skip_connection=False)),
            )
            if self.inv_to_equi:
                self.inv_equi_conv = init_inv_equi_conv_fn(
                    in_channels=inv_in_dim,
                    out_channels=equi_out_dim,
                    **(kwargs | dict(skip_connection=False)),
                )
        if equi_in_dim > 0:
            self.equi_conv = init_equi_conv_fn(
                in_channels=equi_in_dim,
                out_channels=equi_out_dim,
                **(kwargs | dict(skip_connection=False)),
            )
            if self.equi_to_inv:
                self.equi_inv_conv = init_equi_inv_conv_fn(
                    in_channels=equi_in_dim,
                    out_channels=inv_out_dim,
                    **(kwargs | dict(skip_connection=False)),
                )
        if self.skip_connection:
            self.skip_connection_equi = nn.Linear(equi_in_dim, equi_out_dim, bias=False)
            self.skip_connection_inv = nn.Linear(inv_in_dim, inv_out_dim, bias=False)

        if bias:
            self.bias = nn.Parameter(torch.Tensor(inv_out_dim))
            nn.init.zeros_(self.bias)
        else:
            self.register_parameter("bias", None)

    @property
    def equi_out_dim(self) -> int:
        """The effective output dimension of the block for equivariant features."""
        if self._equi_in_dim > 0 or (self.inv_to_equi and self._inv_in_dim > 0):
            return self._equi_out_dim
        else:
            return 0

    @property
    def inv_out_dim(self) -> int:
        """The effective output dimension of the block for invariant features."""
        if self._inv_in_dim > 0 or (self.equi_to_inv and self._equi_in_dim > 0):
            return self._inv_out_dim
        else:
            return 0

    def fusion(self, equi_edge_attr, inv_edge_attr):
        new_equi_edge_attr = self.equi_fusion_layer(
            equi_edge_attr,
            inv_edge_attr,
        )
        # Absolute value to keep representations orientation-invariant.
        new_inv_edge_attr = self.inv_fusion_layer(
            torch.abs(equi_edge_attr),
            inv_edge_attr,
        )
        return new_equi_edge_attr, new_inv_edge_attr

    def forward(self, edge_index, equi_edge_attr, inv_edge_attr, undirected_mask):
        (
            h_equi_equi,
            h_inv_equi,
            h_equi_inv,
            h_inv_inv,
            h_equi_fused,
            h_inv_fused,
            h_equi_residual,
            h_inv_residual,
        ) = (None,) * 8

        if equi_edge_attr.size(-1) > 0:
            h_equi_equi = self.equi_conv(
                edge_index=edge_index,
                edge_attr=equi_edge_attr,
                undirected_mask=undirected_mask,
            )
            if self.skip_connection:
                h_equi_residual = self.skip_connection_equi(equi_edge_attr)
            if self.equi_to_inv:
                h_equi_inv = self.equi_inv_conv(
                    edge_index=edge_index,
                    edge_attr=equi_edge_attr,
                    undirected_mask=undirected_mask,
                )

        if inv_edge_attr.size(-1) > 0:
            h_inv_inv = self.inv_conv(
                edge_index=edge_index,
                edge_attr=inv_edge_attr,
                undirected_mask=undirected_mask,
            )
            if self.skip_connection:
                h_inv_residual = self.skip_connection_inv(inv_edge_attr)
            if self.inv_to_equi:
                h_inv_equi = self.inv_equi_conv(
                    edge_index=edge_index,
                    edge_attr=inv_edge_attr,
                    undirected_mask=undirected_mask,
                )

        if (
            self.use_fusion
            and inv_edge_attr.size(-1) > 0
            and equi_edge_attr.size(-1) > 0
        ):
            h_equi_fused, h_inv_fused = self.fusion(h_equi_equi, h_inv_inv)

        # Aggregate
        h_equi_new = sum(
            h
            for h in [h_equi_equi, h_inv_equi, h_equi_fused, h_equi_residual]
            if h is not None
        )
        h_inv_new = sum(
            h
            for h in [h_inv_inv, h_equi_inv, h_inv_fused, h_inv_residual]
            if h is not None
        )
        if self.bias is not None:
            h_inv_new += self.bias
        return h_equi_new, h_inv_new
