"""TransferDPQuantile implementation built on the multi-chain estimator."""

from __future__ import annotations

from typing import Sequence

import numpy as np

from .multi_chain import MultiChainDPQuantile
from .weights import (
    compute_conservative_variance_weights,
    compute_conservative_weights,
    compute_optimal_weights,
)


class TransferDPQuantile:
    """Fit one multi-chain estimator per site and aggregate with custom weights."""

    def __init__(
        self,
        K_list: Sequence[int],
        rs: Sequence[float],
        tau: float,
        mechanism: str = "rr",
        burn_in_ratio: float = 0.0,
        c0: float = 1.0,
        a: float = 0.6,
        b0: float = 0.0,
        true_q: float | None = None,
    ) -> None:
        if len(K_list) != len(rs):
            raise ValueError("K_list and rs must have the same length.")
        self.K_list = list(int(k) for k in K_list)
        self.rs = list(float(r) for r in rs)
        self.tau = tau
        self.mechanism = mechanism
        self.burn_in_ratio = burn_in_ratio
        self.c0 = c0
        self.a = a
        self.b0 = b0
        self.true_q = true_q
        self.estimators = [
            MultiChainDPQuantile(
                K=k,
                tau=self.tau,
                r=r,
                mechanism=self.mechanism,
                burn_in_ratio=self.burn_in_ratio,
                c0=self.c0,
                a=self.a,
                b0=self.b0,
                true_q=self.true_q,
            )
            for k, r in zip(self.K_list, self.rs)
        ]
        self.global_means: list[float] = []
        self.global_vars: list[float] = []
        self.global_biases: list[float] = []

    def fit(self, datas: Sequence[Sequence[float]]) -> "TransferDPQuantile":
        if len(datas) != len(self.estimators):
            raise ValueError("datas must match the number of sites.")
        self.global_means.clear()
        self.global_vars.clear()
        for estimator, stream in zip(self.estimators, datas):
            estimator.fit(stream)
            self.global_means.append(estimator.global_mean)
            self.global_vars.append(estimator.global_var)
        target_mean = self.global_means[0]
        self.global_biases = [mean - target_mean for mean in self.global_means]
        return self

    def aggregate(self, lambd: float, strategy: str = "opt"):
        if not self.global_means:
            raise RuntimeError("Call fit() before aggregate().")
        strategy = strategy.lower()
        if strategy == "opt":
            weights = compute_optimal_weights(self.global_biases, self.global_vars, lambd)
        elif strategy == "cons":
            weights = compute_conservative_weights(self.global_biases, self.global_vars, lambd)
        elif strategy == "consvar":
            weights = compute_conservative_variance_weights(
                self.global_biases, self.global_vars, self.K_list, lambd
            )
        else:
            raise ValueError(f"Unknown aggregation strategy: {strategy}")
        estimate = float(np.dot(weights, self.global_means))
        variance = float(np.dot(weights, self.global_vars))
        return weights, estimate, variance

