import itertools
import time
from pathlib import Path

import ray
import numpy as np
import pandas as pd 

from method import generate_biases, analyze_results
from comp_target import run_simulation_target

ray.init(runtime_env={"working_dir": ".", "py_modules": [".."]})

N_TO_K = {20_000:8, 50_000: 10, 200_000: 20}
N_BIASES = 10

seed = 511
dist_type = 'normal'
n_sample_targets = [20000,200000] 
n_sim = 1000
n_sites = 4  

taus = [0.25]
rs = [0.5]*n_sites
source_props = [3]

records = []        

total = len(n_sample_targets) * len(source_props) * len(taus) * N_BIASES  

out_dir = Path("output")
out_dir.mkdir(exist_ok=True)
sum_path   = out_dir /f"case_bias_target.csv"
flat_rows  = []   


idx = 0
for n_samples, source_prop, tau in itertools.product(n_sample_targets, source_props, taus):

    scl = 1 / np.sqrt(n_samples)            
    biases_s = generate_biases(step_small=0.1 * scl,
                               step_big  =100   * scl,
                               n_sites   =n_sites)

    for bias_id, biases in enumerate(biases_s, 1):
        idx += 1
        t0 = time.time()
            
        output = run_simulation_target(
            n_simu      = n_sim,
            base_seed   = 2022,
            base_seed_target = seed,
            dist_type   = dist_type,
            tau         = tau,
            rs          = rs,
            K_base      = N_TO_K[n_samples],
            n_samples   = n_samples,
            n_sites     = n_sites,
            source_prop = source_prop,
            biases      = biases,
            c0=1, a=0.6, b0=0)

        res_target = analyze_results(output, est_key="target")

        rec = {
            "tau"        : tau,
            "n_samples"  : n_samples,
            "rs"         : rs,
            "source_prop": source_prop,
            "biases"     : biases,
            "bias_id"    : biases_s.index(biases) + 1,
            "target"     : res_target,
        }
        records.append(rec)

        row = {
            "n_samples"  : n_samples,
            "source_prop": source_prop,
            "tau"        : tau,
            "bias_id"    : rec["bias_id"],
            **{f"tgt_{k}": v for k, v in res_target.items()}}
        flat_rows.append(row)
        pd.DataFrame(flat_rows).to_csv(sum_path, index=False)

        elapsed = time.time() - t0
        eta     = elapsed * (total - idx)
        print(f"[{idx:02d}/{total}] "
              f"n={n_samples:<6} "
              f"τ={tau:<4} src={source_prop:<2} bias#{rec['bias_id']}  "
              f" {elapsed:6.1f}s, left: {eta/60:5.1f} min")
print("\nAll done! Results saved to →")
print(f"  • {sum_path}")