import torch
from torch import nn
import numpy as np
from e3nn import o3

from coarsebind_public.mol_encoder.util.periodic_table import XY_ONE_HOT_FULL


class IrrepWeightLayer(torch.nn.Module):
    """
    one factor for all the elements of an irrep
    ie: 1x0e + 1x0o + 1x1o
     => 3 weights. (not 1+1+3 = 5)

    a hadamard-like product per irrep.
    """

    def __init__(self, irreps_in, dim_equivariant):
        super().__init__()
        self.irreps_in = irreps_in
        self.dim_equivariant = dim_equivariant
        self.irrep_dims = [ir.dim for mul, ir in irreps_in]

    def forward(self, w, Y):
        """
        w: edges X dim_equivariant X n_irreps (NOT dim_irreps)
        Y: edges X dim_irreps (NOT num irreps)
        """
        assert Y.shape[-1] == self.irreps_in.dim
        assert w.shape[-1] == len(self.irreps_in)
        assert sum(self.irrep_dims) == Y.shape[-1]
        to_cat = []
        start = 0
        for I, irrep_d in enumerate(self.irrep_dims):
            to_cat.append(torch.einsum("ij,il->ijl", w[:, :, I], Y[:, start : start + irrep_d]))
            start += irrep_d
        return torch.cat(to_cat, dim=-1)


class allegro_sph_edges(torch.nn.Module):
    irreps_edge_sh: str

    def __init__(self, irreps_edge_sh="1x0e+1x1o+1x2e"):
        """ """
        super().__init__()
        self.irreps_edge_sh = irreps_edge_sh
        self.sh = o3.SphericalHarmonics(self.irreps_edge_sh, True, "component")
        self.dim = o3.Irreps(self.irreps_edge_sh).dim

    def forward(self, edge_vectors):
        """
        Args:
            edge_vectors: batch X edges X 3 tensor of edge vectors
        """
        return self.sh(edge_vectors)


class allegro_xy_oh(torch.nn.Module):
    def __init__(self, out_features=16, dtype=torch.float, device=torch.device("cuda")):
        super().__init__()
        in_node_nf = len(XY_ONE_HOT_FULL(1))
        self.out_features = out_features
        self.embedding = nn.Linear(in_node_nf, self.out_features, dtype=dtype)

        # precreate the one-hot vectors
        periodic_lookup = []
        # TODO: hmmm.. seems to be an issue with atomic_number 89 and there are repeat
        # vectors for eg 0 and 71
        max_atomic_num = 64
        for i in range(max_atomic_num):
            periodic_lookup.append(XY_ONE_HOT_FULL(i))

        self.register_buffer(
            "periodic_lookup",
            torch.tensor(periodic_lookup, dtype=dtype, device=device, requires_grad=False),
        )

        # check that no repeats in the periodic table (i.e. no repeats in the one-hot vectors)
        for i in range(max_atomic_num):
            for j in range(i + 1, max_atomic_num):
                assert not torch.all(self.periodic_lookup[i] == self.periodic_lookup[j])

    def forward(self, atoms):
        """
        Args:
            atoms: batch X max_n_atom long tensor of atomic numbers (0 = mask)
        """
        assert atoms.dim() == 2
        # assert atoms.max() < 63
        assert atoms.min() >= 0
        periodic_emb = self.periodic_lookup[atoms.clamp(0, 63)]
        """
        old_emb = torch.stack(
            [
                torch.tensor(
                    [XY_ONE_HOT_FULL(int(a.item())) for a in row],
                    device=atoms.device,
                    dtype=self.embedding.weight.dtype,
                )
                for row in atoms
            ],
            0,
        )
        assert torch.all(old_emb == periodic_emb)
        """
        return self.embedding(periodic_emb)
