# import numpy as np
# import scipy.linalg as la          # eigen-solvers

# # ------------------------------------------------------------------
# # 1.  Helper losses for an n×n similarity matrix A
# # ------------------------------------------------------------------
# def entropy_loss(A: np.ndarray) -> float:
#     """Row-softmax entropy   L_ent = −(1/n) Σ_i Σ_j Â_ij log Â_ij."""
#     expA   = np.exp(A - A.max(axis=1, keepdims=True))           # stability
#     soft_A = expA / expA.sum(axis=1, keepdims=True)
#     ent    = -np.sum(soft_A * np.log(np.clip(soft_A, 1e-12, None)))
#     return ent / A.shape[0]

# def algebraic_connectivity_loss(A: np.ndarray) -> float:
#     """
#     λ₂  (Fiedler value) of the *unnormalised* Laplacian  L = D − A.
#     • Smaller  ⇒ graph can split easily  ⇒ better community separability.
#     We treat λ₂ itself as the loss to minimise.
#     """
#     d = A.sum(axis=1)
#     L = np.diag(d) - A
#     # two smallest eigen-values (ascending)
#     lam = la.eigvalsh(L, eigvals=(0, 1))
#     lam2 = lam[1].real
#     return lam2                      # minimise for stronger separation

# # ------------------------------------------------------------------
# # 2.  Grid-search over w ∈ [0,1]  for fusion  A_w = w·G + (1−w)·P
# #     evaluates *both* entropy   and   algebraic-connectivity losses
# # ------------------------------------------------------------------
# def grid_search_fusion(G: np.ndarray,
#                        P: np.ndarray,
#                        w_grid = np.linspace(0.0, 1.0, 51)):
#     """
#     Returns list of dicts:
#         [{w, L_entropy, L_lambda2}, …]
#     """
#     results = []
#     for w in w_grid:
#         A   = w * G + (1 - w) * P
#         L_e = entropy_loss(A)
#         L_l2 = algebraic_connectivity_loss(A)
#         print(f"w={w:.2f}  Ent={L_e:.4f}  λ2={L_l2:.4f}")
#         results.append(dict(w=w, L_entropy=L_e, L_lambda2=L_l2))
#     return results

# # ------------------------------------------------------------------
# # 3.  Example call  (assumes  G  and  v  are defined & symmetric)
# # ------------------------------------------------------------------
# search = grid_search_fusion(G, v)          # 51-point grid, 0.00 → 1.00

# best_ent = min(search, key=lambda d: d["L_entropy"])    # minimise entropy
# best_l2  = min(search, key=lambda d: d["L_lambda2"])    # minimise λ₂

# print("\n=== Best by entropy ===")
# print(f"w = {best_ent['w']:.3f} ;  L_entropy = {best_ent['L_entropy']:.4f}")

# print("\n=== Best by algebraic connectivity (λ₂) ===")
# print(f"w = {best_l2['w']:.3f} ;  λ₂ = {best_l2['L_lambda2']:.4f}")

# # ======================================================================
# # Global-scalar gate fusion  (laplacian term removed, everything else unchanged)
# # ======================================================================

# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# # ----------------------------------------------------------------------
# #  MLP that outputs ONE sigmoid-gated scalar w ∈ (0,1)
# # ----------------------------------------------------------------------
# class GlobalGateMLP(nn.Module):
#     def __init__(self, n: int, hidden: int = 128):
#         super().__init__()
#         d_in = 2 * n * n                # vec(G) ‖ vec(P)
#         self.net = nn.Sequential(
#             nn.Linear(d_in, hidden),
#             nn.ReLU(inplace=True),
#             nn.Linear(hidden, 1),
#             nn.Sigmoid()                # w ∈ (0,1)
#         )

#     def forward(self, G: torch.Tensor, P: torch.Tensor) -> torch.Tensor:
#         flat = torch.cat([G.flatten(), P.flatten()], dim=0)
#         return self.net(flat)           # shape (1,)

# # ----------------------------------------------------------------------
# #  Fusion  A = w·G + (1−w)·P
# # ----------------------------------------------------------------------
# def fuse_similarity(G: torch.Tensor,
#                     P: torch.Tensor,
#                     gate_net: GlobalGateMLP):
#     w = gate_net(G, P)                 # scalar (1,)
#     A = w * G + (1 - w) * P
#     return A, w

# # ----------------------------------------------------------------------
# #  Row-softmax entropy loss   (laplacian term removed)
# # ----------------------------------------------------------------------
# def entropy_loss(A: torch.Tensor) -> torch.Tensor:
#     n = A.size(0)
#     soft = F.softmax(A, dim=1)
#     logt = torch.clamp(soft, 1e-12, 1).log()
#     return -(soft * logt).sum() / n    # minimise → sharper rows

# # ----------------------------------------------------------------------
# #  Convenience wrapper (only entropy now)
# # ----------------------------------------------------------------------
# def total_loss(A: torch.Tensor) -> torch.Tensor:
#     return entropy_loss(A)

# # ----------------------------------------------------------------------
# #  Example usage with provided 100×100  G  and  v  (left untouched)
# # ----------------------------------------------------------------------
# device = "cuda" if torch.cuda.is_available() else "cpu"

# def to_tensor(mat):
#     if torch.is_tensor(mat):
#         return mat.clone().to(device)            # no in-place changes
#     return torch.as_tensor(mat, dtype=torch.float32, device=device)

# G_t = to_tensor(G)            # gradient similarity  (100×100)
# V_t = to_tensor(v)            # data     similarity  (100×100)

# n = G_t.size(0)
# gate_net = GlobalGateMLP(n).to(device)
# optimizer = torch.optim.Adam(gate_net.parameters(), lr=1e-3)

# for step in range(200):
#     optimizer.zero_grad()
#     A, w = fuse_similarity(G_t, V_t, gate_net)
#     print(w)
#     loss = total_loss(A)                      # entropy only
  
#     loss.backward()
#     optimizer.step()

# print(f"Learned scalar gate w ≈ {w.item():.4f}")
# ======================================================================
# Row-gate fusion (one weight per client) – entropy-only objective
# ======================================================================

