import argparse
import random
import numpy as np
import torch
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader

from FS_fedavg import FedAvgTrain
from fedavg_utils import ResNet20


def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--dataset", type=str, default="cifar10", choices=["cifar10", "cifar100"])
    p.add_argument("--model_name", type=str, default="resnet20")
    p.add_argument("--data_root", type=str, default="./data")
    p.add_argument("--epochs", type=int, default=120)
    p.add_argument("--workers", type=int, default=2)
    p.add_argument("--batch_size", type=int, default=256)

    p.add_argument("--num_users", type=int, default=32)
    p.add_argument("--local_ep", type=int, default=1)
    p.add_argument("--local_bs", type=int, default=16)

    p.add_argument("--base_lr", type=float, default=0.05)
    p.add_argument("--lr_drop_epoch", type=int, default=90)
    p.add_argument("--lr_drop_factor", type=float, default=0.1)
    p.add_argument("--momentum", type=float, default=0.9)
    p.add_argument("--weight_decay", type=float, default=5e-4)
    p.add_argument("--grad_clip", type=float, default=5.0)

    p.add_argument("--random_seed", type=int, default=40)

    p.add_argument("--wandb", action="store_true")
    p.add_argument("--wandb_project", type=str, default="fedavg")
    p.add_argument("--wandb_entity", type=str, default="hq1351-wayne-state-university")

    args = p.parse_args()
    set_all_seeds(args.random_seed)

    args.device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"[INFO] device={args.device}", flush=True)

    if args.dataset == "cifar10":
        num_classes = 10
        dataset_cls = torchvision.datasets.CIFAR10
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2470, 0.2435, 0.2616)
    else:
        num_classes = 100
        dataset_cls = torchvision.datasets.CIFAR100
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)

    train_tf = T.Compose([
        T.RandomCrop(32, padding=4),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean, std),
    ])
    eval_tf = T.Compose([
        T.ToTensor(),
        T.Normalize(mean, std),
    ])

    train_data = dataset_cls(root=args.data_root, train=True, download=True, transform=train_tf)
    train_eval_data = dataset_cls(root=args.data_root, train=True, download=True, transform=eval_tf)
    test_data = dataset_cls(root=args.data_root, train=False, download=True, transform=eval_tf)

    print(f"[INFO] dataset={args.dataset}", flush=True)
    print(f"[INFO] train size={len(train_data)}", flush=True)
    print(f"[INFO] test size={len(test_data)}", flush=True)

    test_loader = DataLoader(
        test_data,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
    )

    model = ResNet20(num_classes=num_classes).to(args.device)
    FedAvgTrain(args, model, train_data, train_eval_data, test_loader)


if __name__ == "__main__":
    main()