import argparse
import copy
import csv
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader, Subset

from fedavg_utils import DatasetSplit, equal_split_indices, dirichlet_noniid, fedavg_weighted

try:
    import wandb
except Exception:
    wandb = None


def build_resnet20(num_classes: int):
    # torchvision doesn't ship resnet20 for CIFAR; use resnet18 adapted
    m = torchvision.models.resnet18(num_classes=num_classes)
    m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    m.maxpool = nn.Identity()
    return m


@torch.no_grad()
def eval_loss_acc(model, loader, device):
    model.eval()
    ce = nn.CrossEntropyLoss(reduction="sum")
    total_loss, correct, total = 0.0, 0, 0
    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        logits = model(x)
        total_loss += float(ce(logits, y).detach().cpu().item())
        pred = logits.argmax(dim=1)
        correct += int((pred == y).sum().item())
        total += int(y.numel())
    return total_loss / max(total, 1), 100.0 * correct / max(total, 1)


def local_train(model, loader, device, lr, momentum, weight_decay, local_ep):
    model.train()
    opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    ce = nn.CrossEntropyLoss()

    for _ in range(local_ep):
        for x, y in loader:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            opt.zero_grad(set_to_none=True)
            loss = ce(model(x), y)
            loss.backward()
            opt.step()

    return copy.deepcopy(model.state_dict())


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--dataset", type=str, default="cifar10", choices=["cifar10", "cifar100"])
    p.add_argument("--num_classes", type=int, default=10)
    p.add_argument("--data_root", type=str, default="./data")

    p.add_argument("--rounds", type=int, default=120)
    p.add_argument("--workers", type=int, default=0)
    p.add_argument("--eval_bs", type=int, default=256)

    p.add_argument("--num_users", type=int, default=32)
    p.add_argument("--client_frac", type=float, default=1.0)
    p.add_argument("--local_ep", type=int, default=1)
    p.add_argument("--local_bs", type=int, default=64)
    p.add_argument("--local_lr", type=float, default=0.05)
    p.add_argument("--momentum", type=float, default=0.9)
    p.add_argument("--weight_decay", type=float, default=5e-4)

    p.add_argument("--split_type", type=str, default="dirichlet", choices=["equal", "dirichlet"])
    p.add_argument("--dir_alpha", type=float, default=0.1)

    p.add_argument("--seed", type=int, default=40)

    p.add_argument("--wandb", action="store_true")
    p.add_argument("--wandb_project", type=str, default="fedavg")
    p.add_argument("--wandb_entity", type=str, default="")

    args = p.parse_args()

    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"[INFO] device={device}")

    if args.dataset == "cifar100":
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)
    else:
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)

    train_tf = T.Compose([
        T.RandomCrop(32, padding=4),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean, std),
    ])
    test_tf = T.Compose([T.ToTensor(), T.Normalize(mean, std)])

    if args.dataset == "cifar100":
        train_data = torchvision.datasets.CIFAR100(args.data_root, train=True, download=True, transform=train_tf)
        train_eval_data = torchvision.datasets.CIFAR100(args.data_root, train=True, download=True, transform=test_tf)
        test_data = torchvision.datasets.CIFAR100(args.data_root, train=False, download=True, transform=test_tf)
    else:
        train_data = torchvision.datasets.CIFAR10(args.data_root, train=True, download=True, transform=train_tf)
        train_eval_data = torchvision.datasets.CIFAR10(args.data_root, train=True, download=True, transform=test_tf)
        test_data = torchvision.datasets.CIFAR10(args.data_root, train=False, download=True, transform=test_tf)

    test_loader = DataLoader(test_data, batch_size=args.eval_bs, shuffle=False, num_workers=args.workers, pin_memory=True)

    # proxy train-eval (faster)
    rng = np.random.default_rng(args.seed)
    proxy_n = min(10000, len(train_eval_data))
    proxy_idx = rng.choice(len(train_eval_data), size=proxy_n, replace=False)
    train_eval_loader = DataLoader(Subset(train_eval_data, proxy_idx.tolist()),
                                   batch_size=args.eval_bs, shuffle=False,
                                   num_workers=args.workers, pin_memory=True)

    # split to clients
    if args.split_type == "dirichlet":
        dict_users = dirichlet_noniid(train_data, args.num_users, alpha=args.dir_alpha,
                                      seed=args.seed, num_classes=args.num_classes, min_size=10)
        print(f"[INFO] split=dirichlet alpha={args.dir_alpha}")
    else:
        dict_users = equal_split_indices(len(train_data), args.num_users, seed=args.seed)
        print(f"[INFO] split=equal")

    model = build_resnet20(args.num_classes).to(device)
    global_sd = copy.deepcopy(model.state_dict())

    os.makedirs("runs", exist_ok=True)
    csv_path = os.path.join("runs", f"fedavg_{args.dataset}_u{args.num_users}_{args.split_type}_seed{args.seed}.csv")
    if not os.path.exists(csv_path):
        with open(csv_path, "w", newline="") as f:
            w = csv.writer(f)
            w.writerow(["round","used","train_acc","test_acc","train_loss","test_loss","lr"])

    if args.wandb and wandb is not None:
        wandb.init(project=args.wandb_project, entity=args.wandb_entity,
                   config=vars(args), name=f"FedAvg_{args.dataset}_u{args.num_users}_{args.split_type}_seed{args.seed}")

    t0 = time.time()

    for r in range(args.rounds):
        lr = float(args.local_lr)
        if (r + 1) > 90:
            lr *= 0.1

        m = max(1, int(round(args.client_frac * args.num_users)))
        selected = rng.choice(args.num_users, size=m, replace=False).tolist()

        local_sds, local_ns = [], []
        for k in selected:
            ds_k = DatasetSplit(train_data, dict_users[k])
            loader_k = DataLoader(ds_k, batch_size=args.local_bs, shuffle=True,
                                  num_workers=args.workers, pin_memory=True)

            local_model = copy.deepcopy(model).to(device)
            local_model.load_state_dict(global_sd, strict=True)

            sd_new = local_train(local_model, loader_k, device,
                                 lr=lr, momentum=args.momentum,
                                 weight_decay=args.weight_decay,
                                 local_ep=args.local_ep)
            local_sds.append(sd_new)
            local_ns.append(len(dict_users[k]))

        # server: weighted average by client data size
        nsum = float(sum(local_ns))
        weights = [n / nsum for n in local_ns]
        global_sd = fedavg_weighted(local_sds, weights)
        model.load_state_dict(global_sd, strict=True)

        train_loss, train_acc = eval_loss_acc(model, train_eval_loader, device)
        test_loss, test_acc = eval_loss_acc(model, test_loader, device)

        mins = (time.time() - t0) / 60.0
        print(
            f"Round {r+1:04d}/{args.rounds:04d} | used={len(selected):02d}/{args.num_users:02d} "
            f"| TrainAcc {train_acc:6.2f} | TestAcc {test_acc:6.2f} "
            f"| TrainLoss {train_loss:7.4f} | TestLoss {test_loss:7.4f} "
            f"| lr={lr:.4f} | {mins:5.1f}m"
        )

        with open(csv_path, "a", newline="") as f:
            w = csv.writer(f)
            w.writerow([r+1, len(selected), train_acc, test_acc, train_loss, test_loss, lr])

        if args.wandb and wandb is not None:
            wandb.log(
                {"Round": r+1,
                 "Train/Accuracy": train_acc,
                 "Test/Accuracy": test_acc,
                 "Train/Loss": train_loss,
                 "Test/Loss": test_loss,
                 "Optimization/LR": lr,
                 "System/used_clients": len(selected)},
                step=r+1
            )


if __name__ == "__main__":
    main()