import torch
from torch import nn

from coarsebind_public.mol_encoder.models.transformer.attention import AttentionBlock


class StoichiometryEncoder(nn.Module):
    """
    explicitly encodes repeated atoms
    to avoid weak sum coding in transformers.
    """

    def __init__(self, dim=256, layers=4, MAX_ATOM_COUNT=3000, MAX_ATOMIC_NUMBER=83):
        super().__init__()
        assert dim % 2 == 0
        self.dim = dim
        self.MAX_ATOMIC_NUMBER = MAX_ATOMIC_NUMBER
        self.atom_count_emb = nn.Embedding(MAX_ATOM_COUNT + 1, dim // 2, padding_idx=0)
        self.atom_emb = nn.Embedding(MAX_ATOMIC_NUMBER + 1, dim // 2, padding_idx=0)
        self.blocks = torch.nn.ModuleList([AttentionBlock(dim, 16) for _ in range(layers)])

    def forward(self, atoms):
        """
        Atoms: batch X max_n_atom
        """
        counts = torch.stack(
            [
                torch.bincount(
                    row[row < self.MAX_ATOMIC_NUMBER],
                    weights=None,
                    minlength=self.MAX_ATOMIC_NUMBER + 1,
                )
                for row in atoms
            ],
            0,
        )[:, 1 : self.MAX_ATOMIC_NUMBER]
        idxs = (
            torch.arange(counts.shape[-1], device=atoms.device)
            .unsqueeze(0)
            .repeat(atoms.shape[0], 1)
        )
        Is, Js = torch.nonzero(counts, as_tuple=True)
        denominator = torch.bincount(Is, minlength=atoms.shape[0])
        nzc = self.atom_count_emb(counts[Is, Js])
        nzi = self.atom_emb(idxs[Is, Js])
        to_attend = torch.zeros(
            (atoms.shape[0] * (Js.max() + 1), self.dim),
            device=atoms.device,
            dtype=torch.float,
        )
        to_attend.index_add_(
            dim=0, index=Is * (Js.max() + 1) + Js, source=torch.cat([nzc, nzi], -1)
        )
        to_attend = to_attend.reshape(atoms.shape[0], (Js.max().item() + 1), self.dim)
        for block in self.blocks:
            to_attend = block(to_attend)
        return to_attend.sum(1) / torch.sqrt(denominator + 1).unsqueeze(-1)
