import torch
import torch.nn as nn
from modules.attention import ContinuousPositionBias, Transformer
from einops import rearrange

class STTransformer(nn.Module):
    def __init__(self,
                 enc_hidden_size, 
                 patch_size=4, 
                 dim_head=64, 
                 heads=16, 
                 attn_dropout=0.,
                 ff_dropout=0.1,
                 peg=True,
                 peg_causal=True,
                 spatial_depth=1, 
                 temporal_depth=1,
                 causal=True):
        super(STTransformer, self).__init__()
        enc_spatial_transformer_kwargs = dict(
            dim = enc_hidden_size // (patch_size ** 2),
            dim_head = dim_head,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            peg = peg,
            peg_causal = peg_causal,
            peg_spatial_or_temporal = 'spatial',
        )

        enc_temporal_transformer_kwargs = dict(
            dim = enc_hidden_size // (patch_size ** 2),
            dim_head = dim_head,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            peg = peg,
            peg_causal = peg_causal,
            peg_spatial_or_temporal = 'temporal',
            causal = causal,
        )

        self.spatial_rel_pos_bias = ContinuousPositionBias(dim = enc_hidden_size // (patch_size ** 2), heads = heads)
        self.enc_spatial_transformer = Transformer(depth = spatial_depth, **enc_spatial_transformer_kwargs)
        self.enc_temporal_transformer = Transformer(depth = temporal_depth, **enc_temporal_transformer_kwargs)
        
    
    def forward(self, x):
        b, t, h, w, d = x.shape
        tokens = rearrange(x, 'b t h w d -> (b t) (h w) d')
        attn_bias = self.spatial_rel_pos_bias(h, w, device=tokens.device)
        tokens = self.enc_spatial_transformer(
            tokens, 
            attn_bias=attn_bias, 
            video_shape=(b, t, h, w)
        )
        tokens = rearrange(tokens, '(b t) (h w) d -> b t h w d', b=b, t=t, h=h, w=w)
        
        tokens = rearrange(tokens, 'b t h w d -> (b h w) t d')
        tokens = self.enc_temporal_transformer(tokens, video_shape=(b, t, h, w))
        tokens = rearrange(tokens, '(b h w) t d -> b t h w d', b=b, t=t, h=h, w=w)

        return tokens
    
    def to(self, device):
        self.spatial_rel_pos_bias = self.spatial_rel_pos_bias.to(device)
        self.enc_spatial_transformer = self.enc_spatial_transformer.to(device)
        self.enc_temporal_transformer = self.enc_temporal_transformer.to(device)
        return super().to(device)