import numpy as np
import sys
from util_di import *
from TransferDPQuantile import TransferDPQuantile
from DPQuantile import DPQuantile
import ray
import time
import itertools
import pickle as pkl 
from pathlib import Path
import pandas as pd
import copy
from itertools import product
from MultiChainDPQuantile import MultiChainDPQuantile

def train_target(seed=None,seed_target=None,
                   dist_type='normal', tau=0.5, rs=None,
                   K_base=5, n_samples=1000, n_sites=3, 
                   source_prop=1.0,           
                   biases=None,
                   burn_in_ratio=0, c0=1, a=0.6,b0=0):

    np.random.seed(seed)
    
    n_n_samples = get_n_n_sample(n_sites, n_samples, source_prop)

    datas, true_qs = generate_federate_data(dist_type, tau, n_n_samples,
                                   biases)

    np.random.seed(seed_target)
    model_target_only = DPQuantile(r=rs[0], tau=tau, 
                                             true_q=true_qs[0], burn_in_ratio=burn_in_ratio)
    model_target_only.fit(datas[0])
    
    target_est = model_target_only.Q_avg
    target_var = model_target_only.get_variance()    
    
    return {
        'true_qs': true_qs,
        'target_est':target_est,
    'target_var':target_var}

@ray.remote
def train_target_remote(**kwargs):
    return train_target(**kwargs)

def run_simulation_target(n_simu=100,base_seed=2025, 
                       base_seed_target=100,
                            dist_type='normal', tau=0.5, rs=None,
                            K_base=5,
                            n_samples=1000,  n_sites=3, 
                            source_prop=1.0,
                            biases=None,
                            burn_in_ratio=0, c0=1, a=0.6,b0=0):

    futures = [
        train_target_remote.remote(
            seed=base_seed + i,
            seed_target=base_seed_target+i,
            dist_type=dist_type, tau=tau, rs=rs,
            K_base=K_base, n_samples=n_samples, n_sites=n_sites,
            source_prop=source_prop,
            biases=biases,
            burn_in_ratio=burn_in_ratio, c0=c0, a=a, b0=b0
        )
        for i in range(n_simu)
    ]        
    results = ray.get(futures)
    
    true_qs            = np.asarray(results[0]['true_qs'])
    
    target_est  = np.asarray([r['target_est']  for r in results])
    target_var  = np.asarray([r['target_var']  for r in results])


    return dict(true_qs=true_qs,
                target_est=target_est,
               target_var=target_var)