from typing import Any, Callable, Dict, Type
from LorentzMACE.data import AtomicData
from LorentzMACE.data.utils import minkowski_norm
import numpy as np
import torch
from LorentzMACE.modules.blocks import (
    AttributesBlock,
    ComplexAgnosticNonLinearReshapeResidualInteractionBlock,
    EquivariantProductBasisBlock,
    InteractionBlock,
    LinearNodeEmbeddingBlock,
    LinearReadoutBlock,
    NonLinearReadoutBlock,
    ScaleShiftBlock,
    SimpleRadialEmbeddingBlock,
)
from LorentzMACE.modules.spherical_harmonics import SphericalHarmonics
from LorentzMACE.modules.utils import get_edge_vectors_and_lengths
from LorentzMACE.tools.torch_geometric import (
    GraphNorm,
    MeanSubtractionNorm,
    BatchNorm,
    SwitchNorm1d,
)
from LieCG import so13
from LieCG.CG_coefficients.CG_lorentz import CGDict
from LorentzMACE.tools.scatter import scatter_sum, scatter_mean, scatter_std


class LorentzBOTNet(torch.nn.Module):
    def __init__(
        self,
        r_max: float,
        num_bessel: int,
        num_polynomial_cutoff: int,
        max_ell: int,
        interaction_cls: Type[InteractionBlock],
        interaction_cls_first: Type[InteractionBlock],
        radial_basis_cls: Type[Callable],
        num_interactions: int,
        num_elements: int,
        scale: int,
        hidden_irreps: so13.Lorentz_Irreps,
        MLP_irreps: so13.Lorentz_Irreps,
        readout_irreps: so13.Lorentz_Irreps,
        gate: Callable,
        avg_num_neighbors: float,
        use_cutoff: bool,
        device: str,
    ):
        super().__init__()

        # Rescaling
        self.scaling = ScaleShiftBlock(scale=scale, shift=0)
        self.graph_norm = GraphNorm()

        # Embedding
        self.attributes_embedding = AttributesBlock(scale=scale)
        node_attr_irreps = so13.Lorentz_Irreps([(num_elements, (0, 0))])
        node_feats_irreps = so13.Lorentz_Irreps(
            [(hidden_irreps.count(so13.Lorentz_Irrep(0, 0)), (0, 0))]
        )
        self.node_embedding = LinearNodeEmbeddingBlock(
            irreps_in=node_attr_irreps, irreps_out=node_feats_irreps
        )
        self.radial_embedding = radial_basis_cls(
            r_max=r_max,
            num_bessel=num_bessel,
            num_polynomial_cutoff=num_polynomial_cutoff,
            use_cutoff=use_cutoff,
        )
        edge_feats_irreps = so13.Lorentz_Irreps(
            f"{self.radial_embedding.out_dim}x(0,0)"
        )

        cg_dict = CGDict(max_ell + 1, device=device)
        sh_irreps = so13.Lorentz_Irreps.spherical_harmonics(max_ell)
        self.spherical_harmonics = SphericalHarmonics(cg_dict, max_ell)

        self.interactions = torch.nn.ModuleList()
        self.readouts = torch.nn.ModuleList()

        inter = interaction_cls_first(
            node_attrs_irreps=node_attr_irreps,
            node_feats_irreps=node_feats_irreps,
            edge_attrs_irreps=sh_irreps,
            edge_feats_irreps=edge_feats_irreps,
            target_irreps=hidden_irreps,
            avg_num_neighbors=avg_num_neighbors,
        )
        self.interactions.append(inter)
        self.readouts.append(LinearReadoutBlock(inter.irreps_out, readout_irreps))

        for i in range(num_interactions - 1):
            inter = interaction_cls(
                node_attrs_irreps=node_attr_irreps,
                node_feats_irreps=inter.irreps_out,
                edge_attrs_irreps=sh_irreps,
                edge_feats_irreps=edge_feats_irreps,
                target_irreps=hidden_irreps,
                avg_num_neighbors=avg_num_neighbors,
            )
            self.interactions.append(inter)
            if i == num_interactions - 2:
                self.readouts.append(
                    NonLinearReadoutBlock(inter.irreps_out, MLP_irreps, gate)
                )
            else:
                self.readouts.append(
                    LinearReadoutBlock(inter.irreps_out, readout_irreps)
                )

    def forward(self, data: AtomicData, training=False) -> Dict[str, Any]:
        # Setup
        data.positions.requires_grad = True
        positions = self.scaling(data.positions)

        # Embeddings
        node_attrs = self.attributes_embedding(positions)
        node_feats = self.node_embedding(node_attrs)

        vectors, lengths = get_edge_vectors_and_lengths(
            positions=positions, edge_index=data.edge_index,
        )
        edge_attrs = self.spherical_harmonics(vectors)
        edge_feats = self.radial_embedding(lengths)

        # Interactions
        energies = []
        for interaction, readout in zip(self.interactions, self.readouts):
            node_feats = interaction(
                node_attrs=node_attrs,
                node_feats=node_feats,
                edge_attrs=edge_attrs,
                edge_feats=edge_feats,
                edge_index=data.edge_index,
            )
            node_energies = readout(node_feats)  # [n_nodes, 2]
            energy = scatter_sum(
                src=node_energies,
                index=data.batch.unsqueeze(-1),
                dim=0,
                dim_size=data.num_graphs,
            )  # [n_graphs,2]
            energies.append(energy)

        # Sum over energy contributions
        contributions = torch.stack(energies, dim=0)
        total_energy = torch.sum(contributions, dim=0)  # [n_contributions,n_graphs,2]

        output = {
            "energy": total_energy,
            "contributions": contributions,
        }
        return output


class SingleReadoutModel(torch.nn.Module):
    def __init__(
        self,
        r_max: float,
        num_bessel: int,
        num_polynomial_cutoff: int,
        max_ell: int,
        interaction_cls: Type[InteractionBlock],
        interaction_cls_first: Type[InteractionBlock],
        radial_basis_cls: Type[Callable],
        num_interactions: int,
        num_elements: int,
        scale: int,
        hidden_irreps: so13.Lorentz_Irreps,
        MLP_irreps: so13.Lorentz_Irreps,
        readout_irreps: so13.Lorentz_Irreps,
        gate: Callable,
        avg_num_neighbors: float,
        use_cutoff: bool,
        device: str,
    ):
        super().__init__()

        # Rescaling
        self.scaling = ScaleShiftBlock(scale=scale, shift=0)

        # Embedding
        self.attributes_embedding = AttributesBlock(scale=scale)
        node_attr_irreps = so13.Lorentz_Irreps([(num_elements, (0, 0))])
        node_feats_irreps = so13.Lorentz_Irreps(
            [(hidden_irreps.count(so13.Lorentz_Irrep(0, 0)), (0, 0))]
        )
        self.node_embedding = LinearNodeEmbeddingBlock(
            irreps_in=node_attr_irreps, irreps_out=node_feats_irreps
        )
        self.radial_embedding = radial_basis_cls(
            r_max=r_max,
            num_bessel=num_bessel,
            num_polynomial_cutoff=num_polynomial_cutoff,
            use_cutoff=use_cutoff,
        )
        edge_feats_irreps = so13.Lorentz_Irreps(
            f"{self.radial_embedding.out_dim}x(0,0)"
        )

        cg_dict = CGDict(max_ell + 1, device=device)
        sh_irreps = so13.Lorentz_Irreps.spherical_harmonics(max_ell)
        self.spherical_harmonics = SphericalHarmonics(cg_dict, max_ell)

        self.interactions = torch.nn.ModuleList()
        self.readouts = torch.nn.ModuleList()

        inter = interaction_cls_first(
            node_attrs_irreps=node_attr_irreps,
            node_feats_irreps=node_feats_irreps,
            edge_attrs_irreps=sh_irreps,
            edge_feats_irreps=edge_feats_irreps,
            target_irreps=hidden_irreps,
            avg_num_neighbors=avg_num_neighbors,
        )
        self.interactions.append(inter)

        for i in range(num_interactions - 1):
            inter = interaction_cls(
                node_attrs_irreps=node_attr_irreps,
                node_feats_irreps=inter.irreps_out,
                edge_attrs_irreps=sh_irreps,
                edge_feats_irreps=edge_feats_irreps,
                target_irreps=hidden_irreps,
                avg_num_neighbors=avg_num_neighbors,
            )
            self.interactions.append(inter)

        # Readout
        self.readouts.append(LinearReadoutBlock(inter.irreps_out, readout_irreps))

        # Final fully connected layer
        input_dim = self.readouts[0].irreps_out.num_irreps
        mid_dim = MLP_irreps.num_irreps
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(input_dim, mid_dim),
            torch.nn.SiLU(),
            torch.nn.Linear(mid_dim, 2),
        )

    def forward(self, data: AtomicData, training=False) -> Dict[str, Any]:
        # Setup
        positions = self.scaling(data.positions)
        data.positions.requires_grad = True

        # Embeddings
        # node_attrs = self.attributes_embedding(positions)
        node_feats = self.node_embedding(data.node_attrs)

        vectors, lengths = get_edge_vectors_and_lengths(
            positions=positions, edge_index=data.edge_index,
        )
        edge_attrs = self.spherical_harmonics(vectors)
        edge_feats = self.radial_embedding(lengths)

        # Interactions
        for interaction in self.interactions:
            node_feats = interaction(
                node_attrs=data.node_attrs,
                node_feats=node_feats,
                edge_attrs=edge_attrs,
                edge_feats=edge_feats,
                edge_index=data.edge_index,
            )
        node_energies = self.readouts[0](node_feats)  # [n_nodes, 16]
        inter_e = scatter_sum(
            src=node_energies,
            index=data.batch.unsqueeze(-1),
            dim=0,
            dim_size=data.num_graphs,
        )  # [n_graphs,16]
        probabilities = self.fc(inter_e)

        output = {
            "energy": probabilities,
        }
        return output


class LorentzMACELayer(torch.nn.Module):
    def __init__(
        self,
        r_max: float,
        num_bessel: int,
        num_polynomial_cutoff: int,
        max_ell: int,
        num_elements: int,
        scale: int,
        hidden_irreps: str,
        target_irreps: str,
        readout_irreps: str,
        correlation: int,
        avg_num_neighbors: float,
        use_cutoff: bool,
        device: str,
    ):
        super().__init__()

        # Rescaling
        self.scaling = ScaleShiftBlock(scale=scale, shift=0)
        hidden_irreps = so13.Lorentz_Irreps(hidden_irreps)
        readout_irreps = so13.Lorentz_Irreps(readout_irreps)
        target_irreps = so13.Lorentz_Irreps(target_irreps)

        # Embedding
        self.attributes_embedding = AttributesBlock(scale=scale)
        node_attr_irreps = so13.Lorentz_Irreps([(num_elements, (0, 0))])
        node_feats_irreps = so13.Lorentz_Irreps(
            [(hidden_irreps.count(so13.Lorentz_Irrep(0, 0)), (0, 0))]
        )
        self.node_embedding = LinearNodeEmbeddingBlock(
            irreps_in=node_attr_irreps, irreps_out=node_feats_irreps
        )

        self.radial_norm = SwitchNorm1d(num_bessel)

        self.node_attrs_norm = SwitchNorm1d(node_attr_irreps.num_irreps)

        self.radial_embedding = SimpleRadialEmbeddingBlock(
            r_max=r_max,
            num_bessel=num_bessel,
            num_polynomial_cutoff=num_polynomial_cutoff,
            use_cutoff=use_cutoff,
        )
        edge_feats_irreps = so13.Lorentz_Irreps(
            f"{self.radial_embedding.out_dim}x(0,0)"
        )

        cg_dict = CGDict(max_ell + 1, device=device)
        sh_irreps = so13.Lorentz_Irreps.spherical_harmonics(max_ell)
        self.spherical_harmonics = SphericalHarmonics(cg_dict, max_ell)
        num_features = hidden_irreps.count(so13.Lorentz_Irrep(0, 0))
        interaction_irreps = (sh_irreps * num_features).sort()[0].simplify()

        self.interaction = ComplexAgnosticNonLinearReshapeResidualInteractionBlock(
            node_attrs_irreps=node_attr_irreps,
            node_feats_irreps=hidden_irreps,
            edge_attrs_irreps=sh_irreps,
            edge_feats_irreps=edge_feats_irreps,
            target_irreps=interaction_irreps,
            hidden_irreps=target_irreps,
            avg_num_neighbors=avg_num_neighbors,
        )
        node_feats_irreps_out = self.interaction.irreps_out
        self.product = EquivariantProductBasisBlock(
            node_feats_irreps=node_feats_irreps_out,
            target_irreps=target_irreps,
            correlation=correlation,
            element_dependent=True,
            num_elements=num_elements,
            use_sc=True,
        )

        # Readout
        self.readout = LinearReadoutBlock(target_irreps, readout_irreps)

    def forward(
        self,
        positions: torch.Tensor,
        node_feats: torch.Tensor,
        node_attrs: torch.Tensor,
        edge_index: torch.Tensor,
    ) -> Dict[str, Any]:

        positions = self.scaling(positions)

        vectors, lengths = get_edge_vectors_and_lengths(
            positions=positions, edge_index=edge_index,
        )
        edge_attrs = self.spherical_harmonics(vectors)

        edge_feats = self.radial_embedding(lengths)
        edge_feats = self.radial_norm(edge_feats)

        # Interactions
        node_feats, sc = self.interaction(
            node_attrs=node_attrs,
            node_feats=node_feats,
            edge_attrs=edge_attrs,
            edge_feats=edge_feats,
            edge_index=edge_index,
        )
        node_feats = self.product(node_feats=node_feats, sc=sc, node_attrs=node_attrs)
        node_energies = self.readout(node_feats)  # [n_nodes, 16]
        return node_energies


class LorentzMACEModel(torch.nn.Module):
    def __init__(
        self,
        r_max: float,
        num_bessel: int,
        num_polynomial_cutoff: int,
        max_ell: int,
        interaction_cls: Type[InteractionBlock],
        interaction_cls_first: Type[InteractionBlock],
        radial_basis_cls: Type[Callable],
        num_interactions: int,
        num_elements: int,
        scale: int,
        hidden_irreps: so13.Lorentz_Irreps,
        MLP_irreps: so13.Lorentz_Irreps,
        readout_irreps: so13.Lorentz_Irreps,
        gate: Callable,
        correlation: int,
        avg_num_neighbors: float,
        use_cutoff: bool,
        device: str,
    ):
        super().__init__()

        # Rescaling
        self._mean_norm = MeanSubtractionNorm()
        self.graph_norm = BatchNorm(4)
        self.scaling = ScaleShiftBlock(scale=scale, shift=0)

        # Embedding
        self.attributes_embedding = AttributesBlock(scale=scale)
        node_attr_irreps = so13.Lorentz_Irreps([(num_elements, (0, 0))])
        node_feats_irreps = so13.Lorentz_Irreps(
            [(hidden_irreps.count(so13.Lorentz_Irrep(0, 0)), (0, 0))]
        )
        self.node_embedding = LinearNodeEmbeddingBlock(
            irreps_in=node_attr_irreps, irreps_out=node_feats_irreps
        )

        self.radial_norm = SwitchNorm1d(num_bessel)

        self.node_attrs_norm = SwitchNorm1d(node_attr_irreps.num_irreps)

        self.radial_embedding = radial_basis_cls(
            r_max=r_max,
            num_bessel=num_bessel,
            num_polynomial_cutoff=num_polynomial_cutoff,
            use_cutoff=use_cutoff,
        )
        edge_feats_irreps = so13.Lorentz_Irreps(
            f"{self.radial_embedding.out_dim}x(0,0)"
        )

        cg_dict = CGDict(max_ell + 1, device=device)
        sh_irreps = so13.Lorentz_Irreps.spherical_harmonics(max_ell)
        self.spherical_harmonics = SphericalHarmonics(cg_dict, max_ell)
        num_features = hidden_irreps.count(so13.Lorentz_Irrep(0, 0))
        interaction_irreps = (sh_irreps * num_features).sort()[0].simplify()

        self.interactions = torch.nn.ModuleList()
        self.prods = torch.nn.ModuleList()
        self.readouts = torch.nn.ModuleList()

        inter = ComplexAgnosticNonLinearReshapeResidualInteractionBlock(
            node_attrs_irreps=node_attr_irreps,
            node_feats_irreps=node_feats_irreps,
            edge_attrs_irreps=sh_irreps,
            edge_feats_irreps=edge_feats_irreps,
            target_irreps=interaction_irreps,
            hidden_irreps=hidden_irreps,
            avg_num_neighbors=avg_num_neighbors,
        )
        node_feats_irreps_out = inter.irreps_out
        prod = EquivariantProductBasisBlock(
            node_feats_irreps=node_feats_irreps_out,
            target_irreps=hidden_irreps,
            correlation=correlation,
            element_dependent=True,
            num_elements=num_elements,
            use_sc=True,
        )
        self.interactions.append(inter)
        self.prods.append(prod)

        for i in range(num_interactions - 1):
            if i == num_interactions - 2:
                hidden_irreps_out = str(
                    hidden_irreps[0]
                )  # Select only scalars for last layer
            else:
                hidden_irreps_out = hidden_irreps
            inter = ComplexAgnosticNonLinearReshapeResidualInteractionBlock(
                node_attrs_irreps=node_attr_irreps,
                node_feats_irreps=hidden_irreps,
                edge_attrs_irreps=sh_irreps,
                edge_feats_irreps=edge_feats_irreps,
                target_irreps=interaction_irreps,
                hidden_irreps=hidden_irreps_out,
                avg_num_neighbors=avg_num_neighbors,
            )
            prod = EquivariantProductBasisBlock(
                node_feats_irreps=node_feats_irreps_out,
                target_irreps=hidden_irreps_out,
                correlation=correlation,
                element_dependent=True,
                num_elements=num_elements,
                use_sc=False,
            )
            self.interactions.append(inter)
            self.prods.append(prod)

        # Readout
        self.readouts.append(LinearReadoutBlock(hidden_irreps_out, readout_irreps))

        # Final fully connected layer
        input_dim = self.readouts[0].irreps_out.num_irreps

        self.mom_mapper = torch.nn.Sequential(
            torch.nn.Linear(3, 1), torch.nn.GELU(), torch.nn.Dropout(0.01),
        )

        mid_dim = MLP_irreps.num_irreps
        self.mom_attn = torch.nn.MultiheadAttention(
            input_dim, 8, 0.05, batch_first=True
        )

        self.fc = torch.nn.Sequential(
            torch.nn.Linear(input_dim, mid_dim),
            torch.nn.GELU(),
            torch.nn.Dropout(0.01),
            torch.nn.Linear(mid_dim, 2),
        )

    def forward(self, data: AtomicData, training=False) -> Dict[str, Any]:

        positions = self.scaling(data.positions)
        data.positions.requires_grad = True

        node_feats = self.node_embedding(data.node_attrs)

        vectors, lengths = get_edge_vectors_and_lengths(
            positions=positions, edge_index=data.edge_index,
        )
        edge_attrs = self.spherical_harmonics(vectors)

        edge_feats = self.radial_embedding(lengths)
        edge_feats = self.radial_norm(edge_feats)

        # Interactions
        for interaction, product in zip(self.interactions, self.prods):
            node_feats, sc = interaction(
                node_attrs=data.node_attrs,
                node_feats=node_feats,
                edge_attrs=edge_attrs,
                edge_feats=edge_feats,
                edge_index=data.edge_index,
            )
            node_feats = product(
                node_feats=node_feats, sc=sc, node_attrs=data.node_attrs
            )
        node_energies = self.readouts[0](node_feats)  # [n_nodes, 16]

        inter_e = scatter_mean(
            src=node_energies,
            index=data.batch.unsqueeze(-1),
            dim=0,
            dim_size=data.num_graphs,
        )  # [n_graphs,16]
        inter_std = scatter_std(
            src=node_energies,
            index=data.batch.unsqueeze(-1),
            dim=0,
            dim_size=data.num_graphs,
        )  # [n_graphs,16]
        inter_sum = scatter_sum(
            src=node_energies,
            index=data.batch.unsqueeze(-1),
            dim=0,
            dim_size=data.num_graphs,
        )  # [n_graphs,16]

        inter_e = inter_e[:, :, None]
        inter_std = inter_std[:, :, None]
        inter_sum = inter_sum[:, :, None]

        momentums = self.mom_mapper(torch.cat([inter_e, inter_std, inter_sum], dim=2))
        momentums = momentums.reshape(momentums.shape[0], 1, momentums.shape[1])
        att_momentums, _ = self.mom_attn(momentums, momentums, momentums)
        momentums = momentums + att_momentums
        momentums = momentums[:, 0, :]
        probabilities = self.fc(momentums)

        output = {
            "energy": probabilities,
        }
        return output
