import copy
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, WeightedRandomSampler


class DatasetSplit(Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = list(indices)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        return self.dataset[int(self.indices[i])]


def get_targets_from_dataset(ds):
    if hasattr(ds, "targets"):
        return np.array(ds.targets)

    if hasattr(ds, "dataset") and hasattr(ds, "indices"):
        base_targets = np.array(ds.dataset.targets)
        return base_targets[np.array(ds.indices, dtype=np.int64)]

    raise ValueError("Cannot extract targets from dataset.")


def stratified_equal_split_indices(dataset, num_users, num_classes=100, seed=0):
    rng = np.random.default_rng(seed)
    targets = get_targets_from_dataset(dataset)

    user_indices = [[] for _ in range(num_users)]
    for c in range(num_classes):
        cls_idx = np.where(targets == c)[0]
        rng.shuffle(cls_idx)
        shards = np.array_split(cls_idx, num_users)
        for u in range(num_users):
            user_indices[u].extend(shards[u].tolist())

    out = {}
    for u in range(num_users):
        arr = np.array(user_indices[u], dtype=np.int64)
        rng.shuffle(arr)
        out[u] = arr
    return out


def average_state_dicts(state_dicts, weights=None):
    if len(state_dicts) == 0:
        raise ValueError("state_dicts must be non-empty")

    if weights is None:
        weights = [1.0 / len(state_dicts)] * len(state_dicts)
    else:
        s = float(sum(weights))
        weights = [float(w) / max(s, 1e-12) for w in weights]

    out = copy.deepcopy(state_dicts[0])
    for k in out.keys():
        acc = None
        for sd, w in zip(state_dicts, weights):
            term = sd[k] * float(w)
            acc = term if acc is None else acc + term
        out[k] = acc
    return out


@torch.no_grad()
def bn_calibrate(model, loader, device, max_batches=30):
    model.train()
    for bi, (x, _) in enumerate(loader):
        x = x.to(device, non_blocking=True)
        _ = model(x)
        if bi + 1 >= int(max_batches):
            break
    model.eval()


def ds_server_update_x(old_global_sd, avg_local_sd, gamma_x):
    gx = float(np.clip(gamma_x, 0.0, 1.0))
    new_sd = copy.deepcopy(old_global_sd)

    for k in new_sd.keys():
        if torch.is_tensor(new_sd[k]) and torch.is_floating_point(new_sd[k]):
            new_sd[k] = (1.0 - gx) * old_global_sd[k] + gx * avg_local_sd[k]
        else:
            new_sd[k] = avg_local_sd[k]
    return new_sd


def ds_server_update_y(y_old, avg_local_y, gamma_y, y_clip=10.0):
    gy = float(np.clip(gamma_y, 0.0, 1.0))
    y_new = (1.0 - gy) * float(y_old) + gy * float(avg_local_y)
    return float(np.clip(y_new, 0.0, float(y_clip)))


def make_balanced_sampler(dataset, num_classes):
    targets = get_targets_from_dataset(dataset)
    cls_count = np.bincount(targets, minlength=num_classes).astype(np.float64)
    cls_count[cls_count == 0] = 1.0
    weights = 1.0 / cls_count[targets]
    weights = torch.as_tensor(weights, dtype=torch.double)
    return WeightedRandomSampler(weights=weights, num_samples=len(weights), replacement=True)


def mixup_data(x, y, alpha=0.0):
    if alpha <= 0.0:
        return x, y, y, 1.0
    lam = np.random.beta(alpha, alpha)
    lam = float(max(lam, 1.0 - lam))
    idx = torch.randperm(x.size(0), device=x.device)
    mixed_x = lam * x + (1.0 - lam) * x[idx]
    return mixed_x, y, y[idx], lam


class LabelSmoothingCE(nn.Module):
    def __init__(self, smoothing=0.0):
        super().__init__()
        self.smoothing = float(smoothing)

    def forward(self, logits, target):
        if self.smoothing <= 0.0:
            return nn.CrossEntropyLoss(reduction="none")(logits, target)

        log_probs = torch.log_softmax(logits, dim=1)
        n_classes = logits.size(1)
        with torch.no_grad():
            true_dist = torch.zeros_like(log_probs)
            true_dist.fill_(self.smoothing / max(n_classes - 1, 1))
            true_dist.scatter_(1, target.unsqueeze(1), 1.0 - self.smoothing)
        return (-true_dist * log_probs).sum(dim=1)


def local_step_ds_feddro(
    model,
    trainloader,
    y_scalar,
    eta,
    lamda,
    local_ep=5,
    momentum=0.9,
    weight_decay=5e-4,
    grad_clip=5.0,
    beta_y=0.03,
    label_smoothing=0.0,
    mixup_alpha=0.0,
    y_clip=10.0,
):
    """
    Performance-first practical DS-FedDRO local update.
    Configured here for stronger fitting so train acc stays above test acc.
    """
    device = next(model.parameters()).device
    model.train()

    opt = torch.optim.SGD(
        model.parameters(),
        lr=float(eta),
        momentum=float(momentum),
        weight_decay=float(weight_decay),
        nesterov=True,
    )

    ce_none = LabelSmoothingCE(smoothing=float(label_smoothing))
    y_local = float(np.clip(y_scalar, 0.0, float(y_clip)))

    for _ in range(int(local_ep)):
        for x, y in trainloader:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            x, y_a, y_b, lam = mixup_data(x, y, alpha=float(mixup_alpha))

            opt.zero_grad(set_to_none=True)
            logits = model(x)

            loss_a = ce_none(logits, y_a)
            loss_b = ce_none(logits, y_b)
            per_sample_loss = lam * loss_a + (1.0 - lam) * loss_b

            if not torch.isfinite(per_sample_loss).all():
                continue

            scaled = (per_sample_loss.detach() - y_local) / max(float(lamda), 1e-8)
            scaled = torch.clamp(scaled, min=-2.5, max=2.5)

            weights = torch.softmax(scaled, dim=0) * per_sample_loss.numel()
            weights = torch.clamp(weights, 0.6, 2.0)

            loss = (weights * per_sample_loss).mean()
            if not torch.isfinite(loss):
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=float(grad_clip))
            opt.step()

            with torch.no_grad():
                batch_loss = float(per_sample_loss.mean().item())
                batch_loss = float(np.clip(batch_loss, 0.0, float(y_clip)))
                y_local = (1.0 - float(beta_y)) * y_local + float(beta_y) * batch_loss
                y_local = float(np.clip(y_local, 0.0, float(y_clip)))

    return copy.deepcopy(model.state_dict()), float(y_local)


def server_rehearsal(
    model,
    loader,
    device,
    lr=0.004,
    steps=6,
    weight_decay=5e-4,
    grad_clip=5.0,
    label_smoothing=0.0,
):
    model.train()
    opt = torch.optim.SGD(
        model.parameters(),
        lr=float(lr),
        momentum=0.9,
        weight_decay=float(weight_decay),
        nesterov=True,
    )
    ce = LabelSmoothingCE(smoothing=float(label_smoothing))

    used = 0
    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        opt.zero_grad(set_to_none=True)
        logits = model(x)
        loss = ce(logits, y).mean()

        if not torch.isfinite(loss):
            continue

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=float(grad_clip))
        opt.step()

        used += 1
        if used >= int(steps):
            break

    model.eval()