import torch
import torch.nn as nn
from einops import rearrange, repeat

from src.modules.act import GEGLU
from src.modules.torch_modules import Attention, FeedForward, PreNorm, cache_fn


class PerceiverEncoder(nn.Module):
    def __init__(
        self,
        dim: int,
        num_latents: int,
        latent_dim: int,
        supernode_pooling: nn.Module,
        cross_heads: int = 1,
        cross_dim_head: int = 64,
        latent_heads: int = 4,
        latent_dim_head: int = 64,
        weight_tie_layers=False,
        depth: int = 4,
        act: nn.Module = GEGLU,
    ):
        super().__init__()
        self.supernode_pooling = supernode_pooling

        self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))

        self.cross_attend_blocks = nn.ModuleList(
            [
                PreNorm(
                    latent_dim,
                    Attention(latent_dim, dim, heads=cross_heads, dim_head=cross_dim_head),
                    context_dim=dim,
                ),
                PreNorm(latent_dim, FeedForward(latent_dim, act=act)),
            ]
        )

        get_latent_attn = lambda: PreNorm(
            latent_dim,
            Attention(latent_dim, heads=latent_heads, dim_head=latent_dim_head),
        )
        get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, act=act))

        get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))

        self.layers = nn.ModuleList([])
        cache_args = {"_cache": weight_tie_layers}

        for _ in range(depth):
            self.layers.append(
                nn.ModuleList([get_latent_attn(**cache_args), get_latent_ff(**cache_args)])
            )

    def forward(
        self,
        pos,
        occ,
        field,
        batch_index,
        supernode_index,
        supernode_batch_index,
        mask=None,
    ):
        batch_size = batch_index.max().item() + 1
        queries = repeat(self.latents, "n d -> b n d", b=batch_size)

        x = torch.cat([occ, field.unsqueeze(1)], dim=-1)
        x = self.supernode_pooling(
            x=x,
            pos=pos,
            batch_index=batch_index,
            supernode_index=supernode_index,
            super_node_batch_index=supernode_batch_index,
        )
        x = rearrange(
            x,
            "(batch_size num_supernodes) dim -> batch_size num_supernodes dim",
            batch_size=batch_size,
        )

        cross_attn, cross_ff = self.cross_attend_blocks

        x = cross_attn(queries, context=x, mask=mask) + queries
        x = cross_ff(x) + x

        for self_attn, self_ff in self.layers:
            x = self_attn(x) + x
            x = self_ff(x) + x

        return x
