from typing import Any
import signal
from utils.function import *

def _persistent_iters(train_loaders):
    return [iter(ld) for ld in train_loaders]

def lr_at(it, T, base=0.06):
    if it < 2360:
        return base
    else:
        return 0.01


def train_core(models: List[torch.nn.Module],
               loaders: Dict[str, Any],
               W: List[torch.Tensor],
               cfg,
               approach: str ="I" ) -> Dict[str, Any]:

    print(f"Training by Approach {approach}")
    device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
    num_users = cfg.num_users
    train_loaders, test_loader = loaders["train_loaders"], loaders["test_loader"]
    iters = _persistent_iters(train_loaders)
    criterion = nn.CrossEntropyLoss()
    total = sum(cfg.split_ratio)
    ratios_norm = [float(r) * num_users / total for r in cfg.split_ratio]
    print(f"Ratios norm: {ratios_norm}")
    if approach == "I":
        ratios = ratios_norm
        stead_vec = [1 / num_users] * num_users
    else:
        ratios = [1.0] * num_users
        stead_vec = [x / num_users for x in ratios_norm]
    run = None
    if getattr(cfg, "use_wandb", False):
        run = wandb.init(entity=cfg.entity, project=cfg.project,
                         name=cfg.run_name, reinit=True,
                         config={"lr": cfg.lr, "num_users": cfg.num_users,
                                 "topology": cfg.topology,
                                 "eval_every": cfg.eval_every})

    history = {"loss": [], "acc": []}
    running_loss_sum = 0.0
    W_t = torch.as_tensor(W, dtype=torch.float32, device=device)
    acc_0 = 0.0
    for u in range(num_users):
        acc_0 += test(models[u], test_loader, ratios[u] * stead_vec[u])
    y_list = [zero_like_param_dict(models[0]) for _ in range(num_users)]
    g_prev = [zero_like_param_dict(models[0]) for _ in range(num_users)]
    loss_mean_0 = 0.0
    for u in range(num_users):
        try:
            x, y = next(iters[u])
        except StopIteration:
            iters[u] = iter(train_loaders[u]); x, y = next(iters[u])
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        g0, loss_0 = compute_grad_dict(models[u], x, y, criterion, ratios[u], cfg.weight_decay)
        y_list[u] = {k: v.clone() for k, v in g0.items()}
        g_prev[u] = {k: v.clone() for k, v in g0.items()}
        loss_mean_0 += loss_0 * stead_vec[u]
    if run is not None:
        wandb.log({"train_loss": loss_mean_0, "interval_loss": loss_mean_0, "test_acc": acc_0}, step=0)
    grad_bar_0 = {}
    for key in y_list[0].keys():
        grad_bar_0[key] = sum(w * y[key] for w, y in zip(stead_vec, y_list))
    gnorm_0 = grad_norm(grad_bar_0)
    dt = next(models[0].parameters()).dtype
    for it in range(cfg.iterations):
        # check_bn_divergence(models)
        with torch.no_grad():
            for name, param in models[0].named_parameters():
                theta = torch.stack([m.state_dict()[name].data.to(device, dtype=dt) for m in models], dim=0)
                y = torch.stack([y_list[u][name].to(device, dtype=dt) for u in range(num_users)], dim=0)
                x_new = theta - cfg.lr * y
                x_consensus = torch.einsum('uv,v...->u...', W_t, x_new)
                for u in range(num_users):
                    models[u].state_dict()[name].copy_(x_consensus[u])


        grad_bar = {}
        for key in y_list[0].keys():
            grad_bar[key] = sum(w * y[key] for w, y in zip(stead_vec, y_list))
        loss_mean_iter = 0.0
        g_new, loss_this_iter = [None]*num_users, []
        for u in range(num_users):
            try:
                images, labels = next(iters[u])
            except StopIteration:
                iters[u] = iter(train_loaders[u])
                images, labels = next(iters[u])
            images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            g_new[u], loss_u = compute_grad_dict(models[u], images, labels, criterion, ratios[u], cfg.weight_decay)
            loss_mean_iter += loss_u * stead_vec[u]

        running_loss_sum += loss_mean_iter
        with torch.no_grad():
            for name, param in models[0].named_parameters():
                y_matrix = torch.stack([y_list[v][name].to(device, dtype=dt) for v in range(num_users)])

                mixed = torch.einsum('uv,v...->u...', W_t, y_matrix)

                for u in range(num_users):
                    y_list[u][name] = mixed[u] + (g_new[u][name] - g_prev[u][name])

        g_prev = [{k: v.detach().clone() for k, v in g_new[u].items()} for u in range(num_users)]
        from collections import deque
        acc_win = deque(maxlen=getattr(cfg, "smooth_k", 5))
        if it == 0:
            history["loss"].append(loss_mean_0)
            history["acc"].append(acc_0)
            print(f"[Iter {it}] Interval Average loss: {loss_mean_0:.6f}")
        elif (it > 0 and (it + 1) % cfg.eval_every == 0) or (it == cfg.iterations - 1):
            acc_point = 0.0
            for u in range(num_users):
                acc_point += test(models[u], test_loader, ratios[u] * stead_vec[u])
            acc_win.append(acc_point)
            acc_mean = sum(acc_win) / len(acc_win)
            interval_avg_loss = running_loss_sum / cfg.eval_every
            gnorm = grad_norm(grad_bar)
            print(f"[Iter {it}] Acc_mean: {acc_mean:.2f}%")
            print(f"[Iter {it}] Interval Average loss: {interval_avg_loss:.6f}")

            history["loss"].append(interval_avg_loss)
            history["acc"].append(acc_mean)

            if run is not None:
                wandb.log({
                    "test_acc": acc_mean, "interval_loss": interval_avg_loss}, step=it)
            running_loss_sum = 0.0
        if run is not None:
            wandb.log({"train_loss": loss_mean_iter}, step=it+1)
    signal.signal(signal.SIGALRM, handler)
    signal.alarm(10)

    try:
        run.finish()
    except TimeoutException:
        print("wandb.finish error")
    finally:
        signal.alarm(0)
    return history

def train_I(models, loaders, W, cfg, approach="I"):
    cfg.run_name = f"I_lr{cfg.lr}_bs{cfg.batch_size_local}_{cfg.topology}_{cfg.seed}"
    return train_core(models, loaders, W, cfg, approach)

def train_II(models, loaders, W, cfg, approach="II"):
    cfg.run_name = f"II_lr{cfg.lr}_bs{cfg.batch_size_local}_{cfg.topology}_{cfg.seed}"
    return train_core(models, loaders, W, cfg, approach)