import numpy as np
import sys
from util_di import *
from TransferDPQuantile import TransferDPQuantile
from DPQuantile import DPQuantile
import ray
import time
import itertools
import pickle as pkl
from pathlib import Path 
import pandas as pd
import copy
from itertools import product
from MultiChainDPQuantile import MultiChainDPQuantile

def train_inv(seed=None,
                   dist_type='normal', tau=0.5, rs=None,
                   K_base=5, n_samples=1000, n_sites=3, 
                   source_prop=1.0,            
                   biases=None,
                   burn_in_ratio=0, c0=1, a=0.6,b0=0):

    np.random.seed(seed)
    
    n_n_samples = get_n_n_sample(n_sites, n_samples, source_prop)

    datas, true_qs = generate_federate_data(dist_type, tau, n_n_samples,
                                   biases) 
    
    K_list = get_prop_K(datas, K_base=K_base)

    model = TransferDPQuantile(K_list=K_list,
                               rs=rs, tau=tau, true_q=true_qs[0], 
                               burn_in_ratio=burn_in_ratio,
                               c0=c0, a=a, b0=b0)

    model.fit(datas)
    inv_weights, inv_est, inv_var = model.aggregate(lambd=0,method='opt')

    mse_weights, mse_est, mse_var = model.aggregate(lambd=1, method='opt')
    
    cons_weights, cons_est, cons_var = model.aggregate(lambd=1, method='cons')

    
    return {
        'true_qs': true_qs,
        'inv_weights': inv_weights,
        'inv_est': inv_est,
    'inv_var':inv_var,
        'mse_weights': mse_weights,
        'mse_est': mse_est,
    'mse_var':mse_var,
        'cons_weights': cons_weights,
        'cons_est': cons_est,
    'cons_var':cons_var}

@ray.remote
def train_inv_remote(**kwargs):
    return train_inv(**kwargs)

def run_simulation_inv(n_simu=100,base_seed=2025, 
                            dist_type='normal', tau=0.5, rs=None,
                            K_base=5,
                            n_samples=1000,  n_sites=3, 
                            source_prop=1.0,
                            biases=None,
                            burn_in_ratio=0, c0=1, a=0.6,b0=0):


    futures = [
        train_inv_remote.remote(
            seed=base_seed + i,
            dist_type=dist_type, tau=tau, rs=rs,
            K_base=K_base, n_samples=n_samples, n_sites=n_sites,
            source_prop=source_prop,
            biases=biases,
            burn_in_ratio=burn_in_ratio, c0=c0, a=a, b0=b0
        )
        for i in range(n_simu)
    ]        

    results = ray.get(futures)
    

    true_qs            = np.asarray(results[0]['true_qs'])
        
    inv_weights = np.asarray([r['inv_weights']  for r in results])
    inv_est    = np.asarray([r['inv_est']    for r in results])
    inv_var  = np.asarray([r['inv_var']  for r in results])

    mse_weights = np.asarray([r['mse_weights']  for r in results])
    mse_est    = np.asarray([r['mse_est']    for r in results])
    mse_var  = np.asarray([r['mse_var']  for r in results])
    
    cons_weights = np.asarray([r['cons_weights']  for r in results])
    cons_est    = np.asarray([r['cons_est']    for r in results])
    cons_var  = np.asarray([r['cons_var']  for r in results])

    return dict(true_qs=true_qs,
                inv_weights=inv_weights,
                inv_est=inv_est,
               inv_var=inv_var,
                mse_weights=mse_weights,
                mse_est=mse_est,
               mse_var=mse_var,
                cons_weights=cons_weights,
                cons_est=cons_est,
               cons_var=cons_var)