import copy
import csv
import os
import time
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import wandb

from fedavg_utils import (
    DatasetSplit,
    equal_split_indices,
    weighted_avg_state_dict,
    step_round_lr,
    local_train_fedavg,
)


@torch.no_grad()
def eval_loss_acc(model, loader, device):
    model.eval()
    ce = nn.CrossEntropyLoss(reduction="sum")
    total_loss, correct, total = 0.0, 0, 0

    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        logits = model(x)
        total_loss += float(ce(logits, y).item())
        pred = logits.argmax(dim=1)
        correct += int((pred == y).sum().item())
        total += int(y.numel())

    return total_loss / max(total, 1), 100.0 * correct / max(total, 1)


def FedAvgTrain(args, model, train_data, train_eval_data, test_loader):
    device = args.device

    dict_users = equal_split_indices(
        len(train_data),
        num_users=args.num_users,
        seed=int(args.random_seed),
    )
    print(f"[INFO] split=equal_random clients={args.num_users}", flush=True)

    train_eval_loader = DataLoader(
        train_eval_data,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
    )

    global_sd = copy.deepcopy(model.state_dict())

    os.makedirs("runs", exist_ok=True)
    run_name = f"FedAvg_{args.dataset}_{args.model_name}_u{args.num_users}_equal_seed{args.random_seed}"
    csv_path = os.path.join("runs", f"{run_name}.csv")

    if not os.path.exists(csv_path):
        with open(csv_path, "w", newline="") as f:
            csv.writer(f).writerow([
                "round",
                "train_acc",
                "test_acc",
                "train_loss",
                "test_loss",
                "lr",
            ])

    if args.wandb:
        wandb.init(
            project=args.wandb_project,
            entity=args.wandb_entity,
            name=run_name,
            config=vars(args),
        )

    t0 = time.time()

    for t in range(1, args.epochs + 1):
        lr = step_round_lr(
            round_idx=t,
            base_lr=float(args.base_lr),
            drop_epoch=int(args.lr_drop_epoch),
            drop_factor=float(args.lr_drop_factor),
        )

        local_sds = []
        local_weights = []

        for cid in range(args.num_users):
            ds_k = DatasetSplit(train_data, dict_users[cid])
            loader_k = DataLoader(
                ds_k,
                batch_size=args.local_bs,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True,
                drop_last=False,
            )

            local_model = copy.deepcopy(model).to(device)
            local_model.load_state_dict(global_sd, strict=True)

            local_sd = local_train_fedavg(
                model=local_model,
                trainloader=loader_k,
                lr=lr,
                local_ep=int(args.local_ep),
                momentum=float(args.momentum),
                weight_decay=float(args.weight_decay),
                grad_clip=float(args.grad_clip),
            )

            local_sds.append(local_sd)
            local_weights.append(len(dict_users[cid]))

        weights = [w / max(sum(local_weights), 1e-12) for w in local_weights]
        global_sd = weighted_avg_state_dict(local_sds, weights)
        model.load_state_dict(global_sd, strict=True)

        train_loss, train_acc = eval_loss_acc(model, train_eval_loader, device)
        test_loss, test_acc = eval_loss_acc(model, test_loader, device)

        print(
            f"Round {t:03d}/{args.epochs:03d} "
            f"| TrainAcc {train_acc:7.4f} "
            f"| TestAcc {test_acc:7.4f} "
            f"| TrainLoss {train_loss:7.4f} "
            f"| TestLoss {test_loss:7.4f} "
            f"| lr={lr:.5f} "
            f"| {(time.time() - t0) / 60.0:5.1f}m",
            flush=True,
        )

        with open(csv_path, "a", newline="") as f:
            csv.writer(f).writerow([
                t,
                train_acc,
                test_acc,
                train_loss,
                test_loss,
                lr,
            ])

        if args.wandb:
            wandb.log(
                {
                    "Round": t,
                    "Train/Acc": train_acc,
                    "Test/Acc": test_acc,
                    "Train/Loss": train_loss,
                    "Test/Loss": test_loss,
                    "Optimization/LR": lr,
                },
                step=t,
            )

    if args.wandb:
        wandb.finish()