from __future__ import annotations

import torch
from torch import Tensor
from typing import List, Tuple, Optional

from .utils import project_simplex


def compute_gram_matrix(gradients: List[Tensor]) -> Tensor:
    K = len(gradients)
    device = gradients[0].device if isinstance(gradients[0], Tensor) else "cpu"
    G = torch.zeros((K, K), device=device)
    for i in range(K):
        for j in range(i, K):
            dot = torch.sum(gradients[i] * gradients[j])
            G[i, j] = dot
            G[j, i] = dot
    return G


def compute_gram_from_params(
    grad_dicts: List[dict],
    param_names: Optional[List[str]] = None,
) -> Tensor:
    K = len(grad_dicts)
    if K == 0:
        return torch.zeros((0, 0))
    if param_names is None:
        param_names = list(grad_dicts[0].keys())
    device = "cpu"
    for name in param_names:
        if name in grad_dicts[0] and grad_dicts[0][name] is not None:
            device = grad_dicts[0][name].device
            break
    G = torch.zeros((K, K), device=device)
    with torch.no_grad():
        for i in range(K):
            for j in range(i, K):
                dot = torch.tensor(0.0, device=device)
                for name in param_names:
                    gi = grad_dicts[i].get(name)
                    gj = grad_dicts[j].get(name)
                    if gi is None or gj is None:
                        continue
                    dot = dot + torch.sum(gi * gj)
                G[i, j] = dot
                G[j, i] = dot
    return G


def solve_mgda_qp(
    G: Tensor,
    max_iters: int = 100,
    lr: float = 0.1,
    eps: float = 1e-8,
    tol: float = 1e-6,
) -> Tensor:
    K = G.shape[0]
    if K == 0:
        return torch.zeros(0, device=G.device)
    if K == 1:
        return torch.ones(1, device=G.device)
    with torch.no_grad():
        lam = torch.ones(K, device=G.device) / K
        G_reg = G + eps * torch.eye(K, device=G.device)
        prev_obj = float("inf")
        for _it in range(int(max_iters)):
            grad = 2 * (G_reg @ lam)
            lam_new = lam - lr * grad
            lam_new = project_simplex(lam_new, z=1.0)
            obj = float((lam_new @ G_reg @ lam_new).item())
            if abs(prev_obj - obj) < tol:
                lam = lam_new
                break
            prev_obj = obj
            lam = lam_new
        return lam


def compute_pareto_interference(
    gradients: List[Tensor],
    lam: Tensor,
    eps: float = 1e-8,
) -> Tensor:
    K = len(gradients)
    if K == 0:
        return torch.zeros(0)
    device = gradients[0].device
    g_norms = torch.tensor([torch.norm(g) + eps for g in gradients], device=device)
    conf = torch.zeros(K, device=device)
    for k in range(K):
        for j in range(K):
            if j != k:
                dot_product = torch.sum(gradients[k] * gradients[j])
                cos_kj = dot_product / (g_norms[k] * g_norms[j])
                conflict_kj = torch.clamp(-cos_kj, min=0.0)
                conf[k] += lam[j] * conflict_kj
    return conf


def compute_pareto_interference_from_gram(
    G: Tensor,
    lam: Tensor,
    eps: float = 1e-8,
) -> Tensor:
    K = G.shape[0]
    device = G.device
    if K == 0:
        return torch.zeros(0, device=device)
    g_norms = torch.sqrt(torch.diag(G) + eps)
    norm_outer = g_norms.unsqueeze(1) * g_norms.unsqueeze(0)
    cos_matrix = G / (norm_outer + eps)
    cos_matrix = torch.clamp(cos_matrix, -1.0, 1.0)
    conflict_matrix = torch.clamp(-cos_matrix, min=0.0)
    conf = conflict_matrix @ lam
    return conf


class MGDASolver:
    def __init__(
        self,
        K: int,
        max_iters: int = 100,
        lr: float = 0.1,
        eps: float = 1e-8,
        device: str = "cpu",
    ):
        self.K = K
        self.max_iters = max_iters
        self.lr = lr
        self.eps = eps
        self.device = device
        self.last_lam: Optional[Tensor] = None
        self.last_G: Optional[Tensor] = None
    
    def solve(self, gradients: List[Tensor]) -> Tuple[Tensor, Tensor]:
        G = compute_gram_matrix(gradients)
        G = G.to(self.device)
        self.last_G = G
        lam = solve_mgda_qp(G, max_iters=self.max_iters, lr=self.lr, eps=self.eps)
        self.last_lam = lam
        conf = compute_pareto_interference(gradients, lam, eps=self.eps)
        return lam, conf
    
    def solve_from_gram(self, G: Tensor) -> Tuple[Tensor, Tensor]:
        G = G.to(self.device)
        self.last_G = G
        lam = solve_mgda_qp(G, max_iters=self.max_iters, lr=self.lr, eps=self.eps)
        self.last_lam = lam
        conf = compute_pareto_interference_from_gram(G, lam, eps=self.eps)
        return lam, conf
