import einops
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

class ConvPatchEmbedding(nn.Module):
    def __init__(self,
                 image_size,
                 patch_size,
                 target_dim,
                 channels=2,
                 n_conv_layers=1,
                 first_layer_out_channels=16,
                 ):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        self.image_height = image_height
        self.image_width = image_width
        self.patch_height = patch_height
        self.patch_width = patch_width
        self.channels = channels
        self.num_patches = num_patches
        self.target_dim = target_dim
        self.n_conv_layers = n_conv_layers

        # Input convolutional layers
        self.input_convs = nn.ModuleList()
        in_channels = channels
        out_channels = first_layer_out_channels
        for i in range(n_conv_layers):
            self.input_convs.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
            in_channels = out_channels
            out_channels = out_channels * 2
        
        # note: in_channels are last out_channels
        in_patch_dim = in_channels * patch_height * patch_width

        # Patch embedding layers
        self.patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.LayerNorm(in_patch_dim),
            nn.Linear(in_patch_dim, target_dim),
            nn.LayerNorm(target_dim),
        )

        # Positional embedding (only for patches)
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, target_dim))

        # Output convolutional layers
        self.output_convs = nn.ModuleList()
        for _ in range(n_conv_layers):
            self.output_convs.append(nn.Conv2d(channels, channels, kernel_size=3, padding=1))

        # Unpatch layers
        self.patch_unembed = nn.Sequential(
            nn.Linear(target_dim, patch_dim),
            Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)',
                      h=image_height // patch_height,
                      w=image_width // patch_width,
                      p1=patch_height,
                      p2=patch_width)
        )

    def forward(self, x):
        # Apply input convolutions
        b, T, h, w, c = x.shape
        x = einops.rearrange(x, 'b T h w c -> (b T) c h w')
        for conv in self.input_convs:
            x = conv(x)
            assert False, "misses residual connection, activation function and conv_residual connection"
        x = einops.rearrange(x, '(b T) c h w -> b T h w c', T=T)

        # Patch embedding
        x = einops.rearrange(x, 'b T h w c -> (b T) c h w')
        x = self.patch_embedding(x)
        x += self.pos_embedding  # Add positional embedding
        x = einops.rearrange(x, '(b T) (h w) d -> b T h w d', T=T, h=self.image_height//self.patch_height)
        return x

    def unpatch(self, x):
        # Unpatch
        b, T, h, w, d = x.shape
        x = einops.rearrange(x, 'b T h w d -> (b T) (h w) d')
        x = self.patch_unembed(x)
        x = einops.rearrange(x, '(b T) c h w -> b T h w c', T=T)

        # Apply output convolutions
        x = einops.rearrange(x, 'b T h w c -> (b T) c h w')
        for conv in self.output_convs:
            x = conv(x)
        x = einops.rearrange(x, '(b T) c h w -> b T h w c', T=T)
        return x
