import numpy as np
import sys
from util import distribute_data, generate_data
from DPQuantile import *
from MultiChainDPQuantile import MultiChainDPQuantile_kchange_log
from scipy.special import lambertw,gamma

def compute_radius(T, rho=0.001, n_samples=5000000, alpha=0.05, m=1, typ='ub'):
    """
    Compute the confidence interval radius.

    Parameters
    ----------
    T : np.ndarray
        Array of T values (number of steps).
    rho : float, optional
        Hyperparameter rho for gm (default is 0.001).
    n_samples : int, optional
        Number of samples (default is 5000000).
    alpha : float, optional
        Significance level for confidence interval (default is 0.05).
    m : int, optional
        Parameter m for ub, can be scalar or array (default is 1).
    typ : str, optional
        Type of boundary calculation ('ub' or 'gm').

    Returns
    -------
    np.ndarray
        The computed confidence interval radius for each T.
    """
    if typ == 'ub':
        def calculate_delta_T_m_vectorized(T_values, m_values, alpha):
            """
            Calculate δ_T,m according to the new formula (eq4) in the essay.
            δ_T,m = 1.7 sqrt( (loglog(max{2T_values / m, e}) + 0.72 log(10.4/alpha)) / T_values )

            Parameters:
                T_values (np.ndarray): Array of T values.
                m (int): Scalar of m values.
                alpha (float): Significance level for confidence interval.

            Returns:
                np.ndarray: Array of calculated δ_T,m values.
            """
            ratio = 2 * T_values / m_values
            inner_log = np.log(np.maximum(ratio, np.exp(1)))  # If ratio <= 0, inner_log will be nan or -inf
            value_if_true = np.log(inner_log)  # If inner_log <= 0, value_if_true will be nan

            log_term = value_if_true
            alpha_term = 0.72 * np.log(10.4 / alpha)
            radius = 1.7 * np.sqrt(log_term + alpha_term)

            return radius
        radius = calculate_delta_T_m_vectorized(T, m, alpha)
    elif typ == 'gm':
        def calculate_delta_T_m_vectorized(T_values, rho, alpha):
            """
            Calculate δ_T,m according to the new formula (eq4) in the essay.
            δ_T,m = sqrt( 2(T_values rho^2 + 1) / (T_values^2 rho^2) log(sqrt(T_values rho^2 + 1)/alpha) )

            Parameters:
                T_values (np.ndarray): Array of T values.
                rho (float): Scalar of rho values.
                alpha (float): Significance level for confidence interval.

            Returns:
                np.ndarray: Array of calculated δ_T,m values.
            """
            rho_sq = rho ** 2

            term_T_rho_sq_plus_1 = T_values * rho_sq + 1.0

            numerator_part1 = 2.0 * term_T_rho_sq_plus_1
            denominator_part1 = T_values * rho_sq

            part1 = numerator_part1 / denominator_part1

            sqrt_of_term = np.sqrt(term_T_rho_sq_plus_1)

            log_argument = sqrt_of_term / alpha

            log_part = np.log(log_argument) 

            expression_under_main_sqrt = part1 * log_part

            delta_T_m_values = np.sqrt(expression_under_main_sqrt)

            return delta_T_m_values

        radius = calculate_delta_T_m_vectorized(T, rho, alpha)
    else:
        raise ValueError(f"Unsupported boundary type: {typ}")
    return radius

class GADPQuantile_kchange_log(MultiChainDPQuantile_kchange_log):
    """
    Gaussian Approximate Differential Privacy Quantile Estimator (inherits from MultiChainDPQuantile_kchange).
    """

    def __init__(self, alpha=0.05,radius_typ = 'gm', **kwargs):
        """
        Initialize the multi-chain DP quantile estimator.

        Parameters
        ----------
        alpha : float
            Significance level for confidence interval.
        kwargs : dict
            Other keyword arguments for the base class.
        """
        super().__init__(**kwargs)
        self.alpha = alpha
        self.radius_typ = radius_typ
        # Multi-chain statistics
        self.global_means = []      # Cross-chain means at each time step
        self.global_vars = []       # Cross-chain variances at each time step
        self.total_t_cur = 0        # Current total step count

    def _calculate_weighted_variance_formula(self, current_means, weights):
        """
        Calculate weighted variance estimate according to the provided formula.
        Formula: σ̂²_T = Σ[ω_k * (m_k)²] - [Σ(ω_k * m_k)]²
        where m_k is the k-th element of current_means, ω_k is the corresponding weight.

        Parameters
        ----------
        current_means : np.ndarray
            1D NumPy array containing means m_k.
        weights : np.ndarray
            1D NumPy array containing weights ω_k, must be the same length as current_means.

        Returns
        -------
        float
            The value calculated according to the formula.
        """
        means = np.asarray(current_means)
        w = np.asarray(weights)
        if means.shape != w.shape or means.ndim != 1:
            raise ValueError("Input arrays 'current_means' and 'weights' must be 1D arrays of the same length.")
        term1 = np.sum(w * (means ** 2))
        weighted_sum = np.sum(w * means)
        term2 = weighted_sum ** 2
        result = term1 - term2
        return result

    def _compute_global_stats(self, Knum, weights):
        """
        Compute statistics across chains.

        Parameters
        ----------
        Knum : int
            Number of chains to consider.
        weights : np.ndarray
            Weights for each chain.
        """
        current_means = [chain.Q_avg for chain in self.chains[:Knum]]
        mean = np.average(current_means, weights=weights)
        if len(current_means) > 1:
            var = self._calculate_weighted_variance_formula(current_means, weights)
        else:
            var = 0.0
        self.global_means.append(mean)
        self.global_vars.append(var)

    def _get_x(self, chain_idx):
        """
        Get a batch of client data.

        Parameters
        ----------
        chain_idx : int
            Index of the chain.

        Returns
        -------
        object or None
            Next data point for the chain, or None if data is exhausted.
        """
        try:
            return next(self.data_streams[chain_idx])
        except StopIteration:
            return None

    def fit(self, dist_type, tau, n_samples):
        """
        Multi-chain parallel training method.

        Parameters
        ----------
        dist_type : str
            The type of distribution.
        tau : float
            The quantile to compute.
        n_samples : int
            Number of samples.
        """
        self.reset()

        # Burn-in phase: calculate number of burn-in samples
        burn_in = int(self.n_samples * (1.0 / self.r ** 2) * (1.0 / (100.0 * self.burn_in_ratio)))

        sample_quzhen = (n_samples // self.K_all) * self.K_all
        burn_in_per_chain_quzhen = burn_in // self.Knum_cur
        print(f"Number of burn-in samples per chain: {burn_in_per_chain_quzhen}, total burn-in samples: {burn_in}, total samples: {sample_quzhen}")
        data_stream, true_q = generate_data(dist_type, tau, sample_quzhen)
        data_streams = distribute_data(data_stream, self.K_all)
        self.max_steps = len(data_streams[0])
        self.data_streams = [iter(data) for data in data_streams]
        del data_streams

        # Main iteration
        for ii in range(n_samples):
            self.total_t_cur += 1
            # Case when K is being increased and not yet full
            if len(self.incr_k_with_t) > 0:
                if self.total_t_cur >= min(self.incr_k_with_t):
                    cur_chain = self.Knum_cur
                    chain = self.chains[cur_chain]
                    # Burn first, then gradient update
                    x = self._get_x(cur_chain)
                    if x is None:
                        return
                    chain = self.chains[cur_chain]
                    delta = chain._compute_gradient(x)
                    lr = lr_schedule(chain.step + 1, c0=self.c0, a=self.a)
                    chain._update_estimator(delta, lr)
                    if chain.step > burn_in_per_chain_quzhen:
                        chain._update_stats()
                        weights = np.array([chain.n for chain in self.chains[:self.Knum_cur + 1]])
                        total_weight = np.sum(weights)
                        if total_weight != 0:
                            weights = weights / total_weight
                        else:
                            print("Warning: Total weight is zero, cannot normalize.")
                        self._compute_global_stats(self.Knum_cur + 1, weights)
                    else:
                        weights = np.array([chain.n for chain in self.chains[:self.Knum_cur]])
                        total_weight = np.sum(weights)
                        if total_weight != 0:
                            weights = weights / total_weight
                        else:
                            print("Warning: Total weight is zero, cannot normalize.")
                        self._compute_global_stats(self.Knum_cur, weights)

                    # Case when the newly added K is just filled
                    if self.total_t_cur == (int(min(self.incr_k_with_t)) + int(min(self.incr_k_with_t) / self.Knum_cur) - 1):
                        print(f"Delete the first element {min(self.incr_k_with_t)}")
                        del self.incr_k_with_t[min(self.incr_k_with_t)]
                        self.Knum_cur += 1
                else:
                    # Case when K is not being increased
                    cur_chain = ii % self.Knum_cur
                    x = self._get_x(cur_chain)
                    if x is None:
                        return
                    chain = self.chains[cur_chain]
                    delta = chain._compute_gradient(x)
                    lr = lr_schedule(chain.step + 1, c0=self.c0, a=self.a)
                    chain._update_estimator(delta, lr)

                    if chain.step > burn_in_per_chain_quzhen:
                        chain._update_stats()
                        if chain.n > 1:
                            weights = np.array([chain.n for chain in self.chains[:self.Knum_cur]])
                            total_weight = np.sum(weights)
                            if total_weight != 0:
                                weights = weights / total_weight
                            else:
                                print("Warning: Total weight is zero, cannot normalize.")
                            self._compute_global_stats(self.Knum_cur, weights)
            else:
                # Case when K is not being increased
                cur_chain = ii % self.Knum_cur
                x = self._get_x(cur_chain)
                if x is None:
                    return
                chain = self.chains[cur_chain]
                delta = chain._compute_gradient(x)
                lr = lr_schedule(chain.step + 1, c0=self.c0, a=self.a)
                chain._update_estimator(delta, lr)

                if chain.step > burn_in_per_chain_quzhen:
                    chain._update_stats()
                    if chain.n > 1:
                        weights = np.array([chain.n for chain in self.chains[:self.Knum_cur]])
                        total_weight = np.sum(weights)
                        if total_weight != 0:
                            weights = weights / total_weight
                        else:
                            print("Warning: Total weight is zero, cannot normalize.")
                        self._compute_global_stats(self.Knum_cur, weights)

class GADPQuantile_kfix(MultiChainDPQuantile_kchange_log):
    """
    Multi-chain Differential Privacy Quantile Estimator (inherits from MultiChainDPQuantile_kchange).
    """

    def __init__(self, alpha=0.05, **kwargs):
        """
        Initialize the multi-chain DP quantile estimator.

        Parameters
        ----------
        alpha : float
            Significance level for confidence interval.
        kwargs : dict
            Other keyword arguments for the base class.
        """
        super().__init__(**kwargs)
        self.alpha = alpha

        # Multi-chain statistics
        self.global_means = []      # Cross-chain means at each time step
        self.global_vars = []       # Cross-chain variances at each time step
        self.total_t_cur = 0        # Current total step count

    def _calculate_weighted_variance_formula(self, current_means, weights):
        """
        Calculate weighted variance estimate according to the provided formula.
        Formula: σ̂²_T = Σ[ω_k * (m_k)²] - [Σ(ω_k * m_k)]²
        where m_k is the k-th element of current_means, ω_k is the corresponding weight.

        Parameters
        ----------
        current_means : np.ndarray
            1D NumPy array containing means m_k.
        weights : np.ndarray
            1D NumPy array containing weights ω_k, must be the same length as current_means.

        Returns
        -------
        float
            The value calculated according to the formula.
        """
        means = np.asarray(current_means)
        w = np.asarray(weights)
        if means.shape != w.shape or means.ndim != 1:
            raise ValueError("Input arrays 'current_means' and 'weights' must be 1D arrays of the same length.")
        term1 = np.sum(w * (means ** 2))
        weighted_sum = np.sum(w * means)
        term2 = weighted_sum ** 2
        result = term1 - term2
        return result

    def _compute_global_stats(self, Knum, weights):
        """
        Compute statistics across chains.

        Parameters
        ----------
        Knum : int
            Number of chains to consider.
        weights : np.ndarray
            Weights for each chain.
        """
        current_means = [chain.Q_avg for chain in self.chains[:Knum]]
        mean = np.average(current_means, weights=weights)
        if len(current_means) > 1:
            var = self._calculate_weighted_variance_formula(current_means, weights) * (self.total_t_cur / Knum)
        else:
            var = 0.0
        self.global_means.append(mean)
        self.global_vars.append(var)

    def _get_x(self, chain_idx):
        """
        Get a batch of client data.

        Parameters
        ----------
        chain_idx : int
            Index of the chain.

        Returns
        -------
        object or None
            Next data point for the chain, or None if data is exhausted.
        """
        try:
            return next(self.data_streams[chain_idx])
        except StopIteration:
            return None

    def fit(self, dist_type, tau, n_samples):
        """
        Multi-chain parallel training method.

        Parameters
        ----------
        dist_type : str
            The type of distribution.
        tau : float
            The quantile to compute.
        n_samples : int
            Number of samples.
        """
        self.reset()

        # Burn-in phase: calculate number of burn-in samples
        burn_in = int(self.n_samples * (1.0 / self.r ** 2) * (1.0 / (100.0 * self.burn_in_ratio)))

        sample_quzhen = (n_samples // self.K) * self.K
        burn_in_per_chain_quzhen = burn_in // self.K
        print(f"Number of burn-in samples per chain: {burn_in_per_chain_quzhen}, total burn-in samples: {burn_in}, total samples: {sample_quzhen}")
        data_stream, true_q = generate_data(dist_type, tau, sample_quzhen)
        data_streams = distribute_data(data_stream, self.K)
        self.max_steps = len(data_streams[0])
        self.data_streams = [iter(data) for data in data_streams]
        del data_streams

        # Main iteration
        for ii in range(n_samples):
            self.total_t_cur += 1
            # Case without increasing K
            cur_chain = ii % self.K
            # print(f"Current number of chains: {self.Knum_cur}, current chain index: {cur_chain}, current total steps: {self.total_t_cur}")
            x = self._get_x(cur_chain)
            if x is None:  # Data exhausted
                return
            chain = self.chains[cur_chain]
            delta = chain._compute_gradient(x)
            lr = lr_schedule(chain.step + 1, c0=self.c0, a=self.a)
            chain._update_estimator(delta, lr)

            if chain.step > burn_in_per_chain_quzhen:
                chain._update_stats()

                if chain.n > 1:  # Ensure each chain has data
                    weights = np.array([chain.n for chain in self.chains])
                    total_weight = np.sum(weights)

                    # Ensure the sum is not zero to avoid division by zero
                    if total_weight != 0:
                        # Reassign, weights now points to a new float64 array
                        weights = weights / total_weight
                    else:
                        # Handle the case where the sum is zero, e.g., set to equal weights or keep as 0
                        print("Warning: Total weight is zero, cannot normalize.")
                    self._compute_global_stats(self.K, weights)