import einops
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

class PatchEmbedding(nn.Module):
    def __init__(self, 
                 image_size, 
                 patch_size, 
                 dim, 
                 channels=2):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        self.image_height = image_height
        self.image_width = image_width
        self.patch_height = patch_height
        self.patch_width = patch_width
        self.channels = channels
        self.num_patches = num_patches
        self.dim = dim

        # Patch embedding layers
        self.patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        # Unpatch layers
        self.patch_unembed = nn.Sequential(
            nn.Linear(dim, patch_dim),
            Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)',
                      h=image_height // patch_height,
                      w=image_width // patch_width,
                      p1=patch_height,
                      p2=patch_width)
        )

        # Positional embedding (only for patches)
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))

    def forward(self, x):
        # change to channel first and sqeeze time dim
        b, T, h, w, _ = x.shape
        x = einops.rearrange(x, 'b T h w d -> (b T) d h w')
        
        # Patch embedding
        x = self.patch_embedding(x)
        x += self.pos_embedding  # Add positional embedding
        
        x = einops.rearrange(x, '(b T) (h w) d -> b T h w d', T=T, h=self.image_height//self.patch_height)
        return x

    def unpatch(self, x):
        # change to channel first and sqeeze time dim
        b, T, h, w, _ = x.shape
        x = einops.rearrange(x, 'b T h w d -> (b T) (h w) d ')
        
        x = self.patch_unembed(x)
        
        x = einops.rearrange(x, '(b T) d h w -> b T h w d', T=T)
        return x

