import numpy as np
from DPQuantile import *
from util import distribute_data
from scipy.stats import t

def compute_radius(K, alpha=0.05):
    """
    Compute the radius of a (1 − α) confidence interval
    using the Student-t distribution.
 
    Parameters
    ----------
    K : int
        Number of independent chains.
    alpha : float, default=0.05
        Significance level.

    Returns
    -------
    float
        Radius = t_{1-α/2, K-1} / √K
    """
    t_crit = t.ppf(1 - alpha/2, df=K-1)  
    return t_crit

class MultiChainDPQuantile(DPQuantile):
    """
    Multi-chain differentially-private quantile estimator (inherits from
    ``DPQuantile``).
    """
    def __init__(self, K=5, burn_in_ratio=0,
                 c0=1, a=0.6,b0=0, **kwargs):
        super().__init__(burn_in_ratio=burn_in_ratio, **kwargs)
        self.K = K    # number of chains
        self.c0 = c0 # learning-rate scale
        self.a = a   # learning-rate exponent
        self.b0 = b0 # optional offset in the denominator

         # Initialize K independent DPQuantile chains
        self.chains = [DPQuantile(tau=self.tau, r=self.r, 
                                  true_q=self.true_q, track_history=self.track_history) 
                       for _ in range(K)]
        for chain in self.chains:
            chain.reset()    
            chain.burn_in_ratio = self.burn_in_ratio

    def _lr_schedule(self,step):
        """
        Custom learning-rate schedule for individual chains:
            ηₜ = c₀ / (tᵃ + b₀)
        """
        c0 = self.c0
        a = self.a
        b0 = self.b0
        lr = c0 / (step**a + b0)
        return lr
    
    def _compute_global_stats(self):
        """
        Aggregate per-chain statistics to obtain the global mean and the
        variance of the averaged estimator.
        """
        current_means = [chain.Q_avg for chain in self.chains]
        mean = np.mean(current_means)
        # Unbiased sample variance across chains (ddof = 1)
        if len(current_means) > 1:
            var = np.var(current_means, ddof=1)
        else:
            var = 0.0
        self.global_mean = mean
        self.global_var = var / self.K  # Variance of the average of K chains is var / K
        
    def _get_x(self, chain_idx):
        """Return the next observation for chain `chain_idx`; `None` if exhausted."""
        try:
            return next(self.data_streams[chain_idx])
        except StopIteration:
            return None
    
    def fit(self, data_stream):
        """
        Run K independent chains in parallel on disjoint data splits.

        Parameters
        ----------
        data_stream : array-like
            Full data sequence to be partitioned equally among chains.
        """
        self.reset()
        # Split the data stream into K roughly equal parts
        data_streams = distribute_data(data_stream, self.K)
        self.data_streams = [iter(data) for data in data_streams]
        max_steps = len(data_streams[0])
        burn_in = int(max_steps * self.burn_in_ratio)
        # Main training loop        
        for ii in range(max_steps):
            for c in range(self.K):
                x = self._get_x(c)
                if x is None:   # data exhausted for this chain
                    return
                chain = self.chains[c]
                delta = chain._compute_gradient(x)
                lr = self._lr_schedule(chain.step + 1)
                chain._update_estimator(delta, lr)
                # Update chain-specific running statistics after burn-in
                if ii >= burn_in:
                    chain._update_stats()
        # After all chains finish, compute aggregated statistics
        self._compute_global_stats()