import einops
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

class ConvPatchEmbedding(nn.Module):
    def __init__(self,
                 image_size,
                 patch_size,
                 target_dim,
                 channels=2,
                 n_in_conv_layers=2,
                 n_out_conv_layers=2,
                 conv_channels = 32,
                 act = nn.SiLU,
                 ):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        self.image_height = image_height
        self.image_width = image_width
        self.patch_height = patch_height
        self.patch_width = patch_width
        self.channels = channels
        self.num_patches = num_patches
        self.target_dim = target_dim
        self.n_conv_layers = n_in_conv_layers
        self.conv_channels = conv_channels
        self.act = act()

        # Input convolutional layers
        self.input_convs = nn.ModuleList()
        for i in range(n_in_conv_layers):
            self.input_convs.append(nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1))
            
        patch_dim = conv_channels * patch_height * patch_width

        # Patch embedding layers
        self.patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, target_dim),
            nn.LayerNorm(target_dim),
        )

        # Positional embedding (only for patches)
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, target_dim))

        # Unpatch layers
        self.patch_unembed = nn.Sequential(
            nn.Linear(target_dim, patch_dim),
            Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)',
                      h=image_height // patch_height,
                      w=image_width // patch_width,
                      p1=patch_height,
                      p2=patch_width)
        )
        
        # Output convolutional layers
        self.output_convs = nn.ModuleList()
        for _ in range(n_out_conv_layers):
            self.output_convs.append(nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1))
            
        self.in_proj = nn.Linear(channels, conv_channels)
        self.last_proj = nn.Linear(conv_channels, channels)

    def forward(self, x):
        # Apply input convolutions
        b, T, h, w, c = x.shape

        x = self.in_proj(x)
        x = einops.rearrange(x, 'b T h w c -> b c T h w')
        
        for i, conv in enumerate(self.input_convs):
            c = conv(x)
            x = x + self.act(c)
                
        x = einops.rearrange(x, 'b c T h w -> b T h w c', T=T)


        # Patch embedding
        x = einops.rearrange(x, 'b T h w c -> (b T) c h w')
        x = self.patch_embedding(x) # (b T) (h w) d
        x += self.pos_embedding  # Add positional embedding
        x = einops.rearrange(x, '(b T) (h w) d -> b T h w d', T=T, h=self.image_height//self.patch_height)
        return x

    def unpatch(self, x):
        # Unpatch
        b, T, h, w, d = x.shape
        x = einops.rearrange(x, 'b T h w d -> (b T) (h w) d')
        x = self.patch_unembed(x)
        x = einops.rearrange(x, '(b T) c h w -> b T h w c', T=T)

        # Apply output convolutions
        x = einops.rearrange(x, 'b T h w c -> b c T h w')
        for i, conv in enumerate(self.output_convs):
            c = conv(x)
            x = x + self.act(c)
            
        x = einops.rearrange(x, 'b c T h w -> b T h w c', T=T)
        
        x = self.last_proj(x)
        return x
