import itertools            
import time                 
from pathlib import Path    

import ray                  
import numpy as np           
import pandas as pd  
import sys

from method import generate_biases,run_simulation_transfer,analyze_results
from comp import run_simulation_inv


# --------------------------------------------------------------------------- #
#  Ray configuration                                                          #
# --------------------------------------------------------------------------- #

ray.init(runtime_env={"working_dir": ".",
                      "py_modules": [".."]})  # expose parent package to workers

# --------------------------------------------------------------------------- #
#  Experiment-wide constants                                                  #
# --------------------------------------------------------------------------- #

N_TO_K = {20_000:8, 50_000: 10, 200_000: 20}
N_BIASES = 10


seed = 2022
dist_type = 'normal'
n_sample_targets = [20000,200000]   # target-site sample sizes
n_sim = 1000  # Monte-Carlo replications
n_sites = 4    # total sites = 1 target + 3 sources

taus = [0.25]  # quantile levels
rs = [0.5]*n_sites  # r for each site
source_props = [3] # source-to-target sample ratio

lambd_grid = np.logspace(-4, 1, 6)  # 0.001→10
lambd_grid = np.insert(lambd_grid, -1, 3.0)  # insert 3.0 before the last point

records = []            # full nested results per combination

total = len(n_sample_targets) * len(source_props) * len(taus) * N_BIASES 

out_dir = Path("output")
out_dir.mkdir(exist_ok=True)
sum_path   = out_dir /f"case_bias.csv"
flat_rows  = []    # flattened rows for CSV output

# --------------------------------------------------------------------------- #
#  Main grid loop                                                             #
# --------------------------------------------------------------------------- #
idx = 0
for n_samples, source_prop, tau in itertools.product(n_sample_targets, source_props, taus):

    # Regenerate `N_BIASES` bias patterns for each target sample size
    scl = 1 / np.sqrt(n_samples)         # common scaling factor     
    biases_s = generate_biases(step_small=0.1 * scl,
                               step_big  =100   * scl,
                               n_sites   =n_sites)

    for bias_id, biases in enumerate(biases_s, 1):
        idx += 1
        t0 = time.time()
        # ① LDP transfer estimator
        output = run_simulation_transfer(
            n_simu=n_sim, base_seed=seed, dist_type=dist_type, tau=tau,
            rs=rs, K_base=N_TO_K[n_samples], n_samples=n_samples,
            n_sites=n_sites, source_prop=source_prop, biases=biases,
            lambd_grid=lambd_grid, c0=1, a=0.6, b0=0
        )
        # ② Competing method(s) for comparison
        output_comp = run_simulation_inv(
            n_simu=n_sim, base_seed=seed, dist_type=dist_type, tau=tau,
            rs=rs, K_base=N_TO_K[n_samples], n_samples=n_samples,
            n_sites=n_sites, source_prop=source_prop, biases=biases,
            c0=1, a=0.6, b0=0
        )        
        # ③ Summaries        
        res_opt = analyze_results(output, est_key="opt")

        res_inv    = analyze_results(output_comp, est_key="inv")
        res_mse    = analyze_results(output_comp, est_key="mse")

        # Store the full record
        records.append({
            "tau": tau, "n_samples": n_samples, "rs": rs,
            "source_prop": source_prop, "biases": biases,
            "bias_id": bias_id, "opt": res_opt,
            "inv": res_inv, "mse": res_mse})

        # Flattened row for the CSV
        flat_rows.append({
            "n_samples": n_samples, "source_prop": source_prop,
            "tau": tau, "bias_id": bias_id,
            **{f"opt_{k}": v for k, v in res_opt.items()},
            **{f"inv_{k}": v for k, v in res_inv.items()},
            **{f"mse_{k}": v for k, v in res_mse.items()}
        })
        pd.DataFrame(flat_rows).to_csv(sum_path, index=False)

        # Progress bar
        elapsed = time.time() - t0
        eta = elapsed * (total - idx)
        print(f"[{idx:03d}/{total}] n={n_samples:<6} τ={tau:<4} "
              f"src={source_prop} bias#{bias_id} "
              f"{elapsed:5.1f}s left≈{eta/60:4.1f} min")

print("\nAll done! Results saved to →", sum_path)