import torch.nn as nn
import torch
import torch.nn.functional as F
from einops import rearrange
import math

from src.models.blocks.attention1d import BasicTransformerBlock, CrossAttention

class CrossViewEncoder(nn.Module):
    def __init__(self, in_channels, channels, num_layers) -> None:
        super(CrossViewEncoder, self).__init__()
        self.transformers_cross = nn.ModuleList()
        self.transformers_self = nn.ModuleList()
        n_heads = 8
        head_dim = channels // n_heads
        self.proj = nn.Linear(in_channels, channels)
        for _ in range(num_layers):
            self.transformers_cross.append(CrossAttention(channels, heads=n_heads, dim_head=head_dim, cross_attention_dim=channels))
            self.transformers_self.append(BasicTransformerBlock(channels, num_attention_heads=n_heads, attention_head_dim=head_dim))

    def forward(self, x, time_embeddings):
        '''
        x in shape [b,t,c,h,w]
        '''
        b,t,c,h,w = x.shape
        # get canonical view
        x_canonical = x[:, 0]                                           # [b,c,h,w]
        x = x[:, 1:]                                                    # [b,t-1,c,h,w]
        x_canonical = rearrange(x_canonical, 'b c h w -> b (h w) c')
        x = rearrange(x, 'b t c h w -> b (t h w) c')
        x = self.proj(x)
        x_canonical = self.proj(x_canonical)

        for (cross_attn, self_attn) in zip(self.transformers_cross, self.transformers_self):   # requires [b,n,c] inputs
            # cross-attention between canonical-other frames
            x = cross_attn(x, encoder_hidden_states=x_canonical) 
            x = rearrange(x, 'b (t h w) c -> b t c h w', t=t-1, h=h, w=w)
            x_canonical = rearrange(x_canonical, 'b (t h w) c -> b t c h w', t=1, h=h, w=w)
            x = torch.cat([x_canonical, x], dim=1)     
            x = rearrange(x, 'b t c h w -> b (t h w) c')   
            x = self_attn(x)
            x = rearrange(x, 'b (t h w) c -> b t c h w', t=t, h=h, w=w)
            x_canonical = x[:, 0]
            x = x[:, 1:]
            x_canonical = rearrange(x_canonical, 'b c h w -> b (h w) c')
            x = rearrange(x, 'b t c h w -> b (t h w) c')

        x_canonical = rearrange(x_canonical, 'b (t h w) c -> b t c h w', t=1, h=h, w=w)
        x = rearrange(x, 'b (t h w) c -> b t c h w', t=t-1, h=h, w=w)
        x = torch.cat([x_canonical, x], dim=1) 

        return x

if __name__ == '__main__':
    dim = 768
    x = torch.rand(2, 8, dim, 64, 64)
    model = CrossViewEncoder(in_dim=dim, num_layers=3)
    out = model(x)
    print(out.shape)