import torch
import torch.nn as nn
import torch.nn.functional as F
import pywt
import numpy as np

class FrequencyFilter(nn.Module):
    """
    Fourier Filter: to time-variant and time-invariant term
    """
    def __init__(self, mask_spectrum, freq_type):
        super(FrequencyFilter, self).__init__()
        self.mask_spectrum = mask_spectrum
        self.freq_type = freq_type
        
    def forward(self, x):
        # print(x.shape)
        B,L,C = x.shape
        if self.freq_type == "fft":
            xf = torch.fft.rfft(x, dim=1)
            mask = torch.ones_like(xf)
            mask[:, self.mask_spectrum, :] = 0
            x_sto = torch.fft.irfft(xf*mask, dim=1) 
        elif self.freq_type in ['db2', 'sym2', 'coif1', 'bior1.3', 'rbio1.3']:
            wavelet = pywt.Wavelet(self.freq_type)
            device = x.device
            data = x.permute(0,2,1).cpu().numpy()
            cA, cD = pywt.dwt(data, wavelet)
            Dims = cA.shape[-1]
            frequency_feature = np.concatenate((cA, cD), axis=2) # B C D
            mask = np.ones_like(frequency_feature)
            # print(mask.shape)
            mask[:, :, self.mask_spectrum] = 0
            frequency_feature = frequency_feature * mask
            cA_sto, cD_sto = frequency_feature[:, :, :Dims], frequency_feature[:, :, Dims:]
            # print(cA_sto.shape, cD_sto.shape)
            x_sto = pywt.idwt(cA_sto, cD_sto, wavelet)
            x_sto = torch.from_numpy(x_sto).to(device).permute(0, 2, 1)
            # print(x_sto.shape, x.shape)
        if x_sto.shape[1] < x.shape[1]:
            # padding L + 1
            x_sto = F.pad(x_sto, (0,0,0,1,0,0), mode='constant', value=x_sto.mean(dim=1).mean())
        elif x_sto.shape[1] > x.shape[1]:
            x_sto = x_sto[:, :x.shape[1], :]
            
        x_det = x - x_sto
        return x_sto, x_det