import numpy as np
import sys
from MultiChainDPQuantile import MultiChainDPQuantile
from sklearn.model_selection import KFold
import cvxpy as cp
from scipy.stats import chi2, t
from util_di import quantile_plugin 

def compute_lasso_weights(biases, vars_, lambd=0.1):
    """
    Solve the LASSO-style convex program:

        minimize_w  Σ σ_k² w_k²  +  λ Σ b_k² w_k
        subject to  w_k ≥ 0,  Σ w_k = 1
    """
    b2 = np.asarray(biases)**2
    s2 = np.asarray(vars_)
    K  = len(b2)

    w  = cp.Variable(K)
    objective = cp.Minimize(cp.sum(cp.multiply(s2, cp.square(w))) +
                            lambd * b2 @ w)
    constraints = [w >= 0, cp.sum(w) == 1]
    prob = cp.Problem(objective, constraints)
    prob.solve()
    return w.value

def compute_optimal_weights(biases, vars_, lambd=0.1):
    """
    Closed-form (quadratic) weights

        w_k = ( 1 / (σ_k² + λ b_k²) )  /  Σ_j 1 / (σ_j² + λ b_j²)
    """
    biases, vars_ = map(np.asarray, (biases, vars_)) 
    inv = 1.0 / (vars_  + lambd * biases**2)       
    return inv / inv.sum()                         

def compute_consvar_weights(biases, vars_, lambd=1, K_list=None,
                         alpha_var=0.0001): 
   
    b, v, K_arr = map(np.asarray, (biases, vars_, K_list))  
    # (1) Conservative variance upper bound  ˜σ_k² 
    chi2_q   = chi2.ppf(alpha_var, K_arr[1:] - 1)    
    vars_t   = v.copy()
    vars_t[1:] *= (K_arr[1:] - 1) / chi2_q        # keep target site unchanged

    # (2) Final weights
    denom    = vars_t + lambd * b**2
    inv        = 1 / denom
    return inv / inv.sum()

def compute_cons_weights(biases, vars_, lambd=1):
    """
    Bias conservative weights (z-score = 1.96 by default).
    """
    b, v = map(np.asarray, (biases, vars_))
    
    z_q      = 1.96
    se       = np.sqrt(v[1:] + v[0])                      # √(σ̂_k²/n_k + σ̂_0²/n_0)
    b_t      = b.copy()
    b_t[1:]  = np.abs(b[1:]) + z_q * se          # target bias remains zero


    denom    = v + lambd * b_t**2
    inv        = 1 / denom
    return inv / inv.sum()

def get_loss(data,q_est,tau=0.5):
    """
    Check loss for a scalar quantile estimate (pinball / check function).
    """
    diff = data - q_est
    loss = np.where(diff >= 0, tau * diff, (1 - tau) * (-diff))
    return loss.mean()



class TransferDPQuantile:
    def __init__(self, K_list, rs, **kwargs):
        """
        Parameters
        ----------
        K_list : list[int]
            Length-2 list with numbers of chains for the *target* (index 0)
            and *source* site (index 1).
        rs : list[float]
            Randomized-response rates for each site.
        **kwargs :
            Additional keyword arguments passed to ``MultiChainDPQuantile``.
        """
        self.K_list = K_list
        self.n_sites = len(K_list)
        self.estimators = [MultiChainDPQuantile(K=K_list[i],r=rs[i],
                                                **kwargs) for i in range(self.n_sites)]
    
    def fit(self, datas):        
        for k in range(self.n_sites):
            self.estimators[k].fit(datas[k])

        target_mean = self.estimators[0].global_mean
        self.global_means = [est.global_mean for est in self.estimators]
        self.global_biases = [est.global_mean - target_mean for est in self.estimators]
        self.global_vars = [est.global_var for est in self.estimators]
        return self
        
    def aggregate(self,lambd=1,method='opt'):
        """
        Aggregate site-specific estimates using one of several weighting
        schemes.

        Parameters
        ----------
        lambd : float, default=1
            Regularization coefficient λ.
        method : {'opt', 'lasso', 'cons', 'consvar'}, default='opt'
            Weight-construction strategy.

        Returns
        -------
        weights : ndarray of shape (n_sites,)
        estimate : float
            Aggregated point estimate.
        var : float
            Aggregated variance.
        """
        if method == 'opt':
            weights = np.array(compute_optimal_weights(self.global_biases, self.global_vars,
                                                                lambd=lambd))
        elif method == 'lasso':
            weights = np.array(compute_lasso_weights(self.global_biases, self.global_vars,
                                                            lambd=lambd))
        elif method == 'cons':
            weights = np.array(compute_cons_weights(self.global_biases, self.global_vars,
                                    lambd=lambd))
        elif method == 'consvar':
            weights = np.array(compute_consvar_weights(self.global_biases, self.global_vars,
                                    lambd=lambd,K_list=self.K_list))
            
        estimate = sum(w * est for w, est in zip(weights, self.global_means))
        var = np.dot(np.array(weights), np.array(self.global_vars))
        
        return weights,estimate,var


def cv_select_lambda(datas, K_list, rs, lambd_grid,
                     tau=0.5,n_splits=5, **mc_kwargs):
    """
    Five-fold CV for the *optimal* weighting scheme.

    Returns
    -------
    float
        λ that minimises the mean validation loss.
    """

    splitters = [
        list(KFold(n_splits=n_splits, shuffle=False).split(np.arange(len(d))))
        for d in datas]
    losses_opt  = np.zeros((len(lambd_grid), n_splits))
    losses_lasso  = np.zeros((len(lambd_grid), n_splits))
    losses_cons  = np.zeros((len(lambd_grid), n_splits))

    for fold in range(n_splits):
        # ① training indices for each site
        train_sets = [
            d[splitters[s][fold][0]]      # train_indices
            for s, d in enumerate(datas)]
        # ② test indices for *target* site
        test_tar = datas[0][ splitters[0][fold][1] ]  # test_indices

        for j, lam in enumerate(lambd_grid):
            model = TransferDPQuantile(K_list=K_list, rs=rs,
                                       tau=tau, **mc_kwargs)
            model.fit(train_sets)
            _,q_hat_opt,_ = model.aggregate(lambd=lam,method='opt')
            losses_opt[j, fold] = get_loss(test_tar,q_hat_opt,tau=tau)

    mean_loss_opt = losses_opt.mean(axis=1); 

    best_opt = lambd_grid[mean_loss_opt.argmin()]; 
    return best_opt

def cv_select_lambda_lasso(datas, K_list, rs, lambd_grid,
                     tau=0.5,n_splits=5, **mc_kwargs):
    """
    Five-fold CV for *both* optimal and LASSO weighting schemes.

    Returns
    -------
    tuple of floats
        (best_λ_opt, best_λ_lasso)
    """

    splitters = [
        list(KFold(n_splits=n_splits, shuffle=False).split(np.arange(len(d))))
        for d in datas]
    losses_opt  = np.zeros((len(lambd_grid), n_splits))
    losses_lasso  = np.zeros((len(lambd_grid), n_splits))
    losses_cons  = np.zeros((len(lambd_grid), n_splits))

    for fold in range(n_splits):
        train_sets = [
            d[splitters[s][fold][0]]      # train_indices
            for s, d in enumerate(datas)]
        test_tar = datas[0][ splitters[0][fold][1] ]  # test_indices

        for j, lam in enumerate(lambd_grid):
            model = TransferDPQuantile(K_list=K_list, rs=rs,
                                       tau=tau, **mc_kwargs)
            model.fit(train_sets)
            _,q_hat_opt,_ = model.aggregate(lambd=lam,method='opt')
            losses_opt[j, fold] = get_loss(test_tar,q_hat_opt,tau=tau)

            _,q_hat_lasso,_ = model.aggregate(lambd=lam,method='lasso')
            losses_lasso[j, fold] = get_loss(test_tar,q_hat_lasso,tau=tau)

    mean_loss_opt = losses_opt.mean(axis=1) 
    mean_loss_lasso = losses_lasso.mean(axis=1)

    best_opt = lambd_grid[mean_loss_opt.argmin()] 
    best_lasso = lambd_grid[mean_loss_lasso.argmin()]
    return best_opt,best_lasso