import numpy as np
from scipy import signal


class LowPassFilter:
    def __init__(self, cutoff_freq, sampling_freq, action_dim, method='gbt', alpha=0.5):
        # Initialize filter parameters
        self.cutoff_freq = cutoff_freq  # Cutoff frequency in Hz
        self.sampling_freq = sampling_freq  # Sampling frequency in Hz
        self.dt = 1.0 / sampling_freq  # Time step
        self.method = method  # Discretization method
        self.alpha = alpha  # Alpha for GBT (Tustin's method)

        # Initialize past input and output (filtered) values
        self.prev_action = np.zeros(action_dim)
        self.prev_filtered_action = np.zeros(action_dim)

        # Design the low-pass filter
        self.design_filter()

    def design_filter(self):
        # Compute the pole frequency (rad/s)
        w0 = 2 * np.pi * self.cutoff_freq

        # Continuous-time transfer function numerator and denominator
        num = w0
        den = [1, w0]

        # Create the continuous-time transfer function
        lowPass = signal.TransferFunction(num, den)

        # Discretize the transfer function using Tustin's method
        discreteLowPass = lowPass.to_discrete(self.dt, method=self.method, alpha=self.alpha)

        # Extract filter coefficients from discrete transfer function
        self.b = discreteLowPass.num
        self.a = -discreteLowPass.den  # Flip the sign of denominator coefficients (Z-transform)

        # Keep only the necessary coefficients (omit a[0] = 1)
        self.a = self.a[1:]

        print(f"Filter coefficients b_i: {self.b}")
        print(f"Filter coefficients a_i: {self.a}")

    def apply(self, action):
        """
        Apply the low-pass filter to the action based on past values.
        :param action: The current action value to be filtered.
        :return: The filtered action.
        """
        # Apply the filter update equation based on the past actions and filtered actions
        filtered_action = self.a[0] * self.prev_filtered_action + self.b[0] * action + self.b[1] * self.prev_action

        # Update the stored previous values
        self.prev_action = action
        self.prev_filtered_action = filtered_action

        return np.clip(filtered_action, -1, 1)

    def reset(self):
        # resets the filter
        self.prev_action.fill(0)
        self.prev_filtered_action.fill(0)
