import copy
import csv
import math
import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import wandb

from ddro_utils import (
    DatasetSplit,
    stratified_equal_split_indices,
    average_state_dicts,
    bn_calibrate,
    ds_server_update_x,
    ds_server_update_y,
    local_step_ds_feddro,
    make_balanced_sampler,
    server_rehearsal,
)


@torch.no_grad()
def eval_loss_acc(model, loader, device):
    model.eval()
    ce = nn.CrossEntropyLoss(reduction="sum")
    total_loss, correct, total = 0.0, 0, 0

    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        logits = model(x)
        total_loss += float(ce(logits, y).item())
        pred = logits.argmax(dim=1)
        correct += int((pred == y).sum().item())
        total += int(y.numel())

    return total_loss / max(total, 1), 100.0 * correct / max(total, 1)


@torch.no_grad()
def eval_avg_client_train_acc(model, train_data, dict_users, device, num_clients_eval=32, bs=256, workers=2):
    accs = []
    for cid in list(dict_users.keys())[:num_clients_eval]:
        ds = DatasetSplit(train_data, dict_users[cid])
        loader = DataLoader(
            ds,
            batch_size=bs,
            shuffle=False,
            num_workers=workers,
            pin_memory=True,
        )
        _, acc = eval_loss_acc(model, loader, device)
        accs.append(acc)
    return float(np.mean(accs)) if accs else 0.0


def cosine_round_lr(round_idx, total_rounds, base_lr, min_lr=0.001, warmup_rounds=5):
    if round_idx <= warmup_rounds:
        return min_lr + (base_lr - min_lr) * (round_idx / max(warmup_rounds, 1))
    progress = (round_idx - warmup_rounds) / max(total_rounds - warmup_rounds, 1)
    progress = min(max(progress, 0.0), 1.0)
    return min_lr + 0.5 * (base_lr - min_lr) * (1.0 + math.cos(math.pi * progress))


def scheduled_mixup_alpha(t, start_round=25, full_round=80, max_alpha=0.15):
    if t < start_round:
        return 0.0
    if t >= full_round:
        return float(max_alpha)
    r = (t - start_round) / max(full_round - start_round, 1)
    return float(max_alpha * r)


def DS_FedDRO_M(args, model, train_data, train_eval_data, test_loader):
    device = args.device

    dict_users = stratified_equal_split_indices(
        train_data,
        num_users=args.num_users,
        num_classes=args.num_classes,
        seed=int(args.random_seed),
    )
    print(f"[INFO] split=stratified_equal clients={args.num_users}", flush=True)

    train_eval_loader = DataLoader(
        train_eval_data,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
    )

    balanced_sampler = make_balanced_sampler(train_eval_data, args.num_classes)
    balanced_server_loader = DataLoader(
        train_eval_data,
        batch_size=args.batch_size,
        sampler=balanced_sampler,
        num_workers=args.workers,
        pin_memory=True,
    )

    x_global = copy.deepcopy(model.state_dict())

    ce = nn.CrossEntropyLoss(reduction="mean")
    init_vals = []
    model.eval()
    for bi, (x, y) in enumerate(train_eval_loader):
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        loss = ce(model(x), y)
        if torch.isfinite(loss):
            init_vals.append(float(loss.item()))
        if bi + 1 >= 5:
            break

    y_global = float(np.mean(init_vals)) if init_vals else 1.0
    y_global = float(np.clip(y_global, 0.0, float(args.y_clip)))

    client_x = {cid: copy.deepcopy(x_global) for cid in range(args.num_users)}
    client_y = {cid: float(y_global) for cid in range(args.num_users)}

    tau = 0
    os.makedirs("runs", exist_ok=True)
    run_name = (
        f"DS_FedDRO_M_{args.dataset}_{args.model_name}"
        f"_u{args.num_users}_K{args.K}_I{args.I}_seed{args.random_seed}"
    )
    csv_path = os.path.join("runs", f"{run_name}.csv")

    if not os.path.exists(csv_path):
        with open(csv_path, "w", newline="") as f:
            csv.writer(f).writerow([
                "round", "tau", "used",
                "train_acc_local", "train_acc_global", "test_acc",
                "train_loss_global", "test_loss", "y_global", "lr"
            ])

    if args.wandb:
        wandb.init(
            project=args.wandb_project,
            entity=args.wandb_entity,
            name=run_name,
            config=vars(args),
        )

    t0 = time.time()

    for t in range(1, args.epochs + 1):
        lr = cosine_round_lr(
            round_idx=t,
            total_rounds=args.epochs,
            base_lr=float(args.base_lr),
            min_lr=float(args.min_lr),
            warmup_rounds=int(args.warmup_rounds),
        )

        mixup_alpha = scheduled_mixup_alpha(
            t,
            start_round=int(args.mixup_start_round),
            full_round=int(args.mixup_full_round),
            max_alpha=float(args.mixup_alpha),
        )

        selected = list(range(args.num_users))
        local_sds = []
        local_yvals = []
        local_weights = []

        for cid in selected:
            ds_k = DatasetSplit(train_data, dict_users[cid])
            loader_k = DataLoader(
                ds_k,
                batch_size=args.local_bs,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True,
                drop_last=False,
            )

            local_model = copy.deepcopy(model).to(device)
            local_model.load_state_dict(client_x[cid], strict=True)

            x_new, y_new = local_step_ds_feddro(
                model=local_model,
                trainloader=loader_k,
                y_scalar=client_y[cid],
                eta=lr,
                lamda=float(args.lamda),
                local_ep=int(args.local_ep),
                momentum=float(args.momentum),
                weight_decay=float(args.weight_decay),
                grad_clip=float(args.grad_clip),
                beta_y=float(args.beta_y),
                label_smoothing=float(args.label_smoothing),
                mixup_alpha=float(mixup_alpha),
                y_clip=float(args.y_clip),
            )

            client_x[cid] = copy.deepcopy(x_new)
            client_y[cid] = float(np.clip(y_new, 0.0, float(args.y_clip)))

            local_sds.append(x_new)
            local_yvals.append(client_y[cid])
            local_weights.append(len(dict_users[cid]))

        if t % int(args.I) == 0:
            avg_local_sd = average_state_dicts(local_sds, local_weights)
            avg_local_y = float(np.average(local_yvals, weights=np.array(local_weights, dtype=np.float64)))

            x_global = ds_server_update_x(
                old_global_sd=x_global,
                avg_local_sd=avg_local_sd,
                gamma_x=float(args.gamma_x),
            )
            y_global = ds_server_update_y(
                y_old=y_global,
                avg_local_y=avg_local_y,
                gamma_y=float(args.gamma_y),
                y_clip=float(args.y_clip),
            )

            model.load_state_dict(x_global, strict=True)

            if t >= int(args.server_rehearsal_start_round) and int(args.server_rehearsal_steps) > 0:
                server_rehearsal(
                    model=model,
                    loader=balanced_server_loader,
                    device=device,
                    lr=float(args.server_rehearsal_lr),
                    steps=int(args.server_rehearsal_steps),
                    weight_decay=float(args.weight_decay),
                    grad_clip=float(args.grad_clip),
                    label_smoothing=float(args.server_label_smoothing),
                )

            bn_calibrate(model, train_eval_loader, device=device, max_batches=20)
            x_global = copy.deepcopy(model.state_dict())

            for cid in selected:
                client_x[cid] = copy.deepcopy(x_global)
                client_y[cid] = float(y_global)

            tau += 1
        else:
            model.load_state_dict(x_global, strict=True)

        train_acc_local = eval_avg_client_train_acc(
            model=model,
            train_data=train_eval_data,
            dict_users=dict_users,
            device=device,
            num_clients_eval=int(args.eval_num_clients),
            bs=args.batch_size,
            workers=args.workers,
        )
        train_loss_global, train_acc_global = eval_loss_acc(model, train_eval_loader, device)
        test_loss, test_acc = eval_loss_acc(model, test_loader, device)

        print(
            f"Round {t:03d}/{args.epochs:03d} | tau={tau:03d} "
            f"| used={len(selected):02d}/{args.num_users:02d} "
            f"| TrainAcc(Local) {train_acc_local:6.2f} | TrainAcc(Global) {train_acc_global:6.2f} "
            f"| TestAcc {test_acc:6.2f} | TrainLoss(Global) {train_loss_global:7.4f} "
            f"| TestLoss {test_loss:7.4f} | y_global {y_global:7.4f} | lr={lr:.5f} "
            f"| {(time.time() - t0)/60.0:5.1f}m",
            flush=True,
        )

        with open(csv_path, "a", newline="") as f:
            csv.writer(f).writerow([
                t, tau, len(selected),
                train_acc_local, train_acc_global, test_acc,
                train_loss_global, test_loss, y_global, lr
            ])

        if args.wandb:
            wandb.log(
                {
                    "Round": t,
                    "Tau": tau,
                    "Train/Acc_Local": train_acc_local,
                    "Train/Acc_Global": train_acc_global,
                    "Test/Acc": test_acc,
                    "Train/Loss_Global": train_loss_global,
                    "Test/Loss": test_loss,
                    "DRO/y_global": y_global,
                    "Optimization/LR": lr,
                    "Optimization/MixupAlpha": mixup_alpha,
                },
                step=t,
            )

    if args.wandb:
        wandb.finish()


# alias
DS_FedDRO = DS_FedDRO_M