import signal

import torch, numpy as np, gc
from typing import List, Dict, Any
import torch.nn as nn
import wandb

from least_square.train_ls_grad import compute_grad_dict_quadratic, weighted_optimum_closed_form
from utils.function import handler, TimeoutException


def zero_like_param_dict(model: nn.Module):
    return {name: torch.zeros_like(p, device=p.device) for name, p in model.named_parameters()}

@torch.no_grad()
def get_mean_model(models: List[nn.Module], stead_vec: List[float]) -> nn.Module:
    device = next(models[0].parameters()).device
    dtype  = next(models[0].parameters()).dtype
    d      = models[0].x.numel()
    stack  = torch.stack([m.x.detach().to(device=device, dtype=dtype) for m in models], dim=0)
    w      = torch.tensor(stead_vec, dtype=dtype, device=device).view(-1, 1)
    x_mean = (w * stack).sum(dim=0)
    mean_model = type(models[0])(d, x_init=x_mean, device=str(device)).to(device)
    return mean_model


def train_core_quadratic(models: List[nn.Module],
                         problem: Dict[str, Any],
                         W: List[List[float]],
                         cfg,
                         approach: str = "I") -> Dict[str, Any]:
    device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
    n = cfg.num_users
    Q_list, c_list = problem["Q_list"], problem["c_list"]

    if approach == "I":
        ratios = cfg.split_ratio
        stead_vec = [1.0 / n] * n
    else:
        ratios = [1.0] * n
        stead_vec = np.array(cfg.split_ratio, dtype=float)
        stead_vec = (stead_vec / stead_vec.sum()).tolist()
    run = None
    if getattr(cfg, "use_wandb", False):
        run = wandb.init(entity=cfg.entity, project=cfg.project,
                         name=cfg.run_name, reinit=True,
                         config={"lr": cfg.lr, "num_users": cfg.num_users,
                                 "topology": cfg.topology,
                                 "eval_every": cfg.eval_every})
    W_t = torch.as_tensor(W, dtype=torch.float32, device=device)
    dt  = next(models[0].parameters()).dtype
    y_list = [zero_like_param_dict(models[0]) for _ in range(n)]
    g_prev = [zero_like_param_dict(models[0]) for _ in range(n)]

    history = {"iters": [], "loss": [], "grad_norm": [], "global_grad_norm": []}

    x_star_alpha = weighted_optimum_closed_form(Q_list, c_list, cfg.split_ratio, rho=cfg.rho)

    loss_mean_0 = 0.0
    for u in range(n):
        g0, lu = compute_grad_dict_quadratic(models[u], Q_list[u], c_list[u],
                                             rho=cfg.rho, weight=ratios[u],
                                             noise_sigma=cfg.noise_sigma,node=u,seed=cfg.seed_noise,it=0)
        y_list[u] = {k: v.clone() for k, v in g0.items()}
        g_prev[u] = {k: v.clone() for k, v in g0.items()}
        loss_mean_0 += lu * stead_vec[u]

    mean_model_0 = get_mean_model(models, stead_vec)
    dist_0 = float(torch.linalg.vector_norm(mean_model_0.x.detach() - x_star_alpha).item())

    @torch.no_grad()
    def compute_agg_grad_from_models(models, stead_vec, Q_list, c_list, ratios, rho):
        n = len(models)

        g_bar = torch.zeros_like(models[0].x.detach())
        for i in range(n):
            x_i = models[i].x.detach()
            g_i = Q_list[i] @ (x_i - c_list[i])
            if rho > 0.0:
                g_i = g_i + rho * x_i
            g_bar = g_bar + float(stead_vec[i]) * g_i * ratios[i]
        return g_bar


    g_global_0 = compute_agg_grad_from_models(models, stead_vec, Q_list, c_list, ratios, cfg.rho)
    g_global_norm_0 = float(torch.linalg.vector_norm(g_global_0).item())
    running_loss_sum = 0.0

    for it in range(cfg.iterations):
        with torch.no_grad():
            for name, _ in models[0].named_parameters():
                theta = torch.stack([m.state_dict()[name].data.to(device, dtype=dt) for m in models], dim=0)
                y     = torch.stack([y_list[u][name].to(device, dtype=dt) for u in range(n)], dim=0)
                x_new = theta - cfg.lr * y
                x_cons= torch.einsum('uv,v...->u...', W_t, x_new)
                for u in range(n):
                    models[u].state_dict()[name].copy_(x_cons[u])

        loss_mean_iter = 0.0
        g_new = [None] * n
        for u in range(n):
            g_new[u], lu = compute_grad_dict_quadratic(models[u], Q_list[u], c_list[u],
                                                       rho=cfg.rho, weight=ratios[u],
                                                       noise_sigma=cfg.noise_sigma, node=u, seed=cfg.seed_noise,it=it+1)
            loss_mean_iter += lu * stead_vec[u]
        running_loss_sum += loss_mean_iter

        with torch.no_grad():
            for name, _ in models[0].named_parameters():
                Y = torch.stack([y_list[v][name].to(device, dtype=dt) for v in range(n)], dim=0)
                mixed = torch.einsum('uv,v...->u...', W_t, Y)
                for u in range(n):
                    y_list[u][name] = mixed[u] + (g_new[u][name] - g_prev[u][name])
        g_prev = [{k: v.detach().clone() for k, v in g_new[u].items()} for u in range(n)]

        if it == 0:
            print(f"[Iter 0] dist_to_opt: {dist_0:.6f} | global_grad_norm: {g_global_norm_0:.6f}")
            history["iters"].append(it)
            history["loss"].append(loss_mean_0)
            history["grad_norm"].append(dist_0)
            history["global_grad_norm"].append(g_global_norm_0)
            if run is not None:
                run.log({
                    "loss/avg": loss_mean_0,
                    "dist_to_opt": dist_0,
                    "global_grad_norm": g_global_norm_0,
                }, step=it)



        elif (it + 1) % cfg.eval_every == 0:
            mean_model = get_mean_model(models, stead_vec)
            dist = float(torch.linalg.vector_norm(mean_model.x.detach() - x_star_alpha).item())
            interval_avg_loss = running_loss_sum / cfg.eval_every

            g_global = compute_agg_grad_from_models(models, stead_vec, Q_list, c_list, ratios, cfg.rho)
            g_global_norm = float(torch.linalg.vector_norm(g_global).item())
            print(f"[Iter {it}] loss_avg: {interval_avg_loss:.6f} | dist:{dist:.6f} | "
                  f"||nabla F||:{g_global_norm:.6f}")
            history["iters"].append(it)
            history["loss"].append(interval_avg_loss)
            history["grad_norm"].append(dist)
            history["global_grad_norm"].append(g_global_norm)
            running_loss_sum = 0.0

            if run is not None:
                run.log({
                    "loss/avg": interval_avg_loss,
                    "dist_to_opt": dist,
                    "global_grad_norm": g_global_norm,
                }, step=it)

    signal.signal(signal.SIGALRM, handler)
    signal.alarm(5)

    try:
        run.finish()
    except TimeoutException:
        print("wandb.finish error")
    finally:
        signal.alarm(0)
    return history
