import itertools            
import time                 
from pathlib import Path    

import ray                  
import numpy as np           
import pandas as pd         

from method import generate_biases_cont,analyze_results
from method_lasso import run_simulation_transfer
from comp import run_simulation_inv

ray.init(runtime_env={"working_dir": ".",
                      "py_modules": [".."]})  

N_TO_K = {20_000:8, 50_000: 10, 200_000: 20}


seed = 2022
dist_type = 'normal'
n_sample_targets = [20000] 
n_sim = 1000
n_sites = 4  

taus = [0.25,0.5]
rs = [0.5]*n_sites
source_props = [3]

end_val = 1
num_points = 11
biases_s = generate_biases_cont(end_val, num_points, n_sites)

lambd_grid = np.logspace(-4, 1, 6)  # 0.001→10
lambd_grid = np.insert(lambd_grid, -1, 3.0)

records = []        
total   = len(n_sample_targets) * len(source_props)*len(taus)*len(biases_s)
out_dir = Path("output")
out_dir.mkdir(exist_ok=True)
sum_path   = out_dir / f"gbias.csv"
flat_rows  = []   

for idx, (n_samples,source_prop, tau, biases) in enumerate(
        itertools.product(n_sample_targets,source_props, taus, biases_s), 1):
    K_base = N_TO_K[n_samples]
    t0 = time.time()

    output = run_simulation_transfer(
        n_simu      = n_sim,
        base_seed   = seed,
        dist_type   = dist_type,
        tau         = tau,
        rs          = rs,
        K_base      = K_base,
        n_samples   = n_samples,
        n_sites     = n_sites,
        source_prop = source_prop,
        biases      = biases,
        lambd_grid  = lambd_grid,
        c0=1, a=0.6, b0=0)

    output_comp = run_simulation_inv(
        n_simu      = n_sim,
        base_seed   = seed,
        dist_type   = dist_type,
        tau         = tau,
        rs          = rs,
        K_base      = K_base,
        n_samples   = n_samples,
        n_sites     = n_sites,
        source_prop = source_prop,
        biases      = biases,
        c0=1, a=0.6, b0=0)

    res_opt    = analyze_results(output, est_key="opt")
    res_lasso    = analyze_results(output, est_key="lasso")

    res_inv    = analyze_results(output_comp, est_key="inv")
    res_mse    = analyze_results(output_comp, est_key="mse")
    res_cons = analyze_results(output_comp, est_key="cons")


    rec = {
        "tau"        : tau,
        "n_samples"  : n_samples,
        "rs"         : rs,
        "source_prop": source_prop,
        "biases"     : biases,
        "bias_id"    : biases_s.index(biases) + 1,
        "opt"        : res_opt,
        "lasso"        : res_lasso,
        "inv": res_inv, "mse": res_mse,
            "cons"     : res_cons}
    records.append(rec)


    row = {
        "n_samples"  : n_samples,
        "source_prop": source_prop,
        "tau"        : tau,
        "bias_id"    : rec["bias_id"],
        **{f"opt_{k}" : v for k, v in res_opt.items()},
        **{f"lasso_{k}" : v for k, v in res_lasso.items()},
        **{f"inv_{k}": v for k, v in res_inv.items()},
        **{f"mse_{k}": v for k, v in res_mse.items()},
        **{f"cons_{k}" : v for k, v in res_cons.items()}}
    flat_rows.append(row)
    df = pd.DataFrame(flat_rows)
    df.to_csv(sum_path, index=False)


    elapsed = time.time() - t0
    eta     = elapsed * (total - idx)
    print(f"[{idx:02d}/{total}] n={n_samples:<6} τ={tau:<4} src={source_prop:<2} "
          f"bias#{rec['bias_id']}   {elapsed:6.1f}s, Left: {eta/60:5.1f} min")

print("\nAll done! Results saved to →")
print(f"  • {sum_path}")