import numpy as np
from typing import Any, Callable, Dict, List, Optional, Type, Union

import torch
from torch.nn import functional as F
import torch_geometric
from torch_geometric.nn import SchNet, DimeNetPlusPlus, global_add_pool, global_mean_pool
import torch_scatter
from torch_scatter import scatter
from e3nn import o3

from src.modules.blocks import (
    AtomicEnergiesBlock,
    EquivariantProductBasisBlock,
    InteractionBlock,
    LinearNodeEmbeddingBlock,
    LinearReadoutBlock,
    NonLinearReadoutBlock,
    RadialEmbeddingBlock,
    ScaleShiftBlock,
)
from src.modules import (
    interaction_classes,
    gate_dict
)
from src.modules.irreps_tools import reshape_irreps

from src.layers import MPNNLayer, EGNNLayer, TensorProductConvLayer
import src.gvp_layers as gvp


class MACEModel(torch.nn.Module):
    def __init__(
        self,
        r_max=10.0,
        num_bessel=8,
        num_polynomial_cutoff=5,
        max_ell=2,
        correlation=3,
        num_layers=5,
        emb_dim=64,
        in_dim=1,
        out_dim=1,
        aggr="sum",
        pool="sum",
        residual=True
    ):
        super().__init__()
        self.r_max = r_max
        self.emb_dim = emb_dim
        self.num_layers = num_layers
        self.residual = residual
        # Embedding
        self.radial_embedding = RadialEmbeddingBlock(
            r_max=r_max,
            num_bessel=num_bessel,
            num_polynomial_cutoff=num_polynomial_cutoff,
        )
        sh_irreps = o3.Irreps.spherical_harmonics(max_ell)
        self.spherical_harmonics = o3.SphericalHarmonics(
            sh_irreps, normalize=True, normalization="component"
        )

        # Embedding lookup for initial node features
        self.emb_in = torch.nn.Embedding(in_dim, emb_dim)

        self.convs = torch.nn.ModuleList()
        self.prods = torch.nn.ModuleList()
        self.reshapes = torch.nn.ModuleList()
        irrep_seq = [
            o3.Irreps(f'{emb_dim}x0e'),
            o3.Irreps(f'{emb_dim}x0e + {emb_dim}x1o + {emb_dim}x2e'),
            # o3.Irreps(f'{emb_dim//2}x0e + {emb_dim//2}x0o + {emb_dim//2}x1e + {emb_dim//2}x1o + {emb_dim//2}x2e + {emb_dim//2}x2o'),
        ]
        for i in range(num_layers):
            in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)]
            out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)]
            conv = TensorProductConvLayer(
                in_irreps=in_irreps,
                out_irreps=out_irreps,
                sh_irreps=sh_irreps,
                edge_feats_dim=self.radial_embedding.out_dim,
                hidden_dim=emb_dim,
                gate=False,
                aggr=aggr,
            )
            self.convs.append(conv)
            self.reshapes.append(reshape_irreps(out_irreps))
            prod = EquivariantProductBasisBlock(
                node_feats_irreps=out_irreps,
                target_irreps=out_irreps,
                correlation=correlation,
                element_dependent=False,
                num_elements=in_dim,
                use_sc=residual
            )
            self.prods.append(prod)

        # Global pooling/readout function
        self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool]

        # Predictor MLP
        self.pred = torch.nn.Sequential(
            torch.nn.Linear(emb_dim, emb_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_dim, out_dim)
        )
    
    def forward(self, batch):
        h = self.emb_in(batch.atoms)  # (n,) -> (n, d)

        # Edge features
        vectors = batch.pos[batch.edge_index[0]] - batch.pos[batch.edge_index[1]]  # [n_edges, 3]
        lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True)  # [n_edges, 1]
        edge_attrs = self.spherical_harmonics(vectors)
        edge_feats = self.radial_embedding(lengths)
        
        for conv, reshape, prod in zip(self.convs, self.reshapes, self.prods):
            # Message passing layer
            h_update = conv(h, batch.edge_index, edge_attrs, edge_feats)
            # Update node features
            sc = F.pad(h, (0, h_update.shape[-1] - h.shape[-1]))
            h = prod(reshape(h_update), sc, None)

        # Select only scalars for prediction
        h = h[:,:self.emb_dim]
        out = self.pool(h, batch.batch)  # (n, d) -> (batch_size, d)
        return self.pred(out)  # (batch_size, out_dim)


class TFNModel(torch.nn.Module):
    def __init__(
        self,
        r_max=10.0,
        num_bessel=8,
        num_polynomial_cutoff=5,
        max_ell=2,
        num_layers=5,
        emb_dim=64,
        in_dim=1,
        out_dim=1,
        aggr="sum",
        pool="sum",
        residual=True
    ):
        super().__init__()
        self.r_max = r_max
        self.emb_dim = emb_dim
        self.num_layers = num_layers
        self.residual = residual
        # Embedding
        self.radial_embedding = RadialEmbeddingBlock(
            r_max=r_max,
            num_bessel=num_bessel,
            num_polynomial_cutoff=num_polynomial_cutoff,
        )
        sh_irreps = o3.Irreps.spherical_harmonics(max_ell)
        self.spherical_harmonics = o3.SphericalHarmonics(
            sh_irreps, normalize=True, normalization="component"
        )

        # Embedding lookup for initial node features
        self.emb_in = torch.nn.Embedding(in_dim, emb_dim)

        self.convs = torch.nn.ModuleList()
        irrep_seq = [
            o3.Irreps(f'{emb_dim}x0e'),
            o3.Irreps(f'{emb_dim}x0e + {emb_dim}x1o + {emb_dim}x2e'),
            # o3.Irreps(f'{emb_dim//2}x0e + {emb_dim//2}x0o + {emb_dim//2}x1e + {emb_dim//2}x1o + {emb_dim//2}x2e + {emb_dim//2}x2o'),
        ]
        for i in range(num_layers):
            in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)]
            out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)]
            conv = TensorProductConvLayer(
                in_irreps=in_irreps,
                out_irreps=out_irreps,
                sh_irreps=sh_irreps,
                edge_feats_dim=self.radial_embedding.out_dim,
                hidden_dim=emb_dim,
                gate=True,
                aggr=aggr,
            )
            self.convs.append(conv)

        # Global pooling/readout function
        self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool]

        # Predictor MLP
        self.pred = torch.nn.Sequential(
            torch.nn.Linear(emb_dim, emb_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_dim, out_dim)
        )
    
    def forward(self, batch):
        h = self.emb_in(batch.atoms)  # (n,) -> (n, d)

        # Edge features
        vectors = batch.pos[batch.edge_index[0]] - batch.pos[batch.edge_index[1]]  # [n_edges, 3]
        lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True)  # [n_edges, 1]
        edge_attrs = self.spherical_harmonics(vectors)
        edge_feats = self.radial_embedding(lengths)
        
        for conv in self.convs:
            # Message passing layer
            h_update = conv(h, batch.edge_index, edge_attrs, edge_feats)

            # Update node features
            h = h_update + F.pad(h, (0, h_update.shape[-1] - h.shape[-1])) if self.residual else h_update

        # Select only scalars for prediction
        h = h[:,:self.emb_dim]
        out = self.pool(h, batch.batch)  # (n, d) -> (batch_size, d)
        return self.pred(out)  # (batch_size, out_dim)


class GVPGNNModel(torch.nn.Module):
    def __init__(
        self,
        r_max=10.0,
        num_bessel=8,
        num_polynomial_cutoff=5,
        num_layers=5,
        emb_dim=64,
        in_dim=1,
        out_dim=1,
        aggr="sum",
        pool="sum",
        residual=True
    ):
        super().__init__()
        _DEFAULT_V_DIM = (emb_dim, emb_dim)
        _DEFAULT_E_DIM = (emb_dim, 1)
        activations = (F.relu, None)

        self.r_max = r_max
        self.emb_dim = emb_dim
        self.num_layers = num_layers
        # Embedding
        self.radial_embedding = RadialEmbeddingBlock(
            r_max=r_max,
            num_bessel=num_bessel,
            num_polynomial_cutoff=num_polynomial_cutoff,
        )
        self.emb_in = torch.nn.Embedding(in_dim, emb_dim)
        self.W_e = torch.nn.Sequential(
            gvp.LayerNorm((self.radial_embedding.out_dim, 1)),
            gvp.GVP((self.radial_embedding.out_dim, 1), _DEFAULT_E_DIM, 
                activations=(None, None), vector_gate=True)
        )
        self.W_v = torch.nn.Sequential(
            gvp.LayerNorm((emb_dim, 0)),
            gvp.GVP((emb_dim, 0), _DEFAULT_V_DIM,
                activations=(None, None), vector_gate=True)
        )
        
        # Stack of GNN layers
        self.layers = torch.nn.ModuleList(
                gvp.GVPConvLayer(_DEFAULT_V_DIM, _DEFAULT_E_DIM, 
                             activations=activations, vector_gate=True,
                             residual=residual) 
            for _ in range(num_layers))
        
        self.W_out = torch.nn.Sequential(
            gvp.LayerNorm(_DEFAULT_V_DIM),
            gvp.GVP(_DEFAULT_V_DIM, (emb_dim, 0), 
                activations=activations, vector_gate=True)
        )
        
        # Global pooling/readout function
        self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool]

        # Predictor MLP
        self.pred = torch.nn.Sequential(
            torch.nn.Linear(emb_dim, emb_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_dim, out_dim)
        )
    
    def forward(self, batch):

        # Edge features
        vectors = batch.pos[batch.edge_index[0]] - batch.pos[batch.edge_index[1]]  # [n_edges, 3]
        lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True)  # [n_edges, 1]
        
        h_V = self.emb_in(batch.atoms)  # (n,) -> (n, d)
        h_E = (self.radial_embedding(lengths), torch.nan_to_num(torch.div(vectors, lengths)).unsqueeze_(-2))

        h_V = self.W_v(h_V)
        h_E = self.W_e(h_E)
    
        for layer in self.layers:
            h_V = layer(h_V, batch.edge_index, h_E)

        out = self.W_out(h_V)
        
        out = self.pool(out, batch.batch)  # (n, d) -> (batch_size, d)
        return self.pred(out)  # (batch_size, out_dim)


class EGNNModel(torch.nn.Module):
    def __init__(
        self,
        num_layers=5,
        emb_dim=128,
        in_dim=1,
        out_dim=1,
        activation="relu",
        norm="layer",
        aggr="sum",
        pool="sum",
        residual=True
    ):
        """E(n) Equivariant GNN from Satorras-etal.
        This model uses both node features and coordinates as inputs, and
        is invariant to 3D rotations, reflections and translations (constituent
        GNN layers are equivariant to 3D rotations, reflections and translations).
        Args:
            num_layers: (int) - number of message passing layers `L`
            emb_dim: (int) - hidden dimension `d`
            in_dim: (int) - initial node feature dimension `d_n`
            activation: (str) - non-linearity within MLPs (swish/relu)
            norm: (str) - normalisation layer (layer/batch)
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
            pool: (str) - global pooling function (sum/mean)
        """
        super().__init__()

        # Embedding lookup for initial node features
        self.emb_in = torch.nn.Embedding(in_dim, emb_dim)

        # Stack of GNN layers
        self.convs = torch.nn.ModuleList()
        for layer in range(num_layers):
            self.convs.append(EGNNLayer(emb_dim, activation, norm, aggr))

        # Global pooling/readout function
        self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool]

        # Predictor MLP
        self.pred = torch.nn.Sequential(
            torch.nn.Linear(emb_dim, emb_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_dim, out_dim)
        )
        self.residual = residual

    def forward(self, batch):
        
        h = self.emb_in(batch.atoms)  # (n,) -> (n, d)
        pos = batch.pos  # (n, 3)

        for conv in self.convs:
            # Message passing layer
            h_update, pos_update = conv(h, pos, batch.edge_index)

            # Update node features (n, d) -> (n, d)
            h = h + h_update if self.residual else h_update 

            # Update node coordinates (no residual) (n, 3) -> (n, 3)
            pos = pos_update

        out = self.pool(h, batch.batch)  # (n, d) -> (batch_size, d)
        return self.pred(out)  # (batch_size, out_dim)


class MPNNModel(torch.nn.Module):
    def __init__(
        self,
        num_layers=5,
        emb_dim=128,
        in_dim=1,
        out_dim=1,
        activation="relu",
        norm="layer",
        aggr="sum",
        pool="sum",
        residual=True
    ):
        """Vanilla Message Passing GNN model
        Args:
            num_layers: (int) - number of message passing layers `L`
            emb_dim: (int) - hidden dimension `d`
            in_dim: (int) - initial node feature dimension `d_n`
            activation: (str) - non-linearity within MLPs (swish/relu)
            norm: (str) - normalisation layer (layer/batch)
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
            pool: (str) - global pooling function (sum/mean)
        """
        super().__init__()

        # Embedding lookup for initial node features
        self.emb_in = torch.nn.Embedding(in_dim, emb_dim)

        # Stack of GNN layers
        self.convs = torch.nn.ModuleList()
        for layer in range(num_layers):
            self.convs.append(MPNNLayer(emb_dim, activation, norm, aggr))

        # Global pooling/readout function
        self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool]

        # Predictor MLP
        self.pred = torch.nn.Sequential(
            torch.nn.Linear(emb_dim, emb_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_dim, out_dim)
        )
        self.residual = residual

    def forward(self, batch):
        
        h = self.emb_in(batch.atoms)  # (n,) -> (n, d)
        
        for conv in self.convs:
            # Message passing layer and residual connection
            h = h + conv(h, batch.edge_index) if self.residual else conv(h, batch.edge_index)

        out = self.pool(h, batch.batch)  # (n, d) -> (batch_size, d)
        return self.pred(out)  # (batch_size, out_dim)


class SchNetModel(SchNet):
    def __init__(
        self, 
        hidden_channels: int = 128, 
        in_dim: int = 1,
        out_dim: int = 1, 
        num_filters: int = 128, 
        num_layers: int = 6,
        num_gaussians: int = 50, 
        cutoff: float = 10, 
        max_num_neighbors: int = 32, 
        readout: str = 'add', 
        dipole: bool = False,
        mean: Optional[float] = None, 
        std: Optional[float] = None, 
        atomref: Optional[torch.Tensor] = None,
    ):
        super().__init__(hidden_channels, num_filters, num_layers, num_gaussians, cutoff, max_num_neighbors, readout, dipole, mean, std, atomref)

        # Overwrite atom embedding and final predictor
        self.lin2 = torch.nn.Linear(hidden_channels // 2, out_dim)

    def forward(self, batch):
        h = self.embedding(batch.atoms)

        row, col = batch.edge_index
        edge_weight = (batch.pos[row] - batch.pos[col]).norm(dim=-1)
        edge_attr = self.distance_expansion(edge_weight)

        for interaction in self.interactions:
            h = h + interaction(h, batch.edge_index, edge_weight, edge_attr)

        h = self.lin1(h)
        h = self.act(h)
        h = self.lin2(h)

        out = scatter(h, batch.batch, dim=0, reduce=self.readout)
        return out


class DimeNetPPModel(DimeNetPlusPlus):
    def __init__(
        self, 
        hidden_channels: int = 128, 
        in_dim: int = 1,
        out_dim: int = 1, 
        num_layers: int = 4, 
        int_emb_size: int = 64, 
        basis_emb_size: int = 8, 
        out_emb_channels: int = 256, 
        num_spherical: int = 7, 
        num_radial: int = 6, 
        cutoff: float = 10, 
        max_num_neighbors: int = 32, 
        envelope_exponent: int = 5, 
        num_before_skip: int = 1, 
        num_after_skip: int = 2, 
        num_output_layers: int = 3, 
        act: Union[str, Callable] = 'swish'
    ):
        super().__init__(hidden_channels, out_dim, num_layers, int_emb_size, basis_emb_size, out_emb_channels, num_spherical, num_radial, cutoff, max_num_neighbors, envelope_exponent, num_before_skip, num_after_skip, num_output_layers, act)

    def forward(self, batch):
        
        i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets(
            batch.edge_index, num_nodes=batch.atoms.size(0))

        # Calculate distances.
        dist = (batch.pos[i] - batch.pos[j]).pow(2).sum(dim=-1).sqrt()

        # Calculate angles.
        pos_i = batch.pos[idx_i]
        pos_ji, pos_ki = batch.pos[idx_j] - pos_i, batch.pos[idx_k] - pos_i
        a = (pos_ji * pos_ki).sum(dim=-1)
        b = torch.cross(pos_ji, pos_ki).norm(dim=-1)
        angle = torch.atan2(b, a)

        rbf = self.rbf(dist)
        sbf = self.sbf(dist, angle, idx_kj)

        # Embedding block.
        x = self.emb(batch.atoms, rbf, i, j)
        P = self.output_blocks[0](x, rbf, i, num_nodes=batch.pos.size(0))

        # Interaction blocks.
        for interaction_block, output_block in zip(self.interaction_blocks,
                                                   self.output_blocks[1:]):
            x = interaction_block(x, rbf, sbf, idx_kj, idx_ji)
            P += output_block(x, rbf, i)

        return P.sum(dim=0) if batch is None else scatter(P, batch.batch, dim=0)


class OriginalMACEModel(torch.nn.Module):
    def __init__(
        self,
        r_max: float = 10.0,
        num_bessel: int = 8,
        num_polynomial_cutoff: int = 5,
        max_ell: int = 2,
        interaction_cls: Type[InteractionBlock] = interaction_classes["RealAgnosticResidualInteractionBlock"],
        interaction_cls_first: Type[InteractionBlock] = interaction_classes["RealAgnosticInteractionBlock"],
        num_interactions: int = 2,
        num_elements: int = 1,
        hidden_irreps: o3.Irreps = o3.Irreps("64x0e + 64x1o + 64x2e"),
        MLP_irreps: o3.Irreps = o3.Irreps("64x0e"),
        irreps_out: o3.Irreps = o3.Irreps("1x0e"),
        avg_num_neighbors: int = 1,
        correlation: int = 3,
        gate: Optional[Callable] = gate_dict["silu"],
        num_layers=2,
        in_dim=1,
        out_dim=1,
    ):
        super().__init__()
        self.r_max = r_max
        self.num_elements = num_elements
        # Embedding
        node_attr_irreps = o3.Irreps([(num_elements, (0, 1))])
        node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))])
        self.node_embedding = LinearNodeEmbeddingBlock(
            irreps_in=node_attr_irreps, irreps_out=node_feats_irreps
        )
        self.radial_embedding = RadialEmbeddingBlock(
            r_max=r_max,
            num_bessel=num_bessel,
            num_polynomial_cutoff=num_polynomial_cutoff,
        )
        edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e")

        sh_irreps = o3.Irreps.spherical_harmonics(max_ell)
        num_features = hidden_irreps.count(o3.Irrep(0, 1))
        interaction_irreps = (sh_irreps * num_features).sort()[0].simplify()
        self.spherical_harmonics = o3.SphericalHarmonics(
            sh_irreps, normalize=True, normalization="component"
        )

        # Interactions and readout
        self.atomic_energies_fn = LinearReadoutBlock(node_feats_irreps, irreps_out)

        inter = interaction_cls_first(
            node_attrs_irreps=node_attr_irreps,
            node_feats_irreps=node_feats_irreps,
            edge_attrs_irreps=sh_irreps,
            edge_feats_irreps=edge_feats_irreps,
            target_irreps=interaction_irreps,
            hidden_irreps=hidden_irreps,
            avg_num_neighbors=avg_num_neighbors,
        )
        self.interactions = torch.nn.ModuleList([inter])

        # Use the appropriate self connection at the first layer for proper E0
        use_sc_first = False
        if "Residual" in str(interaction_cls_first):
            use_sc_first = True

        node_feats_irreps_out = inter.target_irreps
        prod = EquivariantProductBasisBlock(
            node_feats_irreps=node_feats_irreps_out,
            target_irreps=hidden_irreps,
            correlation=correlation,
            element_dependent=True,
            num_elements=num_elements,
            use_sc=use_sc_first,
        )
        self.products = torch.nn.ModuleList([prod])

        self.readouts = torch.nn.ModuleList()
        self.readouts.append(LinearReadoutBlock(hidden_irreps, irreps_out))

        for i in range(num_interactions - 1):
            if i == num_interactions - 2:
                hidden_irreps_out = str(
                    hidden_irreps[0]
                )  # Select only scalars for last layer
            else:
                hidden_irreps_out = hidden_irreps
            inter = interaction_cls(
                node_attrs_irreps=node_attr_irreps,
                node_feats_irreps=hidden_irreps,
                edge_attrs_irreps=sh_irreps,
                edge_feats_irreps=edge_feats_irreps,
                target_irreps=interaction_irreps,
                hidden_irreps=hidden_irreps_out,
                avg_num_neighbors=avg_num_neighbors,
            )
            self.interactions.append(inter)
            prod = EquivariantProductBasisBlock(
                node_feats_irreps=interaction_irreps,
                target_irreps=hidden_irreps_out,
                correlation=correlation,
                element_dependent=True,
                num_elements=num_elements,
                use_sc=True
            )
            self.products.append(prod)
            if i == num_interactions - 2:
                self.readouts.append(
                    NonLinearReadoutBlock(hidden_irreps_out, MLP_irreps, gate, irreps_out)
                )
            else:
                self.readouts.append(LinearReadoutBlock(hidden_irreps, irreps_out))

    def forward(self, batch):
        # MACE expects one-hot-ified input
        batch.atoms.unsqueeze_(-1)
        shape = batch.atoms.shape[:-1] + (self.num_elements,)
        node_attrs = torch.zeros(shape, device=batch.atoms.device).view(shape)
        node_attrs.scatter_(dim=-1, index=batch.atoms, value=1)

        # Node embeddings
        node_feats = self.node_embedding(node_attrs)
        node_e0 = self.atomic_energies_fn(node_feats)
        e0 = scatter(node_e0, batch.batch, dim=0, reduce="sum")  # [n_graphs, irreps_out]
        
        # Edge features
        vectors = batch.pos[batch.edge_index[0]] - batch.pos[batch.edge_index[1]]  # [n_edges, 3]
        lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True)  # [n_edges, 1]
        edge_attrs = self.spherical_harmonics(vectors)
        edge_feats = self.radial_embedding(lengths)

        # Interactions
        energies = [e0]
        for interaction, product, readout in zip(
            self.interactions, self.products, self.readouts
        ):
            node_feats, sc = interaction(
                node_attrs=node_attrs,
                node_feats=node_feats,
                edge_attrs=edge_attrs,
                edge_feats=edge_feats,
                edge_index=batch.edge_index,
            )
            node_feats = product(
                node_feats=node_feats, sc=sc, node_attrs=node_attrs
            )
            node_energies = readout(node_feats).squeeze(-1)  # [n_nodes, irreps_out]
            energy = scatter(node_energies, batch.batch, dim=0, reduce="sum")  # [n_graphs, irreps_out]
            energies.append(energy)

        # Sum over energy contributions
        contributions = torch.stack(energies, dim=-1)
        total_energy = torch.sum(contributions, dim=-1)  # [n_graphs, irreps_out]

        return total_energy
