import numpy as np

def lr_schedule(step, c0=1, a=0.6, b=0):
    """
    Learning rate schedule.

    Parameters
    ----------
    step : int
        The current step or iteration in each chain.
    c0 : float, optional
        Initial learning rate coefficient (default is 1).
    a : float, optional
        Exponent for the step in the denominator (default is 0.6).
    b : float, optional
        Offset added to the denominator (default is 0).

    Returns
    -------
    lr : float
        The computed learning rate for the current step.
    """
    lr = c0 / (step**a + b)  # Or: lr = c0 / (step**a + c_1)
    return lr

class DPQuantile:
    """
    Base class for Differential Privacy Quantile Estimation.
    """

    def __init__(self, tau=0.5, r=0.5, true_q=None, burn_in_ratio=0):
        """
        Initialize the DPQuantile estimator.

        Parameters
        ----------
        tau : float
            The quantile to estimate.
        r : float
            Response rate for the estimator.
        true_q : float or None
            The true quantile value (if known).
        burn_in_ratio : float
            Hyperparameter about ratio of burn-in samples among total samples.
        """
        self.tau = tau
        self.r = r
        self.true_q = true_q
        self.burn_in_ratio = burn_in_ratio

    def reset(self):
        """
        Reset the training state.
        """
        self.q_est = 0.8  # Initial estimate
        self.Q_avg = self.q_est
        self.n = 0
        self.step = 0

        # Online inference statistics
        self.v_a = 0.0
        self.v_b = 0.0
        self.v_s = 0.0
        self.v_q = 0.0
        self.errors = []

    def _compute_gradient(self, x):
        """
        Core gradient computation.

        Parameters
        ----------
        x : float
            Input data point.

        Returns
        -------
        delta : float
            Computed gradient for the current step.
        """
        if np.random.rand() < self.r:
            s = int(x > self.q_est)
        else:
            s = np.random.binomial(1, 0.5)

        delta = ((1 - self.r + 2 * self.tau * self.r) / 2 if s
                 else -(1 + self.r - 2 * self.tau * self.r) / 2)
        return delta

    def _update_estimator(self, delta, lr):
        """
        Update the estimator parameters.

        Parameters
        ----------
        delta : float
            Gradient value.
        lr : float
            Learning rate.
        """
        self.q_est += lr * delta
        self.step += 1

    def _update_stats(self):
        """
        Update statistics for online inference.
        """
        self.n += 1
        prev_weight = (self.n - 1) / self.n
        self.Q_avg = prev_weight * self.Q_avg + self.q_est / self.n

        # Update variance statistics
        term = self.n ** 2
        self.v_a += term * self.Q_avg ** 2
        self.v_b += term * self.Q_avg
        self.v_q += term
        self.v_s += 1

    def fit(self, data_stream):
        """
        Single-machine training method.

        Parameters
        ----------
        data_stream : array-like
            Input data stream for training.

        Returns
        -------
        var : np.ndarray
            Array of variance estimates after burn-in.
        """
        self.reset()
        n_samples = len(data_stream)
        burn_in = int(n_samples * self.burn_in_ratio)  # Calculate burn-in sample count
        var = np.zeros(n_samples - burn_in)
        print(data_stream)
        for idx, x in enumerate(data_stream):
            print(x)
            # Calculate learning rate for current step
            lr = lr_schedule(self.step + 1)

            # Compute gradient and update estimate
            delta = self._compute_gradient(x)
            self._update_estimator(delta, lr)

            # Skip statistics update during burn-in period
            if idx >= burn_in:
                self._update_stats()
                var[idx - burn_in] = self.get_variance()
            if self.step >= n_samples:
                break
        return var

    def get_variance(self):
        """
        Get the variance estimate.

        Returns
        -------
        variance : float
            Estimated variance.
        """
        if self.n == 0:
            return 0.0
        return (self.v_a - 2 * self.Q_avg * self.v_b +
                (self.Q_avg ** 2) * self.v_q) / (self.n ** 2 * self.v_s)