import os, csv
from typing import Dict, List

def _as_int_list(s: str) -> List[int]:
    return [int(tok.strip()) for tok in s.split(",") if tok.strip()]

def _as_float_list(s: str) -> List[float]:
    return [float(tok.strip()) for tok in s.split(",") if tok.strip()]

def load_config_csv_ms(path: str) -> Dict:
    if not os.path.exists(path):
        raise FileNotFoundError(f"config.csv not found: {path}")
    raw = {}
    with open(path, "r", newline="") as f:
        reader = csv.reader((row for row in f if not row.lstrip().startswith("#")))
        header = None
        for r in reader:
            if not r: 
                continue
            if header is None:
                header = r
                if len(header) < 2 or header[0] != "param" or header[1] != "value":
                    raise ValueError("config.csv must have header 'param,value'")
                continue
            if len(r) < 2:
                continue
            k, v = r[0].strip(), r[1].strip()
            if not k or k.startswith("#"):
                continue
            raw[k] = v

    return {
        "curve_csv": raw.get("curve_csv", "./curve_ms.csv"),
        "out_csv":   raw.get("out_csv",   "./cache/curve_ms_log.csv"),

        "T":     int(raw.get("T", 3)),
        "Vstar": float(raw.get("Vstar", 90.0)),
        "P":     float(raw.get("P", 1e6)),
        "steps": int(raw.get("steps", 400)),
        "lr":    float(raw.get("lr", 200.0)),
        "beta1": float(raw.get("beta1", 0.9)),
        "beta2": float(raw.get("beta2", 0.999)),
        "eps":   float(raw.get("eps", 1e-8)),
        "tol":   float(raw.get("tol", 1e-3)),

        "epochs":   int(raw.get("epochs", 50)),
        "train_lr": float(raw.get("train_lr", 0.1)),
        "gmm_K":    int(raw.get("gmm_K", 3)),

        "K": int(raw["K"]) if "K" in raw else None,

        "q0_vec":   _as_int_list(raw["q0_vec"])   if raw.get("q0_vec","").strip() else None,
        "qcap_vec": _as_int_list(raw["qcap_vec"]) if raw.get("qcap_vec","").strip() else None,
        "c_vec":    _as_float_list(raw["c_vec"])  if raw.get("c_vec","").strip()  else None,
    }

def load_config_csv(path: str):
    if not os.path.exists(path):
        raise FileNotFoundError(f"config.csv not found: {path}")
    cfg = {}
    with open(path, "r", newline="") as f:
        reader = csv.reader((row for row in f if not row.lstrip().startswith("#")))
        header = None
        rows = []
        for r in reader:
            if not r: continue
            if header is None:
                header = r
                continue
            rows.append(r)
    if header is None or len(header) < 2 or header[0] != "param" or header[1] != "value":
        raise ValueError("config.csv must have header 'param,value'")
    for r in rows:
        if len(r) < 2: continue
        k, v = r[0].strip(), r[1].strip()
        if k == "" or k.startswith("#"): continue
        cfg[k] = v
    return _coerce_config_types(cfg)



def _coerce_config_types(cfg: dict):
    # defaults (single-source)
    out = {
        "T":            int(cfg.get("T", 3)),
        "Vstar":        float(cfg.get("Vstar", 80.0)),
        "epochs":       int(cfg.get("epochs", 200)),
        "lr":           float(cfg.get("lr", 0.1)),
        "c":            float(cfg.get("c", 1.0)),
        "P":            float(cfg.get("P", 1e6)),
        "q0":           int(cfg.get("q0", 5000)),
        "qcap":         int(cfg.get("qcap", 50000)),
        "curve_csv":    os.path.normpath(cfg.get("curve_csv", "./curve_points.csv")),
        "curve_out_csv":os.path.normpath(cfg.get("curve_out_csv", "./cache/curve_log.csv")),
        "minibatch_size": int(cfg.get("minibatch_size", 5)),
        "virt_bins":      int(cfg.get("virt_bins", 10)),
        "num_inducing":   int(cfg.get("num_inducing", 40)),
        "num_direction":  int(cfg.get("num_direction", 2)),
        "mu":             float(cfg.get("mu", 1e-2)),
        "joint_steps":  int(cfg.get("joint_steps", 400)),
        "joint_lr":     float(cfg.get("joint_lr", 200.0)),
        "joint_beta1":  float(cfg.get("joint_beta1", 0.9)),
        "joint_beta2":  float(cfg.get("joint_beta2", 0.999)),
        "joint_eps":    float(cfg.get("joint_eps", 1e-8)),
        "joint_tol":    float(cfg.get("joint_tol", 1e-3)),
        "learning_rate_hypers": float(cfg.get("learning_rate_hypers", 0.01)),
        "learning_rate_ngd":    float(cfg.get("learning_rate_ngd", 0.1)),
        "gamma":          float(cfg.get("gamma", 10.0)),
    }
    # Optional multi-source
    if "K" in cfg:
        K = int(cfg["K"])
        out["K"] = K
        out["q0_vec"]   = _as_int_list(cfg.get("q0_vec", ""))   or [out["q0"]]*K
        out["qcap_vec"] = _as_int_list(cfg.get("qcap_vec",""))  or [out["qcap"]]*K
        out["c_vec"]    = _as_float_list(cfg.get("c_vec",""))   or [out["c"]]*K
        # length checks
        if not (len(out["q0_vec"]) == len(out["qcap_vec"]) == len(out["c_vec"]) == K):
            raise ValueError("Lengths of q0_vec, qcap_vec, c_vec must all equal K")
    return out