import numpy as np
import sys
from util_di import *
from TransferDPQuantile import TransferDPQuantile,cv_select_lambda_lasso
from DPQuantile import DPQuantile
import ray
import time
import itertools
import pickle as pkl
from pathlib import Path
import pandas as pd 
import copy
from itertools import product

def train_transfer(seed=None,
                   dist_type='normal', tau=0.5, rs=None,
                   K_base=5, n_samples=1000, n_sites=3, 
                   source_prop=1.0,           # sample-size ratio (source vs. target)
                   biases=None,
                   lambd_grid=[0.1],
                   burn_in_ratio=0, c0=1, a=0.6,b0=0):
    """
    End-to-end training pipeline for **one** replication.

    Steps
    -----
    1. Generate a data stream for each site (`datas`) together with its
       population quantile (`true_qs`).  
       • Target site (index 0) always has bias 0.  
       • Source sites are shifted by `biases[i]`.

    2. Run cross-validation on `lambd_grid` to select the optimal λ.

    3. Train a `TransferDPQuantile` model and aggregate site-specific
       estimates using the CV-selected λ.

    Parameters
    ----------
    seed : int or None
        Random seed for reproducibility.
    dist_type, tau, rs, K_base, n_samples, n_sites, source_prop, biases :
        See *caller* for details.
    lambd_grid : sequence of float
        Grid of λ values tested in cross-validation.
    burn_in_ratio, c0, a, b0 :
        Learning-rate hyper-parameters shared by all LDP chains.

    Returns
    -------
    dict with keys
        'true_qs'     : list[float] – population quantiles per site
        'opt_weights' : ndarray     – optimal weights (sum = 1)
        'opt_est'     : float       – aggregated estimate
        'opt_var'     : float       – aggregated variance
    """
    np.random.seed(seed)
    
    # ------------------------------------------------------------------
    # (1) Generate federated data
    # ------------------------------------------------------------------
    n_n_samples = get_n_n_sample(n_sites, n_samples, source_prop)

    datas, true_qs = generate_federate_data(dist_type, tau, n_n_samples,
                                   biases)
    
    K_list = get_prop_K(datas, K_base=K_base)

    # ------------------------------------------------------------------
    # (2) Cross-validation for λ
    # ------------------------------------------------------------------
    opt_lambd,lasso_lambd = cv_select_lambda_lasso(datas, K_list=K_list, rs=rs,
                                             lambd_grid=lambd_grid,
                     tau=tau, true_q=true_qs[0], 
                               burn_in_ratio=burn_in_ratio,
                               c0=c0, a=a, b0=b0)
    
    # ------------------------------------------------------------------
    # (3) Train and aggregate
    # ------------------------------------------------------------------
    model = TransferDPQuantile(K_list=K_list,
                               rs=rs, tau=tau, true_q=true_qs[0], 
                               burn_in_ratio=burn_in_ratio,
                               c0=c0, a=a, b0=b0)

    model.fit(datas)

    opt_weights, opt_est, opt_var = model.aggregate(lambd=opt_lambd,method='opt')

    lasso_weights, lasso_est, lasso_var = model.aggregate(lambd=lasso_lambd,method='lasso')
    

    
    return {
        'true_qs': true_qs,
        'opt_weights': opt_weights,
        'opt_est': opt_est,
    'opt_var':opt_var,
        'lasso_weights': lasso_weights,
        'lasso_est': lasso_est,
    'lasso_var':lasso_var}

@ray.remote
def train_transfer_remote(**kwargs):
    return train_transfer(**kwargs)
# --------------------------------------------------------------------------- #
# 2. Parallel simulation over `n_simu` runs                                   #
# --------------------------------------------------------------------------- #
def run_simulation_transfer(n_simu=100,base_seed=2025, 
                            dist_type='normal', tau=0.5, rs=None,
                            K_base=5,
                            n_samples=1000,  n_sites=3, 
                            source_prop=1.0,
                            biases=None,
                            lambd_grid=[0.1],
                            burn_in_ratio=0, c0=1, a=0.6,b0=0):

    futures = [
        train_transfer_remote.remote(
            seed=base_seed + i,
            dist_type=dist_type, tau=tau, rs=rs,
            K_base=K_base, n_samples=n_samples, n_sites=n_sites,
            source_prop=source_prop,
            biases=biases,
            lambd_grid=lambd_grid,
            burn_in_ratio=burn_in_ratio, c0=c0, a=a, b0=b0
        )
        for i in range(n_simu)
    ]        

    results = ray.get(futures)
    
    true_qs            = np.asarray(results[0]['true_qs'])
    opt_weights = np.asarray([r['opt_weights']  for r in results])
    opt_est    = np.asarray([r['opt_est']    for r in results])
    opt_var  = np.asarray([r['opt_var']  for r in results])
    lasso_weights = np.asarray([r['lasso_weights']  for r in results])
    lasso_est    = np.asarray([r['lasso_est']    for r in results])
    lasso_var  = np.asarray([r['lasso_var']  for r in results])

    return dict(true_qs=true_qs,
                opt_weights=opt_weights,
                opt_est=opt_est,
                opt_var=opt_var,
                lasso_weights=lasso_weights,
                lasso_est=lasso_est,
                lasso_var=lasso_var)