import argparse
import os
import csv
from datetime import datetime

from fedavg_erm_st import run_fedavg_one_lr


def parse_lr_list(s: str):
    s = s.replace(",", " ")
    vals = [float(x) for x in s.split() if x.strip()]
    if not vals:
        raise ValueError("Empty lr_list")
    return vals


def default_lr_list(dataset: str):
    ds = dataset.lower()
    if ds == "cifar10":
        return [0.001, 0.01, 0.05, 0.1]
    if ds == "cifar100":
        return [0.001, 0.01, 0.05]
    raise ValueError("dataset must be cifar10 or cifar100")


def write_summary_csv(path, rows):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    fieldnames = ["dataset", "split", "lr", "best_test_acc", "best_round", "csv_path", "seed"]
    with open(path, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writeheader()
        for r in rows:
            w.writerow({k: r.get(k, "") for k in fieldnames})


def build_parser():
    p = argparse.ArgumentParser("FedAvg ERM + CIFAR-ST option + LR tuning + CSV + W&B")

    p.add_argument("--dataset", type=str, default="cifar10", choices=["cifar10", "cifar100"])
    p.add_argument("--data_root", type=str, default="./data")

    p.add_argument("--rounds", type=int, default=120)
    p.add_argument("--num_users", type=int, default=8)

    p.add_argument("--local_ep", type=int, default=1)
    p.add_argument("--local_bs", type=int, default=16)
    p.add_argument("--eval_bs", type=int, default=256)

    p.add_argument("--lr", type=float, default=0.05, help="Single LR if not tuning")
    p.add_argument("--tune", action="store_true", help="Run LR sweep like paper and pick best by test acc")
    p.add_argument("--lr_list", type=str, default="", help='Override list e.g. "0.001,0.01,0.05,0.1"')

    p.add_argument("--weight_decay", type=float, default=5e-4)
    p.add_argument("--momentum", type=float, default=0.9)

    p.add_argument("--seed", type=int, default=40)
    p.add_argument("--workers", type=int, default=4)
    p.add_argument("--grad_clip", type=float, default=0.0)
    p.add_argument("--device", type=str, default="cuda")

    # ST option (to make FedAvg worse like paper setting)
    p.add_argument("--st", action="store_true", help="Use CIFAR-ST (imbalanced train, unchanged test)")
    p.add_argument("--st_keep", type=int, default=100, help="Keep N samples per minority class (train only)")

    # output
    p.add_argument("--out_dir", type=str, default="./TrainingResults_FedAvgERM")
    p.add_argument("--run_name", type=str, default="")

    # wandb
    p.add_argument("--wandb", action="store_true")
    p.add_argument("--wandb_project", type=str, default="avg")
    p.add_argument("--wandb_entity", type=str, default="hq1351-wayne-state-university")

    return p


def main():
    args = build_parser().parse_args()

    run_name = args.run_name.strip() or datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    tag = "ST" if args.st else "normal"
    out_root = os.path.join(args.out_dir, run_name, args.dataset, f"iid_equal_{tag}")
    os.makedirs(out_root, exist_ok=True)

    if args.tune:
        lrs = parse_lr_list(args.lr_list) if args.lr_list.strip() else default_lr_list(args.dataset)
    else:
        lrs = [float(args.lr)]

    results = []
    best = None

    for lr in lrs:
        out_csv = os.path.join(out_root, f"fedavg_lr{lr:g}.csv")
        r = run_fedavg_one_lr(
            dataset=args.dataset,
            data_root=args.data_root,
            rounds=args.rounds,
            num_users=args.num_users,
            local_ep=args.local_ep,
            local_bs=args.local_bs,
            eval_bs=args.eval_bs,
            lr=float(lr),
            weight_decay=args.weight_decay,
            momentum=args.momentum,
            seed=args.seed,
            workers=args.workers,
            grad_clip=args.grad_clip,
            device=args.device,
            out_csv_path=out_csv,
            use_st=bool(args.st),
            st_keep=int(args.st_keep),
            wandb_on=bool(args.wandb),
            wandb_project=args.wandb_project,
            wandb_entity=(args.wandb_entity.strip() or None),
        )
        results.append(r)
        if best is None or r["best_test_acc"] > best["best_test_acc"]:
            best = r

    summary_path = os.path.join(out_root, "summary.csv")
    write_summary_csv(summary_path, results)

    print("\n====================")
    if args.tune:
        print(f"LR TUNING DONE for {args.dataset.upper()} (IID equal, {tag})")
        print(f"Best LR: {best['lr']} | Best Test Acc: {best['best_test_acc']:.2f} at round {best['best_round']}")
    else:
        print(f"SINGLE RUN DONE for {args.dataset.upper()} (IID equal, {tag})")
        print(f"LR: {best['lr']} | Best Test Acc: {best['best_test_acc']:.2f} at round {best['best_round']}")
    print(f"CSVs + summary saved under: {out_root}")
    print("====================\n")


if __name__ == "__main__":
    main()