from functools import partial

import einops
import torch
from einops import rearrange
from kappamodules.layers import ContinuousSincosEmbed, LinearProjection
from kappamodules.transformer import DitBlock, DitPerceiverBlock, PerceiverBlock
from kappamodules.vit import VitBlock
from torch import nn

from src.modules.act import GEGLU


class UptTransformerPerceiverOccupancy(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        num_attn_heads,
        pos_embed: nn.Module,
        n_particle_types=None,
        perc_dim=None,
        perc_num_attn_heads=None,
        input_dim=192,
        feat_dim=4,
        drop_path_rate=0.0,
        init_weights="truncnormal",
        condition_dim=None,
        ndim=2,
        act: nn.Module = GEGLU, 
    ):
        super().__init__()
        perc_dim = perc_dim or dim
        perc_num_attn_heads = perc_num_attn_heads or num_attn_heads
        self.dim = dim
        self.depth = depth
        self.num_attn_heads = num_attn_heads
        self.perc_dim = perc_dim
        self.perc_num_attn_heads = perc_num_attn_heads
        self.drop_path_rate = drop_path_rate
        self.init_weights = init_weights
        self.condition_dim = condition_dim
        self.input_dim = input_dim
        self.feat_dim = feat_dim
        self.ndim = ndim

        # input projection
        self.input_proj = LinearProjection(input_dim, dim, init_weights=init_weights)

        # blocks
        if self.condition_dim is not None:
            block_ctor = partial(DitBlock, cond_dim=condition_dim)
        else:
            block_ctor = VitBlock
        self.blocks = nn.ModuleList(
            [
                block_ctor(
                    dim=dim,
                    num_heads=num_attn_heads,
                    init_weights=init_weights,
                    drop_path=drop_path_rate,
                )
                for _ in range(self.depth)
            ]
        )

        # query tokens (create them from a positional embedding)
        # self.pos_embed = SinCos3dEmbed(
        #     out_channels=perc_dim
        # )  #
        self.pos_embed = pos_embed(dim=perc_dim, ndim=ndim)
        # self.pos_vector = nn.Parameter(torch.randn(1, perc_dim))
        self.query_mlp = nn.Sequential(
            LinearProjection(perc_dim, perc_dim * 4, init_weights=init_weights),
            nn.GELU(),
            LinearProjection(perc_dim * 4, perc_dim * 4, init_weights=init_weights),
            nn.GELU(),
            LinearProjection(perc_dim * 4, perc_dim, init_weights=init_weights),
        )

        # perceiver ctor
        if self.condition_dim is not None:
            block_ctor = partial(DitPerceiverBlock, cond_dim=condition_dim)
        else:
            block_ctor = PerceiverBlock

        # latent to pixels
        self.feat_proj = LinearProjection(dim, perc_dim, init_weights=init_weights)
        self.feat_perc = block_ctor(
            dim=perc_dim, num_heads=perc_num_attn_heads, init_weights=init_weights
        )
        self.feat_pred = nn.Sequential(
            nn.LayerNorm(perc_dim, eps=1e-6),
            LinearProjection(perc_dim, feat_dim, init_weights=init_weights),
        )

        # latent to occupancy
        self.occ_proj = LinearProjection(dim, perc_dim, init_weights=init_weights)
        self.occ_perc = block_ctor(
            dim=perc_dim, num_heads=perc_num_attn_heads, init_weights=init_weights
        )
        self.occ_pred = nn.Sequential(
            nn.LayerNorm(perc_dim, eps=1e-6),
            LinearProjection(perc_dim, n_particle_types, init_weights=init_weights),
        )

    def forward(self, x, pos=None, occ_pos=None, condition=None):
        batch_size = x.shape[0]
        assert x.ndim == 3
        block_kwargs = {}
        if condition is not None:
            block_kwargs["cond"] = condition
        # input projection
        x = self.input_proj(x)

        # apply blocks
        for blk in self.blocks:
            x = blk(x, **block_kwargs)
        if pos is not None:
            pos_embed = self.pos_embed(pos)
            # pos_embed = pos_embed + self.pos_vector
            pos_embed = rearrange(
                pos_embed,
                "(batch_size num_points) dim -> batch_size num_points dim",
                batch_size=batch_size,
            )

            feat_query = self.query_mlp(pos_embed)
            feat_out = self.feat_proj(x)
            feat_out = self.feat_perc(q=feat_query, kv=feat_out, **block_kwargs)
            feat_out = self.feat_pred(feat_out)
            feat_out = einops.rearrange(
                feat_out,
                "batch_size max_num_points dim -> (batch_size max_num_points) dim",
            )
        else:
            feat_out = None
        if occ_pos is not None:
            # pos embedding for occupancy positions
            pos_embed = self.pos_embed(occ_pos)
            # pos_embed = pos_embed + self.pos_vector
            pos_embed = rearrange(
                pos_embed,
                "(batch_size num_points) dim -> batch_size num_points dim",
                batch_size=batch_size,
            )

            occ_query = self.query_mlp(pos_embed)
            occ_out = self.occ_proj(x)
            occ_out = self.occ_perc(q=occ_query, kv=occ_out, **block_kwargs)
            occ_out = self.occ_pred(occ_out)
            occ_out = einops.rearrange(
                occ_out,
                "batch_size num_occupancies occ_classes -> (batch_size num_occupancies) occ_classes",
            )
        else:
            occ_out = None
        return occ_out, feat_out
