import argparse
import random
import numpy as np
import torch

from fed_dro import FedDROTrainer


def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def build_parser():
    p = argparse.ArgumentParser(description="CIFAR100-ST FedDRO practical reproduction")

    p.add_argument("--data_root", type=str, default="./data")
    p.add_argument("--epochs", type=int, default=120)
    p.add_argument("--workers", type=int, default=2)
    p.add_argument("--pin_memory", action="store_true")
    p.add_argument("--cpu", action="store_true")

    # paper-like setup
    p.add_argument("--num_users", type=int, default=8)
    p.add_argument("--I", type=int, default=1)
    p.add_argument("--batch_size", type=int, default=128)
    p.add_argument("--local_bs", type=int, default=16)

    p.add_argument("--base_lr", type=float, default=0.1)
    p.add_argument("--lr_drop_round", type=int, default=90)
    p.add_argument("--momentum", type=float, default=0.9)
    p.add_argument("--weight_decay", type=float, default=5e-4)
    p.add_argument("--grad_clip", type=float, default=10.0)

    # robust objective
    p.add_argument("--lamda", type=float, default=2.0)
    p.add_argument("--beta_y", type=float, default=0.1)
    p.add_argument("--y_init", type=float, default=2.0)
    p.add_argument("--y_min", type=float, default=1.0)
    p.add_argument("--y_max", type=float, default=50.0)
    p.add_argument("--exp_clip", type=float, default=5.0)

    p.add_argument("--pretrain_epochs", type=int, default=2)
    p.add_argument("--pretrain_lr", type=float, default=0.05)

    p.add_argument("--train_eval_batches", type=int, default=200)
    p.add_argument("--random_seed", type=int, default=40)

    p.add_argument("--out_dir", type=str, default="runs/feddro_cifar100st")
    p.add_argument("--csv_name", type=str, default="feddro_cifar100st_metrics.csv")
    p.add_argument("--grad_plot_name", type=str, default="grad_norm_cifar100st_upto70.png")
    p.add_argument("--ckpt_name", type=str, default="best_feddro_cifar100st.pt")

    p.add_argument("--wandb", action="store_true")
    p.add_argument("--wandb_project", type=str, default="avg")
    p.add_argument("--wandb_entity", type=str, default="hq1351-wayne-state-university")
    p.add_argument("--wandb_run_name", type=str, default="FedDRO_CIFAR100ST_practical")

    return p


def main():
    args = build_parser().parse_args()
    set_all_seeds(args.random_seed)

    trainer = FedDROTrainer(args)
    trainer.train()

    print(f"Saved CSV to: {trainer.csv_path}")
    print(f"Saved grad plot to: {trainer.plot_path}")
    print(f"Saved checkpoint to: {trainer.ckpt_path}")


if __name__ == "__main__":
    main()
