from dataclasses import dataclass
from typing import List, Dict, Any, Optional
import torch, numpy as np, random
from torch import nn


@dataclass
class RunCfg:
    device: str = "cuda"
    seed: int = 42
    d: int = 20
    num_users: int = 16
    lr: float = 0.1
    iterations: int = 3000
    eval_every: int = 20
    rho: float = 0.0
    topology: str = None
    seed_noise: int= 42
    split_ratio: Optional[List[float]] = None

    gamma_low: float = 0.5
    gamma_high: float = 2.5

    bias_nodes: Optional[List[int]] = None
    mu_norm: float = 5.0

    noise_sigma: float = 0.0
    use_wandb: bool = False
    entity: str = ""
    project: str = ""

def build_quadratic_problem(cfg: RunCfg) -> Dict[str, Any]:
    torch.manual_seed(cfg.seed); np.random.seed(cfg.seed); random.seed(cfg.seed)
    g = torch.Generator().manual_seed(cfg.seed)

    n, d = cfg.num_users, cfg.d
    device = cfg.device

    c_base = torch.randn(d, generator=g).to(device)

    bias_nodes = cfg.bias_nodes if cfg.bias_nodes is not None else [0]

    gammas = cfg.gamma_low + (cfg.gamma_high - cfg.gamma_low) * torch.rand(n, generator=g).to(device)

    mu_bank: Dict[int, torch.Tensor] = {}
    for u in bias_nodes:
        dir_u = torch.randn(d, generator=g).to(device)
        dir_u = dir_u / (dir_u.norm() + 1e-12)
        mu_bank[u] = cfg.mu_norm * dir_u

    Q_list: List[torch.Tensor] = []
    c_list: List[torch.Tensor] = []
    for u in range(n):
        gamma_u = gammas[u]
        Q_u = torch.eye(d).to(device) * gamma_u
        c_u = c_base + (mu_bank[u] if u in mu_bank else torch.zeros(d).to(device))
        Q_list.append(Q_u); c_list.append(c_u)

    x_init_all = torch.randn(n, d, generator=g).to(device)


    return {
        "Q_list": Q_list,
        "c_list": c_list,
        "x_init_all": x_init_all,
        "gammas": gammas,
        "c_base": c_base,
    }


class LSParam(nn.Module):
    def __init__(self, d: int, x_init: torch.Tensor = None, device: str = "cpu"):
        super().__init__()
        if x_init is None:
            x_init = torch.zeros(d)
        self.x = nn.Parameter(x_init.clone().to(device))

    def forward(self, A: torch.Tensor):
        return A @ self.x


def build_model(cfg, x_init: torch.Tensor):
    device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
    model = LSParam(cfg.d, x_init=x_init, device=str(device)).to(device)
    return model