import torch
import torch.nn as nn
import torch.nn.functional as F
import pywt
import numpy as np

class FrequencyFilter(nn.Module):
    """
    Fourier Filter: to time-variant and time-invariant term
    """
    def __init__(self, mask_spectrum, freq_type):
        super(FrequencyFilter, self).__init__()
        self.mask_spectrum = mask_spectrum
        self.freq_type = freq_type
        
    def forward(self, x):
        B,L,C = x.shape
        if self.freq_type == "fft":
            xf = torch.fft.rfft(x, dim=1)
            mask = torch.ones_like(xf)
            mask[:, self.mask_spectrum, :] = 0
            x_sto = torch.fft.irfft(xf*mask, dim=1)  
        elif self.freq_type in ['db2', 'sym2', 'coif1', 'bior1.3', 'rbio1.3']:
            wavelet = pywt.Wavelet(self.freq_type)
            device = x.device
            data = x.permute(0,2,1).cpu().numpy()
            cA, cD = pywt.dwt(data, wavelet)
            Dims = cA.shape[-1]
            frequency_feature = np.concatenate((cA, cD), axis=2) # B C D
            mask = np.ones_like(frequency_feature)
            # print(mask.shape)
            mask[:, :, self.mask_spectrum] = 0
            frequency_feature = frequency_feature * mask
            cA_sto, cD_sto = frequency_feature[:, :, :Dims], frequency_feature[:, :, Dims:]
            # print(cA_sto.shape, cD_sto.shape)
            x_sto = pywt.idwt(cA_sto, cD_sto, wavelet)
            x_sto = torch.from_numpy(x_sto).to(device).permute(0, 2, 1)
            # print(x_sto.shape, x.shape)
        if x_sto.shape[1] < x.shape[1]:
            # padding L + 1
            x_sto = F.pad(x_sto, (0,0,0,1,0,0), mode='constant', value=x_sto.mean(dim=1).mean())
        elif x_sto.shape[1] > x.shape[1]:
            x_sto = x_sto[:, :x.shape[1], :]
            
        x_det = x - x_sto
        return x_sto, x_det
    

class LearnableFrequencyFilter(nn.Module):
    """
    Learnable Frequency Filter: make a learnable binary mask to divide data into time-variant and time-invariant term
    """

    def __init__(self, alpha, len_seq, freq_type):
        super(LearnableFrequencyFilter, self).__init__()
        self.mask = nn.Parameter(torch.rand(len_seq // 2 + 1), requires_grad=True)
        self.freq_type = freq_type
        self.alpha = alpha
        self.reg_mask = torch.ones(int(self.alpha * self.mask.shape[0]))
        print("reg_mask", self.reg_mask.shape)
        print("mask", self.mask.shape)
        self.reg_mask = torch.cat((torch.zeros(self.mask.shape[0] - self.reg_mask.shape[0]),self.reg_mask))
        print("reg_mask", self.reg_mask)

    def forward(self, x):
        B,L,C = x.shape
        if self.freq_type == "fft":
            xf = torch.fft.rfft(x, dim=1)
            # print("xf", xf.shape)
            mask_map = torch.ones_like(xf)
            # get mask > 0.5 index
            # print("mask", torch.sigmoid(self.mask))
            mask_spectrum = torch.where(torch.sigmoid(self.mask) >= 0.5)[0]
            print("mask_spectrum", mask_spectrum)
            mask_map[:, mask_spectrum, :] = 0
            x_sto = torch.fft.irfft(xf*mask_map, dim=1)
            # B L C
            if x_sto.shape[1] < x.shape[1]:
                # padding L + 1
                x_sto = F.pad(x_sto, (0,0,0,1,0,0), mode='constant', value=x_sto.mean(dim=1).mean())
            x_det = x - x_sto
            mask_tensor_sorted = torch.sort(torch.sigmoid(self.mask))[0]

            mask_reg = ((self.reg_mask.to(x.device) - mask_tensor_sorted) ** 2).mean()
            print("mask_reg", mask_reg)
            return x_sto, x_det, mask_reg  
        elif self.freq_type in ['db2', 'sym2', 'coif1', 'bior1.3', 'rbio1.3']:
            wavelet = pywt.Wavelet(self.freq_type)
            device = x.device
            data = x.permute(0,2,1).cpu().numpy()
            cA, cD = pywt.dwt(data, wavelet)
            Dims = cA.shape[-1]
            frequency_feature = np.concatenate((cA, cD), axis=2)
            mask = np.ones_like(frequency_feature)
            mask[:, :, self.mask] = 0
            frequency_feature = frequency_feature * mask
            cA_sto, cD_sto = frequency_feature[:, :, :Dims], frequency_feature[:, :, Dims:]
            x_sto = pywt.idwt(cA_sto, cD_sto, wavelet)
            x_sto = torch.from_numpy(x_sto).to(device).permute(0, 2, 1)[:, :x.shape[1], :]
            x_det = x - x_sto
            return x_sto, x_det