import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial

class TimeMixingBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: tuple = (3, 3, 3),
        stride: tuple = (1, 1, 1),
        padding: tuple = (1, 1, 1),
        act: str = "gelu",
    ):
        super().__init__()
        self.conv1 = nn.Conv3d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )
        self.conv2 = nn.Conv3d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )
        self.act = getattr(F, act) if hasattr(F, act) else nn.GELU()
        self.norm = nn.LayerNorm(in_channels, elementwise_affine=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r = x
        x = einops.rearrange(x, 'b d T h w -> b T h w d')
        x = self.norm(x)
        x = einops.rearrange(x, 'b T h w d -> b d T h w')
        x = self.conv1(x)
        x = self.act(x)
        x = self.conv2(x)
        x = r + x
        return x

