"""Multi-chain DP quantile estimator."""

from __future__ import annotations

from typing import Sequence

import numpy as np

from .data import split_stream
from .dp import DPQuantile, LearningRateSchedule


class MultiChainDPQuantile:
    """Wrapper that runs multiple independent DPQuantile chains."""

    def __init__(
        self,
        K: int,
        tau: float,
        r: float,
        mechanism: str = "rr",
        burn_in_ratio: float = 0.0,
        c0: float = 1.0,
        a: float = 0.6,
        b0: float = 0.0,
        true_q: float | None = None,
        track_history: bool = False,
    ) -> None:
        self.K = max(int(K), 1)
        self.tau = tau
        self.r = r
        self.mechanism = mechanism
        self.burn_in_ratio = burn_in_ratio
        self.lr_schedule = LearningRateSchedule(c0=c0, a=a, b0=b0)
        self.true_q = true_q
        self.track_history = track_history
        self.chains = [
            DPQuantile(
                tau=self.tau,
                r=self.r,
                mechanism=self.mechanism,
                true_q=self.true_q,
                burn_in_ratio=self.burn_in_ratio,
                track_history=self.track_history,
            )
            for _ in range(self.K)
        ]
        self.global_mean = 0.0
        self.global_var = 0.0
        self.global_means: list[float] = []

    def fit(self, data_stream: Sequence[float]) -> None:
        streams = split_stream(data_stream, self.K)
        for chain, chunk in zip(self.chains, streams):
            chain.lr_schedule = self.lr_schedule
            chain.fit(chunk)
        self._compute_global_stats()

    def _compute_global_stats(self) -> None:
        self.global_means = [chain.Q_avg for chain in self.chains]
        self.global_mean = float(np.mean(self.global_means)) if self.global_means else 0.0
        if len(self.global_means) > 1:
            chain_var = float(np.var(self.global_means, ddof=1))
        else:
            chain_var = 0.0
        self.global_var = chain_var / self.K if self.K > 0 else 0.0

