"""Single-chain LDP quantile estimator supporting multiple mechanisms."""

from __future__ import annotations

from dataclasses import dataclass
from typing import List, Sequence

import numpy as np


def _laplace_scale(r: float) -> float:
    r = float(np.clip(r, 1e-12, 1.0 - 1e-12))
    denom = np.log((1.0 + r) / (1.0 - r))
    if not np.isfinite(denom) or denom == 0.0:
        raise ValueError("Invalid randomized response rate for Laplace noise.")
    return 1.0 / denom


def _laplace_noise(r: float) -> float:
    return np.random.laplace(loc=0.0, scale=_laplace_scale(r))


@dataclass
class LearningRateSchedule:
    c0: float = 1.0
    a: float = 0.6
    b0: float = 0.0

    def value(self, step: int) -> float:
        step = max(step, 1)
        return float(self.c0) / (step**float(self.a) + float(self.b0))


class DPQuantile:
    """Streaming quantile estimator with configurable privacy mechanism."""

    def __init__(
        self,
        tau: float = 0.5,
        r: float = 0.5,
        mechanism: str = "rr",
        true_q: float | None = None,
        burn_in_ratio: float = 0.0,
        track_history: bool = False,
        use_true_q_init: bool = False,
        lr_schedule: LearningRateSchedule | None = None,
    ) -> None:
        self.tau = tau
        self.r = r
        self.true_q = true_q
        self.burn_in_ratio = burn_in_ratio
        self.track_history = track_history
        self.use_true_q_init = use_true_q_init
        self.mechanism = mechanism
        self.lr_schedule = lr_schedule or LearningRateSchedule()
        self.reset()

    def reset(self) -> None:
        if self.use_true_q_init and self.true_q is not None:
            self.q_est = float(self.true_q)
        else:
            self.q_est = 0.0
        self.Q_avg = 0.0
        self.n = 0
        self.step = 0
        self.v_a = 0.0
        self.v_b = 0.0
        self.v_s = 0.0
        self.v_q = 0.0
        self.errors: List[float] = []

    def _gradient_rr(self, x: float) -> float:
        if np.random.rand() < self.r:
            s = int(x > self.q_est)
        else:
            s = np.random.binomial(1, 0.5)
        if s:
            return (1 - self.r + 2 * self.tau * self.r) / 2
        return -(1 + self.r - 2 * self.tau * self.r) / 2

    def _gradient_laplace(self, x: float) -> float:
        s = int(x > self.q_est)
        return self.tau if s else -(1 - self.tau)

    def _compute_gradient(self, x: float) -> float:
        mechanism = self.mechanism.lower()
        if mechanism == "rr":
            return self._gradient_rr(x)
        if mechanism == "laplace":
            grad = self._gradient_laplace(x)
            grad += _laplace_noise(self.r)
            return grad
        raise ValueError(f"Unsupported mechanism: {self.mechanism}")

    def _update_estimator(self, delta: float, lr: float) -> None:
        self.q_est += lr * delta
        self.step += 1

    def _update_stats(self) -> None:
        self.n += 1
        prev = (self.n - 1) / self.n
        self.Q_avg = prev * self.Q_avg + self.q_est / self.n
        term = self.n**2
        self.v_a += term * self.Q_avg**2
        self.v_b += term * self.Q_avg
        self.v_q += term
        self.v_s += 1
        if self.track_history and self.true_q is not None:
            self.errors.append(abs(self.Q_avg - self.true_q))

    def fit(self, data_stream: Sequence[float]) -> None:
        self.reset()
        data = np.asarray(data_stream, dtype=float)
        burn_in = int(len(data) * self.burn_in_ratio)
        for idx, x in enumerate(data):
            lr = self.lr_schedule.value(self.step + 1)
            delta = self._compute_gradient(float(x))
            self._update_estimator(delta, lr)
            if idx >= burn_in:
                self._update_stats()

    def get_variance(self) -> float:
        if self.n == 0 or self.v_s == 0:
            return 0.0
        numerator = self.v_a - 2 * self.Q_avg * self.v_b + (self.Q_avg**2) * self.v_q
        return numerator / (self.n**2 * self.v_s)

