# quad_grad.py
import torch, torch.nn as nn
from typing import List

@torch.no_grad()
def compute_grad_dict_quadratic(model: nn.Module,
                                Q_i: torch.Tensor,  # = gamma_i * I
                                c_i: torch.Tensor,
                                rho: float = 0.0,
                                weight: float = 1.0,
                                noise_sigma: float = 0.0, node: int = 0, seed: int = 0, it: int =0) -> torch.Tensor:

    x = model.x.detach()
    g = Q_i @ (x - c_i)
    combined_seed = seed + node * 1000 + it * 10
    gen = torch.Generator().manual_seed(combined_seed)
    noise = torch.randn(g.shape, dtype=g.dtype,generator=gen).to(g.device)
    if rho > 0.0:
        g = g + rho * x
    if noise_sigma > 0.0:
        g = g + noise_sigma * noise
    if weight != 1.0:
        g = g * float(weight)

    diff = x - c_i
    loss = 0.5 * (diff @ (Q_i @ diff)) + 0.5 * rho * (x @ x)
    loss = float(loss.item()) * float(weight)

    grad = {}
    for name, p in model.named_parameters():
        grad[name] = g.to(p.dtype).to(p.device) if name == "x" else torch.zeros_like(p)
    return grad, loss

@torch.no_grad()
def weighted_optimum_closed_form(Q_list: List[torch.Tensor], c_list: List[torch.Tensor],
                                 weights: List[float], rho: float = 0.0) -> torch.Tensor:
    weights_nor = [x / sum(weights) for x in weights]
    device = Q_list[0].device
    d = c_list[0].numel()
    num = torch.zeros(d, device=device)
    den = rho
    for w, Q, c in zip(weights_nor, Q_list, c_list):
        gamma_i = Q[0,0]
        num = num + float(w) * gamma_i * c
        den = den + float(w) * gamma_i
    return num / den