from regression.session_story_configs import SessionStoryConfig
from regression.regression_utils import normalize
from scipy.signal import butter, filtfilt
import numpy as np

def low_pass_filter(data, cutoff, fs, order=5):
    """
    Applies a low-pass filter to the given data.

    Args:
        data (array-like): The time series data to filter.
        cutoff (float): The cutoff frequency in Hz.
        fs (float): The sampling rate of the data in Hz.
        order (int, optional): The order of the filter. Defaults to 5.

    Returns:
        numpy.ndarray: The filtered data.
    """
    nyquist = 0.5 * fs
    normal_cutoff = cutoff / nyquist
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    y = filtfilt(b, a, data)
    return y


def notch_filter(data, low_cut, high_cut, fs, order=4):
    """
    Apply a notch (bandstop) filter to remove frequencies from low_cut to high_cut Hz.
    
    Parameters:
    - data: numpy array containing the signal.
    - low_cut: Lower cutoff frequency (Hz) of the notch.
    - high_cut: Upper cutoff frequency (Hz) of the notch.
    - fs: Sampling frequency (Hz).
    - order: Order of the filter.
    
    Returns:
    - The filtered signal as a numpy array.
    """
    nyq = 0.5 * fs
    low = low_cut / nyq
    high = high_cut / nyq
    
    # Design the Butterworth bandstop filter.
    b, a = butter(order, [low, high], btype='bandstop', analog=False)
    
    # Apply zero-phase filtering.
    filtered_data = filtfilt(b, a, data)
    return filtered_data


def load_meg_targets(session_story_configs:SessionStoryConfig):
    raw_data = [session_story_config.load_aligned_downsampled_meg()[1].T for session_story_config in session_story_configs]
    normalized_data = [normalize(data)[0] for data in raw_data]
    return normalized_data
            
