"""
Filters for EEG data preprocessing.
Includes bandpass, highpass and lowpass
"""

from scipy.signal import butter, sosfilt, iirnotch, filtfilt, sos2tf
import torch
from torchaudio.functional import lfilter

def highpass(x, fs, fc, order=4):
    """
    Applies highpass filter

    Parameters:
     - x:     numpy.array, input signal, sampled at fs
     - fs:    float, sampling frequency
     - fc:    float, cutoff frequency
     - order: filter order

    Returns: numpy.array
    """
    nyq = 0.5 * fs
    norm_fc = fc / nyq
    sos = butter(order, norm_fc, btype='highpass', output='sos')
    return sosfilt(sos, x)


def bandpass(x, fs, fc_low, fc_high, order=4, ftype='sos'):
    """
    Applies bandpass filter

    Parameters:
     - x:       numpy.array, input signal, sampled at fs
     - fs:      float, sampling frequency
     - fc_low:  float, low cutoff frequency
     - fc_high: float, high cutoff frequency
     - order:   filter order

    Returns: numpy.array
    """
    nyq = 0.5 * fs
    norm_fc_low = fc_low / nyq
    norm_fc_high = fc_high / nyq
    if ftype=='sos':
        sos = butter(order, [norm_fc_low, norm_fc_high], btype='bandpass', output='sos')
        return sosfilt(sos, x)

    elif ftype=='filtfilt':
        b, a = butter(order, [norm_fc_low, norm_fc_high], btype='bandpass', output='ba', analog=False)
        return filtfilt(b,a, x)

def bandpass_lfilter(x, fs, fc_low, fc_high, order=3, ftype='filtfilt',device=torch.device('cuda')):
    nyq = 0.5 * fs
    norm_fc_low = fc_low / nyq
    norm_fc_high = fc_high / nyq

    if ftype == 'sos': # not necessary, it is basically the same as with output='ba'
        sos = butter(order, [norm_fc_low, norm_fc_high], btype='bandpass', output='sos')
        b, a = sos2tf(sos)
        b = torch.from_numpy(b).float().to(device)
        a = torch.from_numpy(a).float().to(device)

    elif ftype == 'filtfilt':
        b, a = butter(order, [norm_fc_low, norm_fc_high], btype='bandpass', output='ba', analog=False)
        # b = torch.from_numpy(b).float().to(device)
        # a = torch.from_numpy(a).float().to(device)
        b = torch.from_numpy(b).float()
        a = torch.from_numpy(a).float()
    filtered = lfilter(waveform=x.cpu(), a_coeffs=a, b_coeffs=b, clamp = False)
    return filtered.to(device)
    # return lfilter(waveform=x, a_coeffs=a, b_coeffs=b, clamp = False)

def bandpass_torch(input, low_f, high_f,fs_eeg, device=torch.device('cuda')):
    """
    Takes the input tensor to the frequency domain and removes the frequencies higher than high_f and lower than low_f .

    Parameters:
     - input:                   tensor
     - low_f:                   float, lowest frequency which will be allowed                         
     - high_f:                  float, highest frequency which will be allowed
    
    Returns: filtered input tensor

    """
    pass1 = torch.abs(torch.fft.rfftfreq(input.shape[-1],1/fs_eeg)) > low_f
    pass2 = torch.abs(torch.fft.rfftfreq(input.shape[-1],1/fs_eeg)) < high_f
    fft_input = torch.fft.rfft(input)
    return torch.fft.irfft(fft_input.to(device) * pass1.to(device) * pass2.to(device))
    
def bandpass_torch_(input, low_f, high_f, device=torch.device('cuda')):
    """
    PhysioNet only: Takes the input tensor to the frequency domain and removes the frequencies higher than high_f and lower than low_f .

    Parameters:
     - input:                   tensor
     - low_f:                   float, lowest frequency which will be allowed                         
     - high_f:                  float, highest frequency which will be allowed
    
    Returns: filtered input tensor

    """
    pass1 = torch.abs(torch.fft.fftfreq(input.shape[-1],1/160)) > low_f
    pass2 = torch.abs(torch.fft.fftfreq(input.shape[-1],1/160)) < high_f
    fft_input = torch.fft.fft(input)
    return torch.fft.ifft(fft_input.to(device) * pass1.to(device) * pass2.to(device))

def notch(x, fs, fc, Q=30):
    nyq = 0.5 * fs
    norm_fc = fc / nyq
    b, a = iirnotch(norm_fc, Q)
    return filtfilt(b, a, x)