import torch
import torch.nn as nn
from einops import rearrange
from modules.attention import ContinuousPositionBias, Dual_attention_Transformer

class MLP_Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, hidden_depth):
        super(MLP_Encoder, self).__init__()
        
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        
        self.hidden_layers = nn.ModuleList()
        for i in range(hidden_depth):
            self.hidden_layers.append(nn.Sequential(
                nn.LayerNorm(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            ))
        
        self.output_proj = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, x):
        h = self.input_proj(x)
        
        for layer in self.hidden_layers:
            h = h + layer(h)
        
        out = self.output_proj(h)
        
        return out
    
    
class MLP_Decoder(nn.Module):
    def __init__(self, n_factors, hidden_dim, output_dim, hidden_depth):
        super(MLP_Decoder, self).__init__()
        
        self.input_proj = nn.Linear(n_factors, hidden_dim)
        
        self.hidden_layers = nn.ModuleList()
        for i in range(hidden_depth):
            self.hidden_layers.append(nn.Sequential(
                nn.LayerNorm(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            ))
        
        self.output_proj = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, z):
        h = self.input_proj(z)
        
        for layer in self.hidden_layers:
            h = h + layer(h)
        
        out = self.output_proj(h)
        
        return out
    

class Fusion_decoder(nn.Module):
    def __init__(self, 
                 e_dim, 
                 patch_size=4, 
                 dim_head=64, 
                 heads=16, 
                 attn_dropout=0.,
                 ff_dropout=0.1,
                 peg=True,
                 peg_causal=True,
                 spatial_depth=1, 
                 dim_context=64,
                 has_cross_attn=True,
                 ):

        super(Fusion_decoder, self).__init__()

        dec_spatial_transformer_kwargs = dict(
            dim = e_dim // (patch_size ** 2),
            dim_head = dim_head,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            peg = peg,
            peg_causal = peg_causal,
            has_cross_attn = has_cross_attn,
            dim_context = dim_context,
            peg_spatial_or_temporal = 'spatial',
        )

        self.spatial_rel_pos_bias = ContinuousPositionBias(dim = e_dim, heads = heads)
        self.dec_spatial_transformer = Dual_attention_Transformer(depth = spatial_depth, **dec_spatial_transformer_kwargs)
    
    def forward(self, x, context=(None, None)):
        b, t, h, w, d = x.shape
        
        tokens = rearrange(x, 'b t h w d -> (b t) (h w) d')
        context_1 = rearrange(context[0], 'b t h w d -> (b t) (h w) d') if context[0] is not None else None
        context_2 = rearrange(context[1], 'b t h w d -> (b t) (h w) d') if context[1] is not None else None
        attn_bias = self.spatial_rel_pos_bias(h, w, device=tokens.device)
        tokens = self.dec_spatial_transformer(
            tokens, 
            context=(context_1, context_2),
            attn_bias=attn_bias, 
            video_shape=(b, t, h, w)
        )
        tokens = rearrange(tokens, '(b t) (h w) d -> b t h w d', b=b, t=t, h=h, w=w)

        return tokens
    
    def to(self, device):
        self.spatial_rel_pos_bias = self.spatial_rel_pos_bias.to(device)
        self.dec_spatial_transformer = self.dec_spatial_transformer.to(device)
        return super().to(device)