import numpy as np
import sys
from util_di import *
from TransferDPQuantile import TransferDPQuantile,cv_select_lambda
from DPQuantile import DPQuantile
import ray
import time
import itertools
import pickle as pkl
from pathlib import Path
import pandas as pd
import copy 
from itertools import product

def train_transfer(seed=None,
                   dist_type='normal', tau=0.5, rs=None,
                   K_base=5, n_samples=1000, n_sites=3, 
                   source_prop=1.0,            # sample-size ratio (source vs. target)
                   biases=None,
                   lambd_grid=[0.1],
                   burn_in_ratio=0, c0=1, a=0.6,b0=0):
    """
    End-to-end training pipeline for **one** replication.

    Steps
    -----
    1. Generate a data stream for each site (`datas`) together with its
       population quantile (`true_qs`).  
       • Target site (index 0) always has bias 0.  
       • Source sites are shifted by `biases[i]`.

    2. Run cross-validation on `lambd_grid` to select the optimal λ.

    3. Train a `TransferDPQuantile` model and aggregate site-specific
       estimates using the CV-selected λ.

    Parameters
    ----------
    seed : int or None
        Random seed for reproducibility.
    dist_type, tau, rs, K_base, n_samples, n_sites, source_prop, biases :
        See *caller* for details.
    lambd_grid : sequence of float
        Grid of λ values tested in cross-validation.
    burn_in_ratio, c0, a, b0 :
        Learning-rate hyper-parameters shared by all LDP chains.

    Returns
    -------
    dict with keys
        'true_qs'     : list[float] – population quantiles per site
        'opt_weights' : ndarray     – optimal weights (sum = 1)
        'opt_est'     : float       – aggregated estimate
        'opt_var'     : float       – aggregated variance
    """
    np.random.seed(seed)
    
    # ------------------------------------------------------------------
    # (1) Generate federated data
    # ------------------------------------------------------------------
    n_n_samples = get_n_n_sample(n_sites, n_samples, source_prop)

    datas, true_qs = generate_federate_data(dist_type, tau, n_n_samples,
                                   biases)
    
    K_list = get_prop_K(datas, K_base=K_base)

    # ------------------------------------------------------------------
    # (2) Cross-validation for λ
    # ------------------------------------------------------------------
    opt_lambd = cv_select_lambda(datas, K_list=K_list, rs=rs,
                                             lambd_grid=lambd_grid,
                     tau=tau, true_q=true_qs[0], 
                               burn_in_ratio=burn_in_ratio,
                               c0=c0, a=a, b0=b0)
    
    # ------------------------------------------------------------------
    # (3) Train and aggregate
    # ------------------------------------------------------------------
    model = TransferDPQuantile(K_list=K_list,
                               rs=rs, tau=tau, true_q=true_qs[0], 
                               burn_in_ratio=burn_in_ratio,
                               c0=c0, a=a, b0=b0)

    model.fit(datas)

    opt_weights, opt_est, opt_var = model.aggregate(lambd=opt_lambd,method='opt')
    

    
    return {
        'true_qs': true_qs,
        'opt_weights': opt_weights,
        'opt_est': opt_est,
    'opt_var':opt_var}

@ray.remote
def train_transfer_remote(**kwargs):
    return train_transfer(**kwargs)

# --------------------------------------------------------------------------- #
# 2. Parallel simulation over `n_simu` runs                                   #
# --------------------------------------------------------------------------- #

def run_simulation_transfer(n_simu=100,base_seed=2025, 
                            dist_type='normal', tau=0.5, rs=None,
                            K_base=5,
                            n_samples=1000,  n_sites=3, 
                            source_prop=1.0,
                            biases=None,
                            lambd_grid=[0.1],
                            burn_in_ratio=0, c0=1, a=0.6,b0=0):

    """
    Launch `n_simu` independent replications in parallel (Ray).

    Returns
    -------
    dict with stacked NumPy arrays:
        'true_qs'     shape (n_sites,)
        'opt_weights' shape (n_simu, n_sites)
        'opt_est'     shape (n_simu,)
        'opt_var'     shape (n_simu,)
    """
    futures = [
        train_transfer_remote.remote(
            seed=base_seed + i,
            dist_type=dist_type, tau=tau, rs=rs,
            K_base=K_base, n_samples=n_samples, n_sites=n_sites,
            source_prop=source_prop,
            biases=biases,
            lambd_grid=lambd_grid,
            burn_in_ratio=burn_in_ratio, c0=c0, a=a, b0=b0
        )
        for i in range(n_simu)
    ]        

    results = ray.get(futures)
    
    true_qs            = np.asarray(results[0]['true_qs'])
    opt_weights = np.asarray([r['opt_weights']  for r in results])
    opt_est    = np.asarray([r['opt_est']    for r in results])
    opt_var  = np.asarray([r['opt_var']  for r in results])

    return dict(true_qs=true_qs,
                opt_weights=opt_weights,
                opt_est=opt_est,
                opt_var=opt_var)

# --------------------------------------------------------------------------- #
#  Helper routines to build bias / r-value grids                            #
# --------------------------------------------------------------------------- #
def generate_biases(step_small=0.005, step_big=0.5, n_sites=4):
    """
    Generate *monotone* bias patterns for experiments.

    Rules
    -----
    * First element (target site) is always 0.
    * Source-site biases are non-decreasing and can only take
      values {0, step_small, step_big}.
    """
    levels   = [0.0, step_small, step_big]
    n_source = n_sites - 1

    combos = [c for c in product(levels, repeat=n_source)
              if all(c[i] <= c[i+1] for i in range(n_source-1))]

    order = {v: i for i, v in enumerate(levels)}          # 0 < small < big
    combos.sort(key=lambda c: tuple(order[v] for v in reversed(c)))

    return [[0.0] + list(c) for c in combos]

def generate_biases_cont(end_val, num_points, n_sites):
    """
    Continuous bias grid – log-spaced between 1e-5 and `end_val`.
    """
    seq = np.logspace(-5, np.log(end_val),num_points,base=np.e)
    return [[0.0] + [float(x)] * (n_sites - 1) for x in seq]

def generate_rs_s(target=0.5, r_start=0.2, r_end=0.9, n=8,n_sites=4):
    """
    Build a list of `rs` vectors varying source-site privacy levels.

    * Target site keeps `target`.  
    * Source sites share the same r, linearly spaced from `r_start`
      to `r_end` in `n` steps.
    """
    r_values = np.linspace(r_start, r_end, n)     
    rs_s = [[target] + [float(r)]*(n_sites-1) for r in r_values] 

    return rs_s