from functools import partial

import torch
from einops import rearrange
from kappamodules.layers import LinearProjection
from kappamodules.transformer import (
    DitBlock,
    DitPerceiverPoolingBlock,
    PerceiverPoolingBlock,
    PrenormBlock,
)
from torch import nn

from src.modules.act import GEGLU


class UptPoolTransformerPerceiver(nn.Module):
    def __init__(
        self,
        gnn_dim,
        enc_dim,
        perc_dim,
        enc_depth,
        enc_num_attn_heads,
        perc_num_attn_heads,
        supernode_pooling,
        num_latent_tokens=None,
        drop_path_rate=0.0,
        init_weights="truncnormal",
        condition_dim=None,
        output_ln=False,
        act: nn.Module = GEGLU,
    ):
        super().__init__()
        self.gnn_dim = gnn_dim
        self.enc_dim = enc_dim
        self.perc_dim = perc_dim
        self.enc_depth = enc_depth
        self.enc_num_attn_heads = enc_num_attn_heads
        self.perc_num_attn_heads = perc_num_attn_heads
        self.num_latent_tokens = num_latent_tokens
        self.drop_path_rate = drop_path_rate
        self.init_weights = init_weights
        self.condition_dim = condition_dim
        self.output_ln = output_ln

        # input_shape is (None, input_dim)
        self.supernode_pooling = supernode_pooling

        # blocks
        self.enc_proj = LinearProjection(gnn_dim, enc_dim, init_weights=init_weights)
        if self.condition_dim is not None:
            block_ctor = partial(DitBlock, cond_dim=self.condition_dim)
        else:
            block_ctor = PrenormBlock
        self.blocks = nn.ModuleList(
            [
                block_ctor(
                    dim=enc_dim,
                    num_heads=enc_num_attn_heads,
                    init_weights=init_weights,
                    drop_path=drop_path_rate,
                )
                for _ in range(enc_depth)
            ]
        )

        # perceiver pooling
        self.perc_proj = LinearProjection(enc_dim, perc_dim, init_weights=init_weights)

        if self.condition_dim is not None:
            block_ctor = partial(
                DitPerceiverPoolingBlock,
                perceiver_kwargs=dict(
                    cond_dim=self.condition_dim,
                    init_weights=init_weights,
                ),
            )
        else:
            block_ctor = partial(
                PerceiverPoolingBlock,
                perceiver_kwargs=dict(init_weights=init_weights),
            )
        self.perceiver = block_ctor(
            dim=perc_dim,
            num_heads=perc_num_attn_heads,
            num_query_tokens=num_latent_tokens,
        )

        # output shape
        self.output_shape = (num_latent_tokens, perc_dim)

    def forward(
        self,
        field,
        pos,
        batch_index,
        supernode_index,
        supernode_batch_index,
        condition=None,
    ):
        batch_size = batch_index.max().item() + 1
        x = self.supernode_pooling(
            x=field,
            pos=pos,
            batch_index=batch_index,
            supernode_index=supernode_index,
            super_node_batch_index=supernode_batch_index,
        )
        x = rearrange(
            x,
            "(batch_size num_supernodes) dim -> batch_size num_supernodes dim",
            batch_size=batch_size,
        )

        # apply blocks
        block_kwargs = {}
        if condition is not None:
            block_kwargs["cond"] = condition
        x = self.enc_proj(x)
        for blk in self.blocks:
            x = blk(x, **block_kwargs)

        # perceiver
        x = self.perc_proj(x)
        x = self.perceiver(kv=x, **block_kwargs)

        if self.output_ln:
            x = nn.functional.layer_norm(x, (self.perc_dim,), eps=1e-6)

        return x
