import itertools
import time
from pathlib import Path

import ray
import numpy as np 
import pandas as pd

from method import generate_biases_cont, analyze_results
from comp_target import run_simulation_target

ray.init(runtime_env={"working_dir": ".", "py_modules": [".."]})

N_TO_K = {20_000:8, 50_000: 10, 200_000: 20}

seed = 511
dist_type = 'normal'
n_sample_targets = [20000] # target-site sample 
n_sim = 1000
n_sites = 4

taus = [0.25,0.5]
rs = [0.5]*n_sites
source_props = [3]

end_val = 1
num_points = 11
biases_s = generate_biases_cont(end_val, num_points, n_sites)

records = []          # full nested results per combination
total   = len(n_sample_targets) * len(source_props)*len(taus)*len(biases_s)
out_dir = Path("output")
out_dir.mkdir(exist_ok=True)
sum_path   = out_dir / f"gbias_target.csv"
flat_rows  = []

for idx, (n_samples,source_prop, tau, biases) in enumerate(
        itertools.product(n_sample_targets,source_props, taus, biases_s), 1):
    K_base = N_TO_K[n_samples]
    t0 = time.time()


    output = run_simulation_target(
        n_simu      = n_sim,
        base_seed   = 2022,
        base_seed_target = seed,
        dist_type   = dist_type,
        tau         = tau,
        rs          = rs,
        K_base      = K_base,
        n_samples   = n_samples,
        n_sites     = n_sites,
        source_prop = source_prop,
        biases      = biases,
        c0=1, a=0.6, b0=0)

    res_target = analyze_results(output, est_key="target")

    rec = {
        "tau"        : tau,
        "n_samples"  : n_samples,
        "rs"         : rs,
        "source_prop": source_prop,
        "biases"     : biases,
        "bias_id"    : biases_s.index(biases) + 1,
        "target"     : res_target,
    }
    records.append(rec)


    row = {
        "n_samples"  : n_samples,
        "source_prop": source_prop,
        "tau"        : tau,
        "bias_id"    : rec["bias_id"],
        **{f"tgt_{k}": v for k, v in res_target.items()}}
    flat_rows.append(row)
    pd.DataFrame(flat_rows).to_csv(sum_path, index=False)


    elapsed = time.time() - t0
    eta     = elapsed * (total - idx)
    print(f"[{idx:02d}/{total}] "
          f"n={n_samples:<6} "
          f"τ={tau:<4} src={source_prop:<2} bias#{rec['bias_id']}  "
          f" {elapsed:6.1f}s, left: {eta/60:5.1f} min")

print("\nAll done! Results saved to →")
print(f"  • {sum_path}")