from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union, Callable
from LorentzMACE.modules.irreps_tools import (
    linear_out_irreps,
    reshape_irreps,
    tp_out_irreps_with_instructions,
)
from LorentzMACE.modules.symmetric_contraction import SymmetricContraction
from LorentzMACE.tools.torch_geometric.norms import SwitchNorm1d
from LorentzMACE.modules.utils import metric, minkowski_norm
from LieCG import so13

import numpy as np
import torch
from LorentzMACE.tools.scatter import scatter_sum, scatter_mean

from .radial import BesselBasis, LorentzianBasis, PolynomialCutoff, SimpleEncodeBasis


class AttributesBlock(torch.nn.Module):
    def __init__(self, scale: float) -> None:
        super().__init__()
        self.register_buffer(
            "scale", torch.tensor(scale, dtype=torch.get_default_dtype())
        )

    def forward(
        self, positions: torch.Tensor,  # [n_nodes, 4]
    ):
        attrs = minkowski_norm(positions).abs().sqrt().unsqueeze(-1)
        return torch.view_as_complex(
            torch.stack([attrs, torch.zeros_like(attrs)], dim=-1)
        )

    def __repr__(self):
        return f"{self.__class__.__name__}"


class LinearNodeEmbeddingBlock(torch.nn.Module):
    def __init__(self, irreps_in: so13.Lorentz_Irreps, irreps_out: so13.Lorentz_Irreps):
        super().__init__()
        self.linear = so13.Linear(irreps_in=irreps_in, irreps_out=irreps_out)
        self.node_envelope = lambda x: torch.log((x + 1).abs() + 1e-5)

        self.fc = torch.nn.Linear(2, 1)

    def forward(
        self, node_attrs: torch.Tensor,  # [n_nodes, irreps]
    ):
        emb = self.linear(node_attrs)[:, :, None]
        log_emb = self.node_envelope(emb)
        emb = self.fc(torch.cat([emb, log_emb], dim=2)).squeeze(2)
        return torch.view_as_complex(torch.stack([emb, torch.zeros_like(emb)], dim=-1))


class NonLinearBlock(torch.nn.Module):
    def __init__(self, gate: torch.nn.Module):
        super().__init__()
        self.non_linearity = gate

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # [n_nodes, 1]  # [..., ]
        return self.gate(x)  # [n_nodes, 1]


class LinearReadoutBlock(torch.nn.Module):
    def __init__(
        self, irreps_in: so13.Lorentz_Irreps, readout_irreps: so13.Lorentz_Irreps
    ):
        super().__init__()
        self.linear = so13.Linear(
            irreps_in=irreps_in, irreps_out=readout_irreps, biases=True
        )
        self.irreps_out = readout_irreps

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # [n_nodes, irreps]  # [..., ]
        return self.linear(x.real)  # [n_nodes, 1]


class NonLinearReadoutBlock(torch.nn.Module):
    def __init__(
        self,
        irreps_in: so13.Lorentz_Irreps,
        MLP_irreps: so13.Lorentz_Irreps,
        gate: Callable,
    ):
        super().__init__()
        self.hidden_irreps = MLP_irreps
        self.linear_1 = so13.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps)
        self.non_linearity = so13.Activation(irreps_in=self.hidden_irreps, acts=[gate])
        self.irreps_out = so13.Lorentz_Irreps("2x(0,0)")
        self.linear_2 = so13.Linear(
            irreps_in=self.hidden_irreps, irreps_out=self.irreps_out
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # [n_nodes, irreps]  # [..., ]
        x = self.non_linearity(self.linear_1(x.real))
        return self.linear_2(x)  # [n_nodes, 1]


class RadialBesselEmbeddingBlock(torch.nn.Module):
    def __init__(
        self,
        r_max: float,
        num_bessel: int,
        num_polynomial_cutoff: int,
        use_cutoff: bool,
    ):
        super().__init__()
        self.bessel_fn = BesselBasis(r_max=r_max, num_basis=num_bessel)
        if use_cutoff:
            self.cutoff_fn = PolynomialCutoff(r_max=r_max, p=num_polynomial_cutoff)
        self.out_dim = num_bessel
        self.use_cutoff = use_cutoff

    def forward(
        self, edge_lengths: torch.Tensor,  # [n_edges, 1]
    ):
        bessel = self.bessel_fn(edge_lengths)  # [n_edges, n_basis]
        if self.use_cutoff:
            cutoff = self.cutoff_fn(edge_lengths)  # [n_edges, 1]
            return bessel * cutoff  # [n_edges, n_basis]
        else:
            return bessel


class RadialLorentzianEmbeddingBlock(torch.nn.Module):
    def __init__(
        self,
        r_max: float,
        num_bessel: int,
        num_polynomial_cutoff: int,
        use_cutoff: bool,
    ):
        super().__init__()
        self.bessel_fn = LorentzianBasis(r_max=r_max, num_basis=num_bessel)
        if use_cutoff:
            self.cutoff_fn = PolynomialCutoff(r_max=r_max, p=num_polynomial_cutoff)
        self.out_dim = num_bessel
        self.use_cutoff = use_cutoff

    def forward(
        self, edge_lengths: torch.Tensor,  # [n_edges, 1]
    ):
        bessel = self.bessel_fn(edge_lengths)  # [n_edges, n_basis]
        if self.use_cutoff:
            cutoff = self.cutoff_fn(edge_lengths)  # [n_edges, 1]
            return bessel * cutoff  # [n_edges, n_basis]
        else:
            return bessel


class SimpleRadialEmbeddingBlock(torch.nn.Module):
    def __init__(self, num_bessel: int, **kwargs):
        """https://github.com/abogatskiy/PELICAN/blob/main/src/layers/generic_layers.py"""
        super().__init__()
        self.out_dim = num_bessel
        self.emb_f = SimpleEncodeBasis(num_bessel)

    def forward(
        self, edge_lengths: torch.Tensor,  # [n_edges, 1]
    ):
        embeddings = self.emb_f(edge_lengths)  # [n_edges, n_basis]
        return embeddings


class DotProductBlock(torch.nn.Module):
    def __init__(self, irreps_in: so13.Lorentz_Irreps) -> None:
        super().__init__()
        self.irreps_in = irreps_in
        num_features = self.irreps_in.count()
        self.irreps_out = so13.Lorentz_Irreps(f"{num_features}x(0,0)")
        self.metric_tensor = metric(self.irreps_in)

    def forward(self, x, y):
        return torch.einsum("...a,ab,...b->...", x, self.metric_tensor, y)

    def __repr__(self):
        return f"{self.__class__.__name__}({self.irreps_in} x {self.irreps_in}-> {self.irreps_out})"


class EquivariantProductBasisBlock(torch.nn.Module):
    def __init__(
        self,
        node_feats_irreps: so13.Lorentz_Irreps,
        target_irreps: so13.Lorentz_Irreps,
        correlation: Union[int, Dict[str, int]],
        element_dependent: bool = True,
        use_sc: bool = True,
        num_elements: Optional[int] = None,
    ) -> None:
        super().__init__()

        self.use_sc = use_sc
        self.symmetric_contractions = SymmetricContraction(
            irreps_in=node_feats_irreps,
            irreps_out=target_irreps,
            correlation=correlation,
            element_dependent=element_dependent,
            num_elements=num_elements,
        )
        # Update linear
        self.linear = so13.Linear(
            target_irreps, target_irreps, internal_weights=True, shared_weights=True,
        )

    def forward(
        self, node_feats: torch.Tensor, sc: torch.Tensor, node_attrs: torch.Tensor
    ) -> torch.Tensor:
        node_feats = self.symmetric_contractions(node_feats, node_attrs)
        if self.use_sc:
            node_feats_real = self.linear(node_feats.real) + sc.real
            node_feats_imag = self.linear(node_feats.imag) + sc.imag
            return node_feats_real + 1j * node_feats_imag

        node_feats_real = self.linear(node_feats.real)
        node_feats_imag = self.linear(node_feats.imag)
        return node_feats_real + 1j * node_feats_imag


class InteractionBlock(ABC, torch.nn.Module):
    def __init__(
        self,
        node_attrs_irreps: so13.Lorentz_Irreps,
        node_feats_irreps: so13.Lorentz_Irreps,
        edge_attrs_irreps: so13.Lorentz_Irreps,
        edge_feats_irreps: so13.Lorentz_Irreps,
        target_irreps: so13.Lorentz_Irreps,
        hidden_irreps: so13.Lorentz_Irreps,
        avg_num_neighbors: float,
    ) -> None:
        super().__init__()
        self.node_attrs_irreps = node_attrs_irreps
        self.node_feats_irreps = node_feats_irreps
        self.edge_attrs_irreps = edge_attrs_irreps
        self.edge_feats_irreps = edge_feats_irreps
        self.target_irreps = target_irreps
        self.hidden_irreps = hidden_irreps
        self.avg_num_neighbors = avg_num_neighbors

        self._setup()

    @abstractmethod
    def _setup(self) -> None:
        raise NotImplementedError

    @abstractmethod
    def forward(
        self,
        node_attrs: torch.Tensor,
        node_feats: torch.Tensor,
        edge_attrs: torch.Tensor,
        edge_feats: torch.Tensor,
        edge_index: torch.Tensor,
    ) -> torch.Tensor:
        raise NotImplementedError


nonlinearities = {1: torch.nn.functional.gelu, -1: torch.tanh}


class ComplexAgnosticResidualInteractionBlock(InteractionBlock):
    def _setup(self,) -> None:

        # First linear
        self.linear_up = so13.Linear(
            self.node_feats_irreps,
            self.node_feats_irreps,
            internal_weights=True,
            shared_weights=True,
        )
        # TensorProduct
        irreps_mid, instructions = tp_out_irreps_with_instructions(
            self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps
        )
        self.conv_tp = so13.TensorProduct(
            self.node_feats_irreps,
            self.edge_attrs_irreps,
            irreps_mid,
            instructions=instructions,
            shared_weights=False,
            internal_weights=False,
        )
        # Convolution weights
        input_dim = self.edge_feats_irreps.num_irreps
        self.conv_tp_weights = so13.FullyConnectedNet(
            [input_dim] + 3 * [64] + [self.conv_tp.weight_numel],
            torch.nn.functional.gelu,
        )
        # Linear
        irreps_mid = irreps_mid.simplify()
        self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps)
        self.irreps_out = self.irreps_out.simplify()
        self.linear = so13.Linear(
            irreps_mid,
            self.irreps_out,
            internal_weights=True,
            shared_weights=True,
            use_complex=False,
        )

        # Selector TensorProduct
        self.skip_tp = so13.FullyConnectedTensorProduct(
            self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out
        )

    def forward(
        self,
        node_attrs: torch.Tensor,
        node_feats: torch.Tensor,
        edge_attrs: torch.Tensor,
        edge_feats: torch.Tensor,
        edge_index: torch.Tensor,
    ) -> List[torch.Tensor]:
        sender, receiver = edge_index
        num_nodes = node_feats.shape[0]
        sc_real = self.skip_tp(node_feats.real, node_attrs)
        sc_imag = self.skip_tp(node_feats.imag, node_attrs)
        node_feats_real = self.linear_up(node_feats.real)
        node_feats_imag = self.linear_up(node_feats.imag)
        tp_weights = self.conv_tp_weights(edge_feats)
        mji_real = self.conv_tp(
            node_feats_real[sender], edge_attrs.real, tp_weights
        ) - self.conv_tp(
            node_feats_imag[sender], edge_attrs.imag, tp_weights
        )  # [n_edges, irreps]
        mji_imag = self.conv_tp(
            node_feats_real[sender], edge_attrs.imag, tp_weights
        ) + self.conv_tp(node_feats_imag[sender], edge_attrs.real, tp_weights)
        message_real = scatter_sum(
            src=mji_real, index=receiver, dim=0, dim_size=num_nodes
        )  # [n_nodes, irreps]
        message_imag = scatter_sum(
            src=mji_imag, index=receiver, dim=0, dim_size=num_nodes
        )
        message_real = self.linear(message_real) / self.avg_num_neighbors
        message_real = message_real + sc_real
        message_imag = self.linear(message_imag) / self.avg_num_neighbors
        message_imag = message_imag + sc_imag

        message = torch.view_as_complex(
            torch.stack((message_real, message_imag), dim=-1)
        )
        return message  # [n_nodes, n_channelx(lmax + 1)**2]


class ComplexAgnosticNonLinearResidualInteractionBlock(InteractionBlock):
    def _setup(self,) -> None:

        # First linear
        self.linear_up = so13.Linear(
            self.node_feats_irreps,
            self.node_feats_irreps,
            internal_weights=True,
            shared_weights=True,
        )
        # TensorProduct
        irreps_mid, instructions = tp_out_irreps_with_instructions(
            self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps
        )

        self.conv_tp = so13.TensorProduct(
            self.node_feats_irreps,
            self.edge_attrs_irreps,
            irreps_mid,
            instructions=instructions,
            shared_weights=False,
            internal_weights=False,
        )

        irreps_mid = irreps_mid.simplify()

        # Convolution weights
        input_dim = self.edge_feats_irreps.num_irreps
        self.conv_tp_weights = so13.FullyConnectedNet(
            [input_dim] + 3 * [64] + [self.conv_tp.weight_numel],
            torch.nn.functional.gelu,
        )

        # equivariant non linearity
        irreps_scalars = so13.Lorentz_Irreps(
            [
                (mul, ir)
                for mul, ir in self.target_irreps
                if ir.l == 0 and ir.k == 0 and ir in irreps_mid
            ]
        )
        irreps_gated = so13.Lorentz_Irreps(
            [
                (mul, ir)
                for mul, ir in self.target_irreps
                if ir.l > 0 and ir in irreps_mid
            ]
        )
        irreps_gates = so13.Lorentz_Irreps([mul, "(0,0)"] for mul, _ in irreps_gated)
        self.equivariant_nonlin = so13.Gate(
            irreps_scalars=irreps_scalars,
            act_scalars=[torch.nn.functional.gelu] * len(irreps_scalars),
            irreps_gates=irreps_gates,
            act_gates=[torch.nn.functional.gelu] * len(irreps_gates),
            irreps_gated=irreps_gated,
        )
        self.irreps_nonlin = self.equivariant_nonlin.irreps_in.simplify()
        self.irreps_out = self.equivariant_nonlin.irreps_out.simplify()

        # Linear
        self.linear = so13.Linear(
            irreps_mid, self.irreps_nonlin, internal_weights=True, shared_weights=True
        )

        # Selector TensorProduct
        self.skip_tp = so13.FullyConnectedTensorProduct(
            self.node_feats_irreps, self.node_attrs_irreps, self.irreps_nonlin
        )

    def forward(
        self,
        node_attrs: torch.Tensor,
        node_feats: torch.Tensor,
        edge_attrs: torch.Tensor,
        edge_feats: torch.Tensor,
        edge_index: torch.Tensor,
    ) -> List[torch.Tensor]:
        sender, receiver = edge_index
        num_nodes = node_feats.shape[0]
        sc_real = self.skip_tp(node_feats.real, node_attrs)
        sc_imag = self.skip_tp(node_feats.imag, node_attrs)
        node_feats_real = self.linear_up(node_feats.real)
        node_feats_imag = self.linear_up(node_feats.imag)

        tp_weights = self.conv_tp_weights(edge_feats)
        mji_real = self.conv_tp(
            node_feats_real[sender], edge_attrs.real, tp_weights
        ) - self.conv_tp(
            node_feats_imag[sender], edge_attrs.imag, tp_weights
        )  # [n_edges, irreps]
        mji_imag = self.conv_tp(
            node_feats_real[sender], edge_attrs.imag, tp_weights
        ) + self.conv_tp(node_feats_imag[sender], edge_attrs.real, tp_weights)
        message_real = scatter_mean(
            src=mji_real, index=receiver, dim=0, dim_size=num_nodes
        )  # [n_nodes, irreps]
        message_imag = scatter_mean(
            src=mji_imag, index=receiver, dim=0, dim_size=num_nodes
        )

        message_real = self.linear(message_real) / self.avg_num_neighbors
        message_real = self.equivariant_nonlin(message_real + sc_real)
        message_imag = self.linear(message_imag) / self.avg_num_neighbors
        message_imag = self.equivariant_nonlin(message_imag + sc_imag)

        message = torch.view_as_complex(
            torch.stack((message_real, message_imag), dim=-1)
        )
        return message  # [n_nodes, n_channelx(lmax + 1)**2]


class ComplexAgnosticNonLinearReshapeResidualInteractionBlock(InteractionBlock):
    def _setup(self,) -> None:

        # First linear
        self.linear_up = so13.Linear(
            self.node_feats_irreps,
            self.node_feats_irreps,
            internal_weights=True,
            shared_weights=True,
        )
        # TensorProduct
        irreps_mid, instructions = tp_out_irreps_with_instructions(
            self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps
        )

        self.conv_tp = so13.TensorProduct(
            self.node_feats_irreps,
            self.edge_attrs_irreps,
            irreps_mid,
            instructions=instructions,
            shared_weights=False,
            internal_weights=False,
        )

        self.dropout = torch.nn.Dropout(0.05)

        irreps_mid = irreps_mid.simplify()

        # Convolution weights
        input_dim = self.edge_feats_irreps.num_irreps
        self.conv_tp_weights = so13.FullyConnectedNet(
            [input_dim] + 3 * [64] + [self.conv_tp.weight_numel],
            torch.nn.functional.leaky_relu,
        )

        # equivariant non linearity
        irreps_scalars = so13.Lorentz_Irreps(
            [
                (mul, ir)
                for mul, ir in self.target_irreps
                if ir.l == 0 and ir.k == 0 and ir in irreps_mid
            ]
        )
        irreps_gated = so13.Lorentz_Irreps(
            [
                (mul, ir)
                for mul, ir in self.target_irreps
                if ir.l > 0 and ir in irreps_mid
            ]
        )
        irreps_gates = so13.Lorentz_Irreps([mul, "(0,0)"] for mul, _ in irreps_gated)
        self.equivariant_nonlin = so13.Gate(
            irreps_scalars=irreps_scalars,
            act_scalars=[torch.nn.functional.gelu] * len(irreps_scalars),
            irreps_gates=irreps_gates,
            act_gates=[torch.nn.functional.gelu] * len(irreps_gates),
            irreps_gated=irreps_gated,
        )
        self.irreps_nonlin = self.equivariant_nonlin.irreps_in.simplify()
        self.irreps_out = self.equivariant_nonlin.irreps_out.simplify()

        # Linear
        self.linear = so13.Linear(
            irreps_mid, self.irreps_nonlin, internal_weights=True, shared_weights=True
        )

        # Selector TensorProduct
        self.skip_tp = so13.FullyConnectedTensorProduct(
            self.node_feats_irreps, self.node_attrs_irreps, self.irreps_nonlin
        )
        self.skip_tp_out = so13.FullyConnectedTensorProduct(
            self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps
        )
        self.reshape = reshape_irreps(self.irreps_out)
        self.alpha = torch.nn.Parameter(torch.tensor(1.0, requires_grad=True))

        self.real_norm = SwitchNorm1d(irreps_mid.dim)
        self.imag_norm = SwitchNorm1d(irreps_mid.dim)

    def forward(
        self,
        node_attrs: torch.Tensor,
        node_feats: torch.Tensor,
        edge_attrs: torch.Tensor,
        edge_feats: torch.Tensor,
        edge_index: torch.Tensor,
    ) -> List[torch.Tensor]:
        sender, receiver = edge_index
        num_nodes = node_feats.shape[0]

        sc_real = self.skip_tp(node_feats.real, node_attrs)
        sc_imag = self.skip_tp(node_feats.imag, node_attrs)
        sc_real_out = self.skip_tp_out(node_feats.real, node_attrs)
        sc_imag_out = self.skip_tp_out(node_feats.imag, node_attrs)
        node_feats_real = self.linear_up(node_feats.real)
        node_feats_imag = self.linear_up(node_feats.imag)

        # node_feats_real = self.dropout(node_feats_real)
        # node_feats_imag = self.dropout(node_feats_imag)

        tp_weights = self.conv_tp_weights(edge_feats)

        mji_real = self.conv_tp(
            node_feats_real[sender], edge_attrs.real, tp_weights
        ) - self.conv_tp(
            node_feats_imag[sender], edge_attrs.imag, tp_weights
        )  # [n_edges, irreps]
        mji_imag = self.conv_tp(
            node_feats_real[sender], edge_attrs.imag, tp_weights
        ) + self.conv_tp(node_feats_imag[sender], edge_attrs.real, tp_weights)

        message_real = scatter_mean(
            src=mji_real, index=receiver, dim=0, dim_size=num_nodes
        )  # [n_nodes, irreps]

        message_imag = scatter_mean(
            src=mji_imag, index=receiver, dim=0, dim_size=num_nodes
        )

        message_real = scatter_sum(
            src=mji_real, index=receiver, dim=0, dim_size=num_nodes
        )  # [n_nodes, irreps]
        message_imag = scatter_sum(
            src=mji_imag, index=receiver, dim=0, dim_size=num_nodes
        )
        # message_real = self.real_norm(message_real) / 40 * 2**0.5
        # message_imag = self.imag_norm(message_imag) / 40 * 2**0.5
        message_real = self.linear(message_real) / 40 * 2 ** 0.5
        message_real = self.equivariant_nonlin(message_real + sc_real)
        message_imag = self.linear(message_imag) / 40 * 2 ** 0.5
        message_imag = self.equivariant_nonlin(message_imag + sc_imag)

        message = torch.view_as_complex(
            torch.stack((message_real, message_imag), dim=-1)
        )
        sc = torch.view_as_complex(torch.stack((sc_real_out, sc_imag_out), dim=-1))
        return (self.reshape(message), sc)  # [n_nodes, n_channelx(lmax + 1)**2]


class ComplexActBock(torch.nn.Module):
    def __init__(self, act, use_phase=False):
        # act can be either a function from nn.functional or a nn.Module if the
        # activation has learnable parameters
        self.act = act
        self.use_phase = use_phase

    def forward(self, x: torch.Tensor):
        if self.use_phase:
            return self.act(torch.abs(x)) * torch.exp(1.0j * torch.angle(x))
        else:
            return self.act(x.real) + 1.0j * self.act(x.imag)


class ScaleShiftBlock(torch.nn.Module):
    def __init__(self, scale: float, shift: float):
        super().__init__()
        self.register_buffer(
            "scale", torch.tensor(scale, dtype=torch.get_default_dtype())
        )
        self.register_buffer(
            "shift", torch.tensor(shift, dtype=torch.get_default_dtype())
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.scale * x + self.shift

    def __repr__(self):
        return (
            f"{self.__class__.__name__}(scale={self.scale:.6f}, shift={self.shift:.6f})"
        )
