import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class PatchEmbed2D(nn.Module):
    """
    2D Patch Embedding module:
    Pre-computes the number of patches based on kernel/stride during initialization.
    """

    def __init__(
        self, 
        input_channels=1, 
        in_height=32,
        in_width=160,
        kernel_size=(1, 50),
        stride=(1, 50),
        embed_dim=128,
        flatten=True
    ):
        super().__init__()
        self.flatten = flatten
        self.kernel_size = kernel_size
        self.stride = stride
        self.in_height = in_height
        self.in_width = in_width

        # Convolution used to extract 2D patches
        self.proj = nn.Conv2d(
            in_channels=input_channels,
            out_channels=embed_dim,
            kernel_size=kernel_size,
            stride=stride,
            bias=True
        )

        # Compute H_out and W_out
        kernel_h, kernel_w = kernel_size
        stride_h, stride_w = stride

        H_out = (in_height - kernel_h) // stride_h + 1
        W_out = (in_width  - kernel_w) // stride_w + 1

        self.H_out = H_out
        self.W_out = W_out

        self.num_patches = H_out * W_out
        self.embed_dim = embed_dim

    def init_patch_embed(self):
        w = self.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

    def forward(self, x):
        """
        Args:
            x: [B, input_channels, in_height, in_width]
        Returns:
            If flatten=True => [B, num_patches, embed_dim]
            Otherwise => [B, embed_dim, H_out, W_out]
        """
        x = self.proj(x)  # => [B, embed_dim, H_out, W_out]
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # => [B, num_patches, embed_dim]
        return x


def patchify_2d(x, patch_size=(1, 50), padding_mode='constant'):
    """
    增强版patchify_2d，支持非整除情况的padding
    """
    B, C, H, W = x.shape
    ph, pw = patch_size
    
    # 处理非整除情况
    padding_h = (ph - H % ph) % ph
    padding_w = (pw - W % pw) % pw
    
    if padding_h > 0 or padding_w > 0:
        # 添加padding使尺寸能被patch_size整除
        x = F.pad(x, (0, padding_w, 0, padding_h), mode=padding_mode)
        H, W = H + padding_h, W + padding_w
    
    # 确保尺寸能被patch_size整除
    assert H % ph == 0, "Height is not divisible by patch height"
    assert W % pw == 0, "Width is not divisible by patch width"

    nH = H // ph  # number of patches along the vertical dimension
    nW = W // pw  # number of patches along the horizontal dimension

    # 1) reshape => [B, C, nH, ph, nW, pw]
    x = x.reshape(B, C, nH, ph, nW, pw)
    # 2) rearrange dimensions => [B, nH, nW, C, ph, pw]
    x = x.permute(0, 2, 4, 1, 3, 5)
    # 3) merge the last three dimensions => [B, nH*nW, C*ph*pw]
    x = x.reshape(B, nH*nW, C*ph*pw)
    return x

def unpatchify_2d(x, patch_size=(1, 50), out_channels=1, out_height=64, out_width=1000):
    """
    The inverse of patchify_2d:
    Transforms [B, num_patches, C*ph*pw] back to [B, C, H, W].
    """
    B, N, patch_dim = x.shape
    ph, pw = patch_size
    
    # 计算C (如果未知)
    C = out_channels
    patch_area = ph * pw
    if patch_dim % patch_area != 0 and out_channels == 1:
        # 尝试推断通道数
        C = patch_dim // patch_area
        assert C * patch_area == patch_dim, "Cannot infer channel count from patch dimension"
    
    # 计算nH, nW
    if out_height is None or out_width is None:
        # 尝试从N推断
        nH = int(math.sqrt(N))
        nW = N // nH
        assert nH * nW == N, "Cannot infer grid dimensions from patch count"
        out_height, out_width = nH * ph, nW * pw
    else:
        # 从指定的输出尺寸计算
        nH, nW = out_height // ph, out_width // pw
        assert nH*nW == N, f"num_patches mismatch, expected {nH*nW}, got {N}"

    # 1) reshape => [B, nH, nW, patch_dim]
    x = x.reshape(B, nH, nW, patch_dim)
    # 2) split into => [B, nH, nW, C, ph, pw]
    x = x.reshape(B, nH, nW, C, ph, pw)
    # 3) permute => [B, C, nH, ph, nW, pw]
    x = x.permute(0, 3, 1, 4, 2, 5)
    # 4) reshape => [B, C, nH*ph, nW*pw]
    x = x.reshape(B, C, nH*ph, nW*pw)
    return x
