import itertools            
import time                 
from pathlib import Path    

import ray                  
import numpy as np          
import pandas as pd          

from method import generate_rs_s,analyze_results
from method_lasso import run_simulation_transfer

ray.init(runtime_env={"working_dir": ".",
                      "py_modules": [".."]})  

N_TO_K = {20_000:8, 50_000: 10, 200_000: 20}

seed = 2022
dist_type = 'normal'
n_sample_targets = [20000] 
n_sim = 1000
n_sites = 4

taus = [0.25,0.5,0.75]
target = 1
target_norm = 0.25
rs_s = generate_rs_s(target=target, r_start=target_norm,
                     r_end=0.9, n=9,n_sites=4)

biases = [0]* n_sites
source_props = [3]

c0=1; a=0.6; b0=0

lambd_grid = np.logspace(-4, 1, 6)  # 0.001→10

records = []        
total   = len(n_sample_targets)*len(source_props)*len(taus)*len(rs_s)
out_dir = Path("output")
out_dir.mkdir(exist_ok=True)
sum_path   = out_dir / f"case_r_dist_{dist_type}_r_{target}.csv"
flat_rows  = []

for idx, (n_samples,source_prop, tau, rs) in enumerate(
        itertools.product(n_sample_targets, source_props, taus, rs_s), 1):
    K_base = N_TO_K[n_samples]
    t0 = time.time()

    output = run_simulation_transfer(
        n_simu      = n_sim,
        base_seed   = seed,
        dist_type   = dist_type,
        tau         = tau,
        rs          = rs,
        K_base      = K_base,
        n_samples   = n_samples,
        n_sites     = n_sites,
        source_prop = source_prop,
        biases      = biases,
        lambd_grid  = lambd_grid,
        c0=c0, a=a, b0=b0)

    res_opt    = analyze_results(output, est_key="opt")
    res_lasso    = analyze_results(output, est_key="lasso")

    rec = {
        "tau"        : tau,
        "n_samples"  : n_samples,
        "rs"         : rs,
        "rs_id"      : rs_s.index(rs) + 1,
        "source_prop": source_prop,
        "biases"     : biases,
        "opt"        : res_opt,
        "lasso"        : res_lasso}
    records.append(rec)

    row = {
        "n_samples"  : n_samples,
        "source_prop": source_prop,
        "tau"        : tau,
        "rs"    : rec["rs"],
        **{f"opt_{k}" : v for k, v in res_opt.items()},
        **{f"lasso_{k}" : v for k, v in res_lasso.items()}}
    flat_rows.append(row)
    df = pd.DataFrame(flat_rows)
    df.to_csv(sum_path, index=False)

    elapsed = time.time() - t0
    eta     = elapsed * (total - idx)
    print(f"[{idx:02d}/{total}] n={n_samples:<6} τ={tau:<4} "
          f"src={source_prop:<2} rs={rs}  "
          f"{elapsed:5.1f}s, left: {eta/60:5.1f} min")  

print("\nAll done! Results saved to →")
print(f"  • {sum_path}")