#
# a very basic GCN with few bells and whistles.
#

import torch
from torch import nn


from coarsebind_public.mol_encoder.models.encoder_graph.graph_utils import (
    get_moment_encoding,
    triu_to_sparse,
)
from coarsebind_public.mol_encoder.models.encoder_3d.attentive_allegro import SwiGLUNet
from coarsebind_public.mol_encoder.models.loose_modules.stoich_encoding import (
    StoichiometryEncoder,
)
from coarsebind_public.mol_encoder.models.transformer.attention import (
    MaskedAttentionBlock,
    AttAgg,
)


def gather_message(T, Is, Js, batch_size=0, max_n_atoms=0, vp=True):
    """
    A variance preserving message gather.
    """
    tore = torch.zeros((batch_size * max_n_atoms, T.shape[-1]), device=T.device, dtype=T.dtype)
    tore.index_add_(dim=0, index=Is * (max_n_atoms) + Js, source=T)
    if vp:
        tore_denom = torch.zeros(
            (batch_size * max_n_atoms, T.shape[-1]), device=T.device, dtype=T.dtype
        )
        tore_denom.index_add_(dim=0, index=Is * (max_n_atoms) + Js, source=torch.ones_like(T))
        return (tore / (1e-6 + tore_denom).sqrt()).reshape(batch_size, max_n_atoms, T.shape[-1])
    else:
        return tore.reshape(batch_size, max_n_atoms, T.shape[-1])


class AttentiveGCNLayer(torch.nn.Module):
    def __init__(
        self,
        dim_node=256,
        dim_edge=256,
        dim_global=512,
        last_layer=False,
        max_nodes=80,
    ):
        super().__init__()
        self.last_layer = last_layer

        self.dim_node = dim_node
        self.dim_edge = dim_edge
        self.dim_global = dim_global

        self.node_attention_block = MaskedAttentionBlock(self.dim_node, self.dim_node // 8)
        self.X_to_dY = SwiGLUNet(self.dim_edge, self.dim_node)
        if not self.last_layer:
            self.dim_two_body = 2 * self.dim_node + self.dim_global
            self.two_body = SwiGLUNet(self.dim_two_body, self.dim_edge)
            self.global_agg = AttAgg(
                self.dim_node,
                n_head=16,
                n_layers=2,
                n_out_tokens=self.dim_global // self.dim_node,
            )

    def forward(self, Y, X, G, Is, Js, Ks, node_mask):
        """
        Y: batch X max_n_nodes X dim_node
        X: batch X max_n_edges X dim_edge
        G: batch X dim_global
        Is: edges
        Js: edges
        Ks: edges
        node_mask:  batch X max_n_nodes
        """
        batch_size = Y.shape[0]
        dY_from_X = gather_message(
            self.X_to_dY(X),
            Is,
            Js,
            batch_size=batch_size,
            max_n_atoms=Y.shape[1],
            vp=True,
        )
        to_attend = Y + dY_from_X
        Y = Y + self.node_attention_block(to_attend, node_mask)
        if self.last_layer:
            return Y, None, None
        else:
            G = G + self.global_agg(Y, node_mask)
            dX = self.two_body(
                torch.cat(
                    [
                        Y[Is, Js],
                        Y[Is, Ks],
                        G[Is],
                    ],
                    -1,
                )
            )
            X = X + dX
            return Y, X, G


class AttentiveGCN(torch.nn.Module):
    def __init__(
        self,
        n_layers=6,
        dim_node=256,
        dim_edge=256,
        dim_global=512,
        max_nodes=80,
        max_node_types=50,
        max_edge_types=50,
        n_out_tokens=1,
        do_lap_pe=False,
    ):
        """ """
        super().__init__()
        self.n_layers = n_layers
        self.dim_node = dim_node
        self.dim_edge = dim_edge
        self.dim_global = dim_global
        self.max_nodes = max_nodes
        self.max_node_types = max_node_types
        self.max_edge_types = max_edge_types
        self.n_out_tokens = n_out_tokens

        self.stoich_embed = StoichiometryEncoder(dim=dim_global)
        self.node_embed = nn.Embedding(self.max_node_types, self.dim_node, padding_idx=0)
        self.edge_embed = nn.Embedding(self.max_edge_types, self.dim_edge, padding_idx=0)
        # also do a laplacian position embed
        # after I translate miles' out of sparse
        if do_lap_pe:
            self.n_pe_dim = 16
        else:
            self.n_pe_dim = 0
        self.get_y0 = SwiGLUNet(
            self.dim_node + self.n_pe_dim, self.dim_node
        )  # does allegro_xy_oh + the graph types.
        self.dim_two_body = 2 * self.dim_node + self.dim_edge + self.dim_global
        self.two_body = SwiGLUNet(self.dim_two_body, self.dim_edge)
        self.layers = torch.nn.ModuleList(
            [
                AttentiveGCNLayer(
                    dim_node=dim_node,
                    dim_edge=dim_edge,
                    dim_global=dim_global,
                    max_nodes=max_nodes,
                    last_layer=layer_idx == n_layers - 1,
                )
                for layer_idx in range(n_layers)
            ]
        )
        self.global_agg = AttAgg(
            self.dim_node,
            n_head=16,
            n_layers=2,
            n_out_tokens=self.dim_global * n_out_tokens // self.dim_node,
        )

    def forward(self, atoms, nodes, edges_, graph_tokenizer, apply_global_agg=True):
        """
        nodes: batch X max_n_atoms # ZERO IS PAD
        edges: batch X max_edges (dense coding.)
        """
        node_mask = atoms > 0
        node_embs = self.node_embed(nodes)

        # Convert the dense coded edges into Is,Js,Ks, and edges
        Is, Js, Ks, edges, adj = triu_to_sparse(edges_)

        edge_embs = self.edge_embed(edges)
        if self.n_pe_dim:
            Y_pe = get_moment_encoding(adj, n_moments=16)
            Y = self.get_y0(torch.cat([node_embs, Y_pe], -1)) * (node_mask.unsqueeze(-1))
        else:
            Y = self.get_y0(node_embs)
        G = self.stoich_embed(atoms)
        X = self.two_body(torch.cat([Y[Is, Js], Y[Is, Ks], edge_embs, G[Is]], -1))

        for layer in self.layers:
            Y, X, G = layer(Y, X, G, Is, Js, Ks, node_mask)

        if apply_global_agg:
            return self.global_agg(Y, node_mask)
        else:
            return Y
