import numpy as np
from multiprocessing import freeze_support
import os
import csv
from LOC.seed import set_global_seed
from LOC.config import load_config_csv_ms
from LOC.curve import load_curve_ms_csv_or_die
from LOC.loc import ExperimentCfg, mpc_run_ms
from oracle.cifar10_oracle import get_cifar10_oracle_multi

def main_from_csv(config_path: str):
    freeze_support()
    cfg = load_config_csv_ms(config_path)

    points = load_curve_ms_csv_or_die(cfg["curve_csv"])
    S = len(points[0][0])

    # defaults
    
    if cfg["q0_vec"]   is None: cfg["q0_vec"]   = [5000]  * S
    if cfg["qcap_vec"] is None: cfg["qcap_vec"] = [50000] * S
    if cfg["c_vec"]    is None: cfg["c_vec"]    = [1.0]   * S
    if not (len(cfg["q0_vec"]) == len(cfg["qcap_vec"]) == len(cfg["c_vec"]) == S):
        raise ValueError("q0_vec, qcap_vec, c_vec must all have length S (from curve).")

    exp = ExperimentCfg(
        T=cfg["T"], Vstar=cfg["Vstar"], P=cfg["P"],
        q0_vec=tuple(cfg["q0_vec"]), qcap_vec=tuple(cfg["qcap_vec"]), c_vec=tuple(cfg["c_vec"]),
        steps=cfg["steps"], lr=cfg["lr"], beta1=cfg["beta1"], beta2=cfg["beta2"],
        eps=cfg["eps"], tol=cfg["tol"], gmm_K=cfg["gmm_K"]
    )

    # ----- cost_ratio denominator -----
    q0   = np.array(cfg["q0_vec"], dtype=float)
    qcap = np.array(cfg["qcap_vec"], dtype=float)
    cvec = np.array(cfg["c_vec"], dtype=float)
    denom = float(np.dot(cvec, (qcap - q0)))
    if denom <= 0:
        raise ValueError(f"Invalid denom c·(qcap-q0)={denom}. Check qcap_vec > q0_vec and c_vec >= 0.")


    base_seed = 42
    case_seeds = [base_seed + i for i in range(10)]

    rows = []
    metrics, acc_ratios, cost_ratios = [], [], []

    for case_idx, seed in enumerate(case_seeds, start=1):
        print(f"\n========== Run {case_idx}/{len(case_seeds)} ==========")
        set_global_seed(seed, deterministic=True)
        
        oracle = get_cifar10_oracle_multi(K=S, epochs=cfg["epochs"], lr=cfg["train_lr"])

        result = mpc_run_ms(points, oracle, exp, curve_out_csv=cfg["out_csv"], seed=seed)

        qT = np.array(result["qT"], dtype=float)
        metricT = float(result["metricT"])

        acc_ratio = metricT / float(cfg["Vstar"])
        cost_ratio = float(np.dot(cvec, (qT - q0))) / denom

        print(f"[Case {case_idx:02d}] seed={seed} qT={result['qT']} "
              f"metricT={metricT:.2f} acc_ratio={acc_ratio:.4f} cost_ratio={cost_ratio:.4f}")

        rows.append([case_idx, seed, metricT, acc_ratio, cost_ratio, *map(int, result["qT"])])
        metrics.append(metricT)
        acc_ratios.append(acc_ratio)
        cost_ratios.append(cost_ratio)

    metrics = np.array(metrics, float)
    acc_ratios = np.array(acc_ratios, float)
    cost_ratios = np.array(cost_ratios, float)

    print("\n--- Statistics over 10 cases ---")
    print(f"metricT    mean ± std = {metrics.mean():.2f} ± {metrics.std(ddof=1):.2f}")
    print(f"acc_ratio  mean ± std = {acc_ratios.mean():.4f} ± {acc_ratios.std(ddof=1):.4f}")
    print(f"cost_ratio mean ± std = {cost_ratios.mean():.4f} ± {cost_ratios.std(ddof=1):.4f}")

    # save summary csv
    out_sum = f"./result/summary_cifar10_loc_{cfg['K']}D_{cfg['T']}_{cfg['Vstar']}.csv"
    os.makedirs(os.path.dirname(out_sum) or ".", exist_ok=True)
    with open(out_sum, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["case", "seed", "metricT", "acc_ratio", "cost_ratio", *[f"q{i+1}T" for i in range(S)]])
        w.writerows(rows)
    print("Saved summary to:", out_sum)

if __name__ == "__main__":
    main_from_csv("./data/cifar10/config_5D.csv")