import torch
import torch.nn as nn
from .mlp import MLP

class BandPassFilter(nn.Module):
    """带通滤波器，提取指定频率范围的信号"""
    def __init__(self, min_freq=5, max_freq=20, sample_rate=30):
        super().__init__()
        self.min = min_freq / (sample_rate / 2)
        self.max = max_freq / (sample_rate / 2)
        # Remove fixed window initialization
        
    def forward(self, x):
        # x: [B, T, D] 时间序列信号
        B, T, D = x.shape
        # Create dynamic window based on input sequence length
        window = torch.hann_window(T, device=x.device).view(1, T, 1)
        
        # Apply window to the time dimension
        x_windowed = x * window
        
        # Perform FFT along time dimension
        x_fft = torch.fft.rfft(x_windowed, dim=1)
        
        # Create frequency mask with proper dimensions
        freq = torch.fft.rfftfreq(T, device=x.device)
        mask = ((freq >= self.min) & (freq <= self.max)).view(1, -1, 1)
        
        # Apply mask
        x_filtered = x_fft * mask.float()
        
        # Inverse FFT
        return torch.fft.irfft(x_filtered, n=T, dim=1)

class MicroExpressionMLP(nn.Module):
    def __init__(self, input_dim=3, output_dim=64):
        super().__init__()
        self.filter = BandPassFilter()
        self.mlp = MLP(
            dims=[input_dim, 256, 512, output_dim],
            last_op=nn.Tanh()
        )
        
        self.exp_mlp = MLP(
            dims=[64, 256, 256, 1],
            last_op=nn.Tanh()
        )
        
        self.deform_mlp = MLP(
            dims=[67, 256, 256, 3],
            last_op=nn.Tanh()
        )
        
    def forward(self, exp_coeff_seq, deform):
        B, T, D = exp_coeff_seq.shape
        high_freq = self.filter(exp_coeff_seq) 
        delta = self.mlp(high_freq)
        delta = self.exp_mlp(delta)
        delta = delta.repeat(1, deform.shape[1], 1)
        micro_deform_input = torch.cat([delta, deform], dim=2)
        delta = self.deform_mlp(micro_deform_input.permute(0, 2, 1))
        delta = delta.permute(0, 2, 1)
        return delta * 0.03  