from typing import *
import numpy as np
import torch
import torch.nn as nn 
from einops import rearrange

from src.models.tokenizer import Downsample, RMSGroupNorm, conv_module, SubsampledLinear
from src.models.attention import SpaceTimeBlock
from src.utils.database import standardize


class Encoder(nn.Module):

    def __init__(self, embed_dim: int, spatial_ndims: int, padding_mode: int, groups: int):
        super().__init__()
        self.encoder = nn.Sequential(*[
            Downsample(embed_dim//4, embed_dim//4, kernel_size=4, stride=4, padding_mode=padding_mode, bias=False, ndim=spatial_ndims),
            RMSGroupNorm(groups, embed_dim//4, affine=True),
            nn.GELU(),
            Downsample(embed_dim//4, embed_dim//4, kernel_size=2, stride=2, padding_mode=padding_mode, bias=False, ndim=spatial_ndims),
            RMSGroupNorm(groups, embed_dim//4, affine=True),
            nn.GELU(),
            Downsample(embed_dim//4, embed_dim, kernel_size=2, stride=2, padding_mode=padding_mode, bias=False, ndim=spatial_ndims),
            RMSGroupNorm(groups, embed_dim, affine=True),
        ])
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.encoder(x)
    

class ParamProj(nn.Module):

    def __init__(self, hidden_dim: int, out_dim: int):
        super().__init__()
        self.proj = nn.Sequential(*[
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, out_dim),
        ])

    def forward(self, x: torch.Tensor):
        return self.proj(x)


class Hypernetwork(nn.Module):

    def __init__(self, 
        in_chans: int | None, 
        out_chans: int | Dict,
        embed_dim: int, 
        spatial_ndims: int | Dict, 
        padding_mode: str,
        groups: int,
        processor_blocks: int,
        drop_path: float,
        num_heads: int,
        bias_type: str,
        finetune: bool = False, 

    ) -> None:
        super().__init__()

        self.customized = in_chans is None
        if self.customized:
            # encoder
            n_states = 12
            if finetune:
                n_states += 5  # accounting for shearflow and euler finetuning experiments
            self.space_bag = SubsampledLinear(dim_in=n_states, dim_out=embed_dim//4, subsample_in=True)
            self.encoder = nn.ModuleDict({
                k: Encoder(embed_dim, spatial_ndims[k], padding_mode[k], groups)
                for k in spatial_ndims
            })

            # regressor
            self.proj_channel = nn.ModuleDict({
                k: ParamProj(embed_dim, out_chans[k])
                for k in out_chans
            })
        else:
            # encoder
            self.space_bag = conv_module(2,False)(in_chans, embed_dim//4, kernel_size=1, stride=1, padding_mode=padding_mode, bias=True)
            self.encoder = Encoder(embed_dim, spatial_ndims, padding_mode, groups)

            # regressor 
            self.proj_channel = ParamProj(embed_dim, out_chans)

        # processor (common to different datasets)
        self.dp = np.linspace(0, drop_path, processor_blocks)
        self.blocks = nn.ModuleList([
            SpaceTimeBlock(dim=embed_dim, num_heads=num_heads, bias_type=bias_type, drop_path=dp)
            for dp in self.dp
        ])

    def forward(self, x: torch.Tensor, state_labels: torch.Tensor, dset_name: str) -> torch.Tensor:

        # dimensions
        B = x.shape[0]
        spatial_dims = tuple(range(3,x.squeeze((-2, -1)).ndim))

        # preprocess
        x = standardize(x, dims=(1,*spatial_dims), return_stats=False)

        # adapt to the number of in channels
        if self.customized:
            x = rearrange(x, 'b t c h w -> (b t) h w c')
            x = self.space_bag(x, state_labels)
            x = rearrange(x, 'bt h w c -> bt c h w')
        else:
            x = rearrange(x, 'b t c h w -> (b t) c h w')
            x = self.space_bag(x)

        # encode
        x = x.squeeze((-2, -1))
        if self.customized:
            x = self.encoder[dset_name](x)
        else:
            x = self.encoder(x)
        if x.ndim == 3:
            x = x.unsqueeze(-1)
        x = rearrange(x, '(b t) c h w ->  b t c h w', b=B)

        # attention layers
        all_att_maps = []
        for blk in self.blocks:
            x, att_maps = blk(x, return_att=False)
            all_att_maps += att_maps

        # regression layers
        x = x.mean(1, keepdim=True)  # average over time
        x = x.mean((-2,-1))  # average over space
        if self.customized:
            x = self.proj_channel[dset_name](x)[:,0,:]
        else:
            x = self.proj_channel(x)[:,0,:]

        return x
