import torch
import torch.nn as nn
from einops import rearrange
from kappamodules.layers import ContinuousSincosEmbed

from src.modules.act import GEGLU
from src.modules.torch_modules import Attention, FeedForward, PreNorm


class PerceiverDecoder(nn.Module):
    def __init__(
        self,
        queries_dim,
        latent_dim,
        cross_heads=1,
        cross_dim_head=64,
        n_particle_types=1,
        n_channels=1,
        act=GEGLU,
        ndim: int = 3,
    ):
        super().__init__()
        self.pos_embed = ContinuousSincosEmbed(dim=queries_dim, ndim=ndim)
        self.pos_vector = nn.Parameter(torch.randn(1, queries_dim))

        self.feat_decoder = PreNorm(
            queries_dim,
            Attention(queries_dim, latent_dim, heads=cross_heads, dim_head=cross_dim_head),
            context_dim=latent_dim,
        )
        self.feat_ff = PreNorm(queries_dim, FeedForward(queries_dim, act=act))
        self.feat_pred = nn.Sequential(
            nn.LayerNorm(queries_dim),
            nn.Linear(queries_dim, n_channels),
        )

        self.occ_decoder = PreNorm(
            queries_dim,
            Attention(queries_dim, latent_dim, heads=cross_heads, dim_head=cross_dim_head),
            context_dim=latent_dim,
        )
        self.occ_ff = PreNorm(queries_dim, FeedForward(queries_dim, act=act))
        self.occ_pred = nn.Sequential(
            nn.LayerNorm(queries_dim),
            nn.Linear(queries_dim, n_particle_types),
        )

    def forward(self, x, pos, occ_pos=None, condition=None):
        batch_size = x.shape[0]
        pos_embed = self.pos_embed(pos)
        pos_embed = pos_embed  # + self.pos_vector 
        pos_embed = rearrange(
            pos_embed,
            "(batch_size num_points) dim -> batch_size num_points dim",
            batch_size=batch_size,
        )

        feat_out = self.feat_decoder(pos_embed, context=x)
        feat_out = self.feat_ff(feat_out)
        feat_out = self.feat_pred(feat_out)
        feat_out = rearrange(
            feat_out, "batch_size max_num_points dim -> (batch_size max_num_points) dim"
        )

        if occ_pos is not None:
            pos_embed = self.pos_embed(occ_pos)
            pos_embed = pos_embed  # + self.pos_vector 
            pos_embed = rearrange(
                pos_embed,
                "(batch_size num_points) dim -> batch_size num_points dim",
                batch_size=batch_size,
            )

        occ_out = self.occ_decoder(pos_embed, context=x)
        occ_out = self.occ_ff(occ_out)
        occ_out = self.occ_pred(occ_out)
        occ_out = rearrange(
            occ_out, "batch_size max_num_points dim -> (batch_size max_num_points) dim"
        )

        return occ_out, feat_out
