import os, csv, json, time, hashlib, base64
import os, random, numpy as np, torch

def config_id(cfg) -> str:
    import json as _json
    def _norm(x):
        if isinstance(x, dict):
            return {k: _norm(x[k]) for k in sorted(x)}
        if isinstance(x, (list, tuple)):
            return [_norm(v) for v in x]
        return x
    blob = _json.dumps(_norm(cfg), separators=(",", ":"), ensure_ascii=False)
    h = hashlib.sha256(blob.encode("utf-8")).digest()
    return base64.urlsafe_b64encode(h[:8]).decode("ascii")

def cfg_get(cfg, path, default=None):
    cur = cfg
    for k in path.split("."):
        if isinstance(cur, dict) and k in cur:
            cur = cur[k]
        else:
            return default
    return cur

def results_paths(cfg):
    eid = f"{cfg.get('exp_name','exp')}_{cfg.get('model_name','model')}_{cfg.get('data',{}).get('dataset','data')}_{config_id(cfg)}"
    results_dir = "./results"
    os.makedirs(results_dir, exist_ok=True)
    return {
        "exp_id": eid,
        "dir": results_dir,
        "round_csv": os.path.join(results_dir, f"{eid}_rounds.csv"),
        "meta_json": os.path.join(results_dir, f"{eid}_meta.json"),
    }

def atomic_write_round_csv(path, rows, header=None):
    if header is None:
        header = ["round", "loss", "acc", "round_time_sec", "cum_time_sec"]
    tmp = path + ".tmp"
    with open(tmp, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(header)
        w.writerows(rows)
    os.replace(tmp, path)

import os, csv

def load_round_csv(path):
    out = {
        "round": [], "loss": [], "acc": [],
        "round_client_time_sec": [], "round_each_client_time_sec": [], "round_server_time_sec": [],
        "cum_client_time_sec": [], "cum_each_client_time_sec": [], "cum_server_time_sec": [],
        "cum_time_sec": [],
    }
    with open(path, "r") as f:
        r = csv.DictReader(f)
        for row in r:
            out["round"].append(int(row["round"]))
            out["loss"].append(float(row["loss"]))
            out["acc"].append(float(row["acc"]))
            out["round_client_time_sec"].append(float(row["round_client_time_sec"]))
            out["round_each_client_time_sec"].append(float(row["round_each_client_time_sec"]))
            out["round_server_time_sec"].append(float(row["round_server_time_sec"]))
            out["cum_client_time_sec"].append(float(row["cum_client_time_sec"]))
            out["cum_each_client_time_sec"].append(float(row["cum_each_client_time_sec"]))
            out["cum_server_time_sec"].append(float(row["cum_server_time_sec"]))
            out["cum_time_sec"].append(float(row["cum_time_sec"]))
    return out

def print_round_logs_from_csv(csv_path, log_every=1):
    if not os.path.exists(csv_path):
        print(f"[Server] no CSV to print: {csv_path}")
        return

    rec = load_round_csv(csv_path)
    n = len(rec["round"])
    for i in range(n):
        r = rec["round"][i]
        if r % max(1, int(log_every)) != 0:
            continue
        loss = rec["loss"][i]
        acc  = rec["acc"][i]
        rt   = rec["round_client_time_sec"][i] + rec["round_server_time_sec"][i]
        ct   = rec["cum_time_sec"][i]
        rc   = rec["round_client_time_sec"][i]
        rs   = rec["round_server_time_sec"][i]
        print(f">> round {r} | loss={loss:.4f} | acc={acc:.2f}% | "
              f"t_round={rt:.2f}s (client={rc:.2f}s, server={rs:.2f}s) | t_cum={ct:.2f}s")


def set_global_seed(seed: int, deterministic: bool = True):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False