import torch.nn as nn
import torch
import torch.nn.functional as F
from einops import rearrange
import math
from src.models.blocks.attention1d import BasicTransformerBlock

class lifting(nn.Module):
    def __init__(self, latent_res, in_dim, num_layers) -> None:
        super(lifting, self).__init__()

        embedding_stdev = (1. / math.sqrt(in_dim))
        self.latent_res = latent_res
        self.latent_emb = nn.parameter.Parameter(
                            (torch.rand(self.latent_res, self.latent_res, self.latent_res, in_dim) * embedding_stdev))

        self.transformer = nn.ModuleList()
        n_heads = 8
        head_dim = in_dim // n_heads
        for _ in range(num_layers):
            self.transformer.append(BasicTransformerBlock(in_dim, num_attention_heads=n_heads, attention_head_dim=head_dim, cross_attention_dim=in_dim))

        self.latent_refine = nn.Sequential(
            nn.ConvTranspose3d(in_dim, 256, 4, stride=2, padding=1),
            #nn.Conv3d(in_dim, 256, 3, padding=1),
            nn.BatchNorm3d(256),
            nn.LeakyReLU(inplace=True),
            nn.Conv3d(256, 128, 3, padding=1),
            nn.BatchNorm3d(128),
        )
    
    def forward(self, x):
        '''
        x: 2D features in shape [b,t,c,h,w]
        '''
        b,t,c,h,w = x.shape
        device = x.device
        x = rearrange(x, 'b t c h w -> b (t h w) c')
        latent = rearrange(self.latent_emb, 'd h w c -> (d h w) c').unsqueeze(0).repeat(b,1,1).to(device)  # [b,N=d*h*w,c]
    
        for block in self.transformer:
            latent = block(latent, x)
        latent = rearrange(latent, 'b (d h w) c -> b c d h w', d=self.latent_res, h=self.latent_res, w=self.latent_res)
        latent = self.latent_refine(latent)
        
        return latent


if __name__ == '__main__':
    dim = 768
    x = torch.rand(2, 8, dim, 8, 8)
    model = lifting(latent_res=16, in_dim=dim, num_layers=4)
    out = model(x)
    print(out.shape)
