import numpy as np
from DPQuantile import *
from util import distribute_data
from scipy.stats import t
import math
    
class MultiChainDPQuantile_kchange_log(DPQuantile):
    """
    Multi-chain Differential Privacy Quantile Estimator (inherits from DPQuantile).
    """

    def __init__(self, burn_in_ratio=0.01, c0=2, a=0.51, K=-1, b=0, n_samples=5000000, **kwargs):
        """
        Initialize the multi-chain DP quantile estimator.

        Parameters
        ----------
        burn_in_ratio : float
            Hyperparameter about ratio of burn-in samples among total samples.
        c0 : float
            Initial learning rate coefficient.
        a : float
            Exponent for the learning rate schedule.
        K : int
            Number of chains. If negative, will be determined automatically.
        b : float
            Offset added to the denominator in the learning rate schedule.
        n_samples : int
            Number of samples.
        kwargs : dict
            Other keyword arguments for the base class.
        """
        super().__init__(burn_in_ratio=burn_in_ratio, **kwargs)
        self.c0 = c0  
        self.a = a    # exponent in learning rate
        self.b = b

        # Initialize number of chains
        self.K = K
        self.K_all = int(np.log10(n_samples) * 8)              # Total number of chains
        self.Knum_cur = int(np.log10(n_samples / 5) * 8)       # Current number of chains, start recording at 1/5 position
        self.incr_k_with_t = {}  # Record nodes where K changes
        self.n_samples = n_samples

        # Build the dictionary using a for loop
        for i in range(self.Knum_cur + 1, self.K_all + 1):
            # 1. Calculate the original floating-point value
            original_key_float = 10 ** (i / 8)
            start_key = math.ceil(original_key_float)

            # Find the smallest integer greater than or equal to start_key that is divisible by i
            remainder = start_key % (i - 1)
            if remainder == 0:
                new_key = start_key
            else:
                new_key = start_key + (i - remainder)
            self.incr_k_with_t[int(new_key)] = i
        if self.K < 0:
            self.chains = [DPQuantile(tau=self.tau, r=self.r,
                                      true_q=self.true_q)
                           for _ in range(self.K_all)]
        if self.K > 0:
            self.chains = [DPQuantile(tau=self.tau, r=self.r,
                                      true_q=self.true_q)
                           for _ in range(self.K)]
        for chains in self.chains:
            chains.reset()
            chains.burn_in_ratio = self.burn_in_ratio

    def _get_x(self, chain_idx):
        """
        Get a batch of client data.

        Parameters
        ----------
        chain_idx : int
            Index of the chain.

        Returns
        -------
        x : object or None
            Next data point for the chain, or None if data is exhausted.
        """
        try:
            return next(self.data_streams[chain_idx])
        except StopIteration:
            return None