# main_monogp_ms.py
import os, csv
import numpy as np
from multiprocessing import freeze_support
from LOC.seed import set_global_seed
from LOC.ratio import cost_ratio, acc_ratio
from MonotonicGP.mpc_monotonicGP import ExperimentCfgMulti, mpc_run_multi
from oracle.cifar10_oracle import get_cifar10_oracle_multi
from LOC.config import load_config_csv
from LOC.curve import load_curve_csv_strict_multi

def main_from_csv(config_path: str):
    freeze_support()
    cfg = load_config_csv(config_path)
    K = cfg["K"]
    points_multi = load_curve_csv_strict_multi(cfg["curve_csv"], K=K)

    num_runs = 10
    base_seed = 42
    seeds = [base_seed + i for i in range(num_runs)]

    metrics=[]; accs=[]; costs=[]; qTs=[]
    rows=[]
    
    for i, seed in enumerate(seeds, start=1):
        set_global_seed(seed, deterministic=True)
        print(f"\n========== Run {i}/{num_runs} ==========")
        oracle = get_cifar10_oracle_multi(K, epochs=cfg["epochs"], lr=cfg["lr"])
        expm = ExperimentCfgMulti(
            T=cfg["T"], Vstar=cfg["Vstar"], P=cfg["P"], epochs=cfg["epochs"],
            q0_vec=cfg["q0_vec"], q_cap_vec=cfg["qcap_vec"], c_vec=cfg["c_vec"],
            j_steps=cfg["joint_steps"], j_lr=cfg["joint_lr"],
            j_beta1=cfg["joint_beta1"], j_beta2=cfg["joint_beta2"],
            j_eps=cfg["joint_eps"], j_tol=cfg["joint_tol"],
            virt_bins=cfg["virt_bins"],
            minibatch_size=cfg["minibatch_size"],
            num_inducing=cfg["num_inducing"],
            num_direction=cfg["num_direction"],
            mu=cfg["mu"]
        )

        result = mpc_run_multi(points_multi, oracle, expm, curve_out_csv=cfg["curve_out_csv"])
        qT = result["qT"]
        mT = float(result["metricT"])

        cr = cost_ratio(qT, cfg["q0_vec"], cfg["qcap_vec"], cfg["c_vec"])
        ar = acc_ratio(mT, cfg["Vstar"])        

        print(f"[Run {i:02d}] seed={seed} qT={qT} metricT={mT:.2f} acc_ratio={ar:.4f} cost_ratio={cr:.4f}")

        metrics.append(mT); accs.append(ar); costs.append(cr); qTs.append(qT)
        rows.append([i, seed, mT, ar, cr, *qT]) 

    metrics=np.array(metrics); accs=np.array(accs); costs=np.array(costs)   
    print("\n--- Statistics (mean ± std over 10 seeds) ---")
    print(f"metricT    mean ± std = {metrics.mean():.2f} ± {metrics.std(ddof=1):.2f}")
    print(f"acc_ratio  mean ± std = {accs.mean():.4f} ± {accs.std(ddof=1):.4f}")
    print(f"cost_ratio mean ± std = {costs.mean():.4f} ± {costs.std(ddof=1):.4f}")

    out=f"./result/summary_monogp_cifar10_{cfg['K']}D_{cfg['T']}_{cfg['Vstar']}_cost.csv"
    os.makedirs(os.path.dirname(out), exist_ok=True)
    with open(out,"w",newline="") as f:
        w=csv.writer(f)
        w.writerow(["run","seed","metricT","acc_ratio","cost_ratio", *[f"q{j+1}T" for j in range(K)]])
        w.writerows(rows)
    print("Saved:", out)    

if __name__ == "__main__":
    main_from_csv("./data/cifar10/mono_config_5D.csv")      
