import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from src.models.kappa_overrides.dit_conditioning import Dit

def modulate_scale_shift(x, scale, shift):
    # if x.ndim == 3:
    #     scale = scale.unsqueeze(1)
    #     shift = shift.unsqueeze(1)
    return x * (1 + scale) + shift

def modulate_gate(x, gate):
    # if x.ndim == 3:
    #     gate = gate.unsqueeze(1)
    return gate * x

class TimeMixingBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        cond_dim: int,
        kernel_size: tuple = (3, 3, 3),
        stride: tuple = (1, 1, 1),
        padding: tuple = (1, 1, 1),
        act: str = "gelu",
        init_weights="xavier_uniform",
        init_modulation_zero=False,
        init_gate_zero=False,
    ):
        super().__init__()
        self.conv1 = nn.Conv3d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )
        self.conv2 = nn.Conv3d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )
        self.act = getattr(F, act) if hasattr(F, act) else nn.GELU()
        self.norm = nn.LayerNorm(in_channels, elementwise_affine=False)
        
        assert in_channels == out_channels
        # modulation
        self.modulation = Dit(
            cond_dim=cond_dim,
            out_dim=out_channels,
            init_weights="zero" if init_modulation_zero else init_weights,
            num_outputs=3,
            gate_indices=[2],
            init_gate_zero=init_gate_zero,
        )

    def forward(self, x, cond) -> torch.Tensor:
        scale, shift, gate = self.modulation(cond)
        r = x
        x = modulate_scale_shift(self.norm(x), scale=scale, shift=shift)
        x = einops.rearrange(x, 'b T h w d -> b d T h w')
        x = self.conv1(x)
        x = self.act(x)
        x = self.conv2(x)
        x = einops.rearrange(x, 'b d T h w -> b T h w d')
        x = modulate_gate(x, gate=gate)
        x = r + x
        return x

