import numpy as np

class ConstScheduler:
    def __init__(self, start_value: float):
        """
        Initialize the baseScheduler.
        """
        self.coefficient = start_value
        
    def update_coefficient(self, current_epoch: int) -> float:
        """
        Update the coefficient based on the current epoch.
        """
        if current_epoch < 0:
            raise ValueError("current_epoch must be non-negative.")
        
        return self.coefficient

    def get_coefficient(self) -> float:
        """
        Get the current value of the coefficient.
        """
        return self.coefficient
    

class LinearScheduler():
    def __init__(self, start_value: float, end_value: float, num_epochs: int):
        """
        Initialize the LinearScheduler.
        """
        if num_epochs <= 0:
            raise ValueError("num_epochs must be a positive integer.")
        
        self.start_value = start_value
        self.end_value = end_value
        self.num_epochs = num_epochs
        self.coefficient = start
        
    def update_coefficient(self, current_epoch: int) -> float:
        """
        Update the coefficient linearly based on the current epoch.
        """
        if current_epoch < 0:
            raise ValueError("current_epoch must be non-negative.")
        
        # Compute the linear interpolation.
        progress = current_epoch / (self.num_epochs - 1)
        self.coefficient = self.start_value + progress * (self.end_value - self.start_value)
        return self.coefficient

    def get_coefficient(self) -> float:
        """
        Get the current value of the coefficient.
        """
        return self.coefficient


class CosineScheduler():
    def __init__(self, start_value: float, end_value: float, period: int):
        """
        Initialize the CosineScheduler.
        """
        self.start_value = start_value
        self.end_value = end_value
        self.period = period
        self.coefficient = start_value
        
    def update_coefficient(self, current_epoch: int) -> float:
        """
        Update the coefficient based on the cosine function.
        """
        if current_epoch < 0:
            raise ValueError("current_epoch must be non-negative.")
        
        # Compute the cosine interpolation.
        progress = current_epoch / self.period
        self.coefficient = self.start_value + 0.5 * (1 - np.cos(progress * 2 * np.pi)) * (self.end_value - self.start_value)
        return self.coefficient

    def get_coefficient(self) -> float:
        """
        Get the current value of the coefficient.
        """
        return self.coefficient


class SigmoidScheduler():
    def __init__(self, start_value: float, end_value: float, midpoint: int, scale: int):
        """
        Initialize the SigmoidScheduler.
        """
        self.start_value = start_value
        self.end_value = end_value
        self.midpoint = midpoint
        self.scale = scale
        self.coefficient = start_value
        
    def update_coefficient(self, current_epoch: int) -> float:
        """
        Update the coefficient based on the sigmoid function.
        """
        if current_epoch < 0:
            raise ValueError("current_epoch must be non-negative.")
        
        # Compute the sigmoid interpolation.
        self.coefficient = self.start_value + (self.end_value - self.start_value) / (1 + np.exp((self.midpoint - current_epoch) / self.scale))
        return self.coefficient

    def get_coefficient(self) -> float:
        """
        Get the current value of the coefficient.
        """
        return self.coefficient


class SparsityScheduler():
    def __init__(self, min_l1: float, start_value: float, lambda_l1: float = 0.05):
        """
        Initialize the DynamicalScheduler.
        """
        self.min_l1 = min_l1
        self.lambda_l1 = lambda_l1
        self.prev_l1 = None
        self.coefficient = start_value
        
    def update_coefficient(self, current_l1: int) -> float:
        """
        Update the coefficient based on the dynamical function.
        """
        # Compute the dynamical interpolation.
        if self.prev_l1 is None:
            self.prev_l1 = current_l1
        else:
            self.coefficient = self.coefficient + self.lambda_l1 * (current_l1 - self.prev_l1)
        return self.coefficient

    def get_coefficient(self) -> float:
        """
        Get the current value of the coefficient.
        """
        return self.coefficient