import numpy as np

def lr_schedule(step,c0=1,a=0.6):
    """
    Learning-rate schedule of the form  ηₜ = c₀ / tᵃ.

    Parameters
    ----------
    step : int
        Current iteration (starting from 1).
    c0 : float, default=1
        Scale factor. 
    a : float, default=0.6
        Decay exponent (0 < a ≤ 1).

    Returns
    -------
    float
        Learning rate for this step.
    """
    lr = c0 / (step**a)
    return lr

class DPQuantile:
    """
    Base class for LDP quantile estimation.

    The estimator follows an online (streaming) procedure with randomized response of probability ``r`` to preserve privacy.
    """
        
    def __init__(self, tau=0.5, r=0.5, true_q=None,
     track_history=False, burn_in_ratio=0, use_true_q_init=False):
        self.tau = tau # target quantile level
        self.r = r # randomized-response rate
        self.true_q = true_q # ground-truth quantile (for diagnostics only)
        self.track_history = track_history
        self.burn_in_ratio = burn_in_ratio
        self.use_true_q_init = use_true_q_init

    def reset(self):
        """Re-initialise the running state before processing a new stream."""
        if self.use_true_q_init and self.true_q is not None:
            self.q_est = self.true_q  # Initial estimate: either use the ground truth or start from 0
        else:
            self.q_est = 0.0
        self.Q_avg = 0.0  # running mean of point estimates
        self.n = 0 # number of samples contributing to stats
        self.step = 0 # total iterations processed
        
        
        # Online variance bookkeeping
        self.v_a = 0.0
        self.v_b = 0.0
        self.v_s = 0.0
        self.v_q = 0.0
        self.errors = []

    def _compute_gradient(self, x):
        """
        Compute the stochastic gradient for one data point ``x``.
        Implements randomized response: with prob ``r`` use the true comparison
        """
        if np.random.rand() < self.r:
            s = int(x > self.q_est)
        else:
            s = np.random.binomial(1, 0.5)
        
        delta = ((1 - self.r + 2*self.tau*self.r)/2 if s 
                else -(1 + self.r - 2*self.tau*self.r)/2)
        return delta

    def _update_estimator(self, delta, lr):
        """Gradient-descent update of the current quantile estimate."""
        self.q_est += lr * delta
        self.step += 1

    def _update_stats(self):
        """
        Update running statistics for the averaged estimator
        """
        self.n += 1
        prev_weight = (self.n - 1) / self.n
        self.Q_avg = prev_weight * self.Q_avg + self.q_est / self.n
        
        term = self.n**2
        self.v_a += term * self.Q_avg**2
        self.v_b += term * self.Q_avg
        self.v_q += term
        self.v_s += 1

        # Track absolute error curve if requested
        if self.track_history and self.true_q is not None:
            self.errors.append(np.abs(self.Q_avg - self.true_q))

    def fit(self, data_stream):
        """
        Process an i.i.d.\ data stream and learn the LDP quantile online.

        Parameters
        ----------
        data_stream : array-like
            Sequence of observations to be processed exactly once.
        """
        self.reset()
        n_samples = len(data_stream)
        burn_in = int(n_samples * self.burn_in_ratio)  # number of samples to skip
        for idx, x in enumerate(data_stream):
            # 1. get learning rate for the *next* step (1-based indexing)
            lr = lr_schedule(self.step + 1)
            
            # 2. stochastic gradient update
            delta = self._compute_gradient(x)
            self._update_estimator(delta, lr)
            
            # 3. update statistics after burn-in period
            if idx >= burn_in:
                self._update_stats()
            
            # (Safe-guard) stop if the stream length was underestimated
            if self.step >= n_samples:
                break

    def get_variance(self):
        if self.n == 0:
            return 0.0
        return (self.v_a - 2*self.Q_avg*self.v_b + 
               (self.Q_avg**2)*self.v_q) / (self.n**2 * self.v_s)