import einops
import torch
from torch import nn

from models.base.single_model_base import SingleModelBase
from optimizers.param_group_modifiers.exclude_from_wd_by_name_modifier import ExcludeFromWdByNameModifier
from custommodules.vit import VitPatchEmbed, VitPosEmbed
from custommodules.transformer import PerceiverPoolingBlock, PrenormBlock

class RansGridVit(SingleModelBase):
    def __init__(
            self,
            patch_size,
            dim,
            depth,
            num_attn_heads,
            num_output_tokens,
            init_weights="xavier_uniform",
            **kwargs,
    ):
        super().__init__(**kwargs)
        self.patch_size = patch_size
        self.dim = dim
        self.depth = depth
        self.num_attn_heads = num_attn_heads
        self.num_output_tokens = num_output_tokens
        self.resolution = self.data_container.get_dataset().grid_resolution
        self.ndim = len(self.resolution)
        # sdf + grid_pos
        if self.data_container.get_dataset().concat_pos_to_sdf:
            input_dim = 4
        else:
            input_dim = 1

        self.patch_embed = VitPatchEmbed(
            dim=dim,
            num_channels=input_dim,
            resolution=self.resolution,
            patch_size=patch_size
        )
        self.pos_embed = VitPosEmbed(
            seqlens=self.patch_embed.seqlens,
            dim=dim,
            is_learnable=True,
            allow_interpolation=False,
        )
        self.blocks = nn.ModuleList([
            PrenormBlock(dim=dim, num_heads=num_attn_heads, init_weights=init_weights)
            for _ in range(depth)
        ])
        self.perceiver = PerceiverPoolingBlock(
            dim=dim,
            num_heads=num_attn_heads,
            num_query_tokens=num_output_tokens,
            perceiver_kwargs=dict(init_weights=init_weights),
        )

        self.type_token = nn.Parameter(torch.empty(size=(1, 1, dim,)))

        self.static_ctx["grid_resolution"] = self.resolution
        self.static_ctx["ndim"] = self.ndim
        self.output_shape = (self.patch_embed.num_patches, dim)

    def model_specific_initialization(self):
        nn.init.trunc_normal_(self.type_token)

    def get_model_specific_param_group_modifiers(self):
        return [ExcludeFromWdByNameModifier(name="type_token")]

    def forward(self, x):
        # sdf is passed as dim-last with spatial -> convert to dim-first with spatial
        x = einops.rearrange(x, "batch_size height width depth dim -> batch_size dim height width depth")
        # embed
        x = self.patch_embed(x)
        x = self.pos_embed(x)
        # flatten
        x = einops.rearrange(x, "batch_size height width depth dim -> batch_size (height width depth) dim")
        # transformer
        for block in self.blocks:
            x = block(x)
        # perceiver
        x = self.perceiver(x)
        x = x + self.type_token
        return x
