import copy
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset


class DatasetSplit(Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = list(indices)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        return self.dataset[int(self.indices[i])]


def get_targets_from_dataset(ds):
    if hasattr(ds, "targets"):
        return np.array(ds.targets)

    if hasattr(ds, "dataset") and hasattr(ds, "indices"):
        base_targets = np.array(ds.dataset.targets)
        return base_targets[np.array(ds.indices, dtype=np.int64)]

    raise ValueError("Cannot extract targets from dataset.")


def stratified_equal_split_indices(dataset, num_users, num_classes=100, seed=0):
    rng = np.random.default_rng(seed)
    targets = get_targets_from_dataset(dataset)

    user_indices = [[] for _ in range(num_users)]
    for c in range(num_classes):
        cls_idx = np.where(targets == c)[0]
        rng.shuffle(cls_idx)
        shards = np.array_split(cls_idx, num_users)
        for u in range(num_users):
            user_indices[u].extend(shards[u].tolist())

    out = {}
    for u in range(num_users):
        arr = np.array(user_indices[u], dtype=np.int64)
        rng.shuffle(arr)
        out[u] = arr
    return out


def average_state_dicts(state_dicts):
    if len(state_dicts) == 0:
        raise ValueError("state_dicts must be non-empty")

    out = copy.deepcopy(state_dicts[0])
    n = float(len(state_dicts))

    for k in out.keys():
        acc = None
        for sd in state_dicts:
            v = sd[k]
            acc = v.clone() if acc is None else acc + v
        out[k] = acc / n
    return out


@torch.no_grad()
def bn_calibrate(model, loader, device, max_batches=30):
    model.train()
    for bi, (x, _) in enumerate(loader):
        x = x.to(device, non_blocking=True)
        _ = model(x)
        if bi + 1 >= int(max_batches):
            break
    model.eval()


def ds_server_update_x(old_global_sd, avg_local_sd, gamma_x):
    """
    Algorithm 3 server update:
        x^{tau+1} = x^tau - gamma_x/K sum_k (x^tau - x_k^{t+1})
                   = x^tau - gamma_x (x^tau - avg_k x_k^{t+1})
    """
    gx = float(np.clip(gamma_x, 0.0, 2.0))
    new_sd = copy.deepcopy(old_global_sd)

    for k in new_sd.keys():
        if torch.is_tensor(new_sd[k]) and torch.is_floating_point(new_sd[k]):
            new_sd[k] = old_global_sd[k] - gx * (old_global_sd[k] - avg_local_sd[k])
        else:
            new_sd[k] = avg_local_sd[k]
    return new_sd


def ds_server_update_y(y_old, avg_local_y, gamma_y):
    """
    Algorithm 3 server update:
        y^{tau+1} = y^tau - gamma_y/K sum_k (y^tau - y_k^{t+1})
                  = y^tau - gamma_y (y^tau - avg_k y_k^{t+1})
    """
    gy = float(np.clip(gamma_y, 0.0, 2.0))
    y_new = float(y_old) - gy * (float(y_old) - float(avg_local_y))
    return float(np.clip(y_new, 0.0, 10.0))


def compute_gk_value(model, x, y):
    """
    Inner-function estimate g_k(x; zeta).
    For classification, use batch-average CE as a practical scalar inner estimate.
    """
    logits = model(x)
    ce = nn.CrossEntropyLoss(reduction="mean")
    return ce(logits, y)


def compute_hk_value(model, x, y):
    """
    Non-compositional part h_k(x; xi).
    For a practical DRO-style surrogate, use negative half squared CE:
        h_k ~ - 0.5 * E[ell(x)^2]
    This keeps the update in the h(x) + f(g(x)) form used by the paper.
    """
    logits = model(x)
    ce_none = nn.CrossEntropyLoss(reduction="none")
    per_sample_ce = ce_none(logits, y)
    return -0.5 * (per_sample_ce ** 2).mean()


def compute_phi_surrogate(model, x, y, y_scalar, lamda):
    """
    Practical surrogate for Eq. (7):
        grad Phi_k = grad h_k + grad g_k * grad f(y)

    Using chi-square DRO-style outer function:
        f(y) = y^2 / (2 * lamda)
        grad f(y) = y / lamda

    So a scalar surrogate whose autograd gives the desired chain-rule form is:
        h_k(x; xi) + (y_k / lamda) * g_k(x; zeta)
    """
    hk = compute_hk_value(model, x, y)
    gk = compute_gk_value(model, x, y)

    coeff = torch.tensor(
        float(y_scalar) / max(float(lamda), 1e-12),
        dtype=gk.dtype,
        device=gk.device,
    )
    return hk + coeff * gk


def local_step_ds_feddro(
    model,
    trainloader,
    y_scalar,
    eta,
    lamda,
    local_steps=1,
    momentum=0.9,
    weight_decay=5e-4,
    grad_clip=5.0,
    beta_y=0.05,
):
    """
    One client's Algorithm 3 local updates:
      x_k^{t+1} = x_k^t - eta^t grad Phi_k(...)
      y_k^{t+1} = (1 - beta^t) y_k^t + beta^t g_k(x_k^{t+1}; zeta_k^{t+1})
    """
    device = next(model.parameters()).device
    model.train()

    opt = torch.optim.SGD(
        model.parameters(),
        lr=float(eta),
        momentum=float(momentum),
        weight_decay=float(weight_decay),
        nesterov=True,
    )

    y_local = float(np.clip(y_scalar, 0.0, 10.0))
    loader_iter = iter(trainloader)

    for _ in range(int(local_steps)):
        try:
            x, y = next(loader_iter)
        except StopIteration:
            loader_iter = iter(trainloader)
            x, y = next(loader_iter)

        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        opt.zero_grad(set_to_none=True)

        phi = compute_phi_surrogate(
            model=model,
            x=x,
            y=y,
            y_scalar=y_local,
            lamda=float(lamda),
        )

        if not torch.isfinite(phi):
            continue

        phi.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=float(grad_clip))
        opt.step()

        with torch.no_grad():
            g_new = compute_gk_value(model, x, y)
            if torch.isfinite(g_new):
                y_local = (1.0 - float(beta_y)) * y_local + float(beta_y) * float(g_new.item())
                y_local = float(np.clip(y_local, 0.0, 10.0))

    return copy.deepcopy(model.state_dict()), float(y_local)