import math
import numpy as np
import torch
from typing import List, Tuple, Optional, Sequence
from sklearn.mixture import GaussianMixture
from scipy.optimize import least_squares, minimize

MSPair = Tuple[Tuple[int, ...], float]

def _safe_pow(X, b, eps=1e-12):
    return np.power(np.maximum(X, eps), b)

def _pl_add_resid_scaled(theta, Qn, V):
    N, S = Qn.shape
    a = np.clip(theta[:S], 1e-6, 1e12)
    b = theta[S:2*S]
    b = np.where(np.abs(b - 1.0) < 1e-6, 1.0 + np.sign(b - 1.0)*1e-3, b)
    bias = theta[-1]
    Vhat = (a * _safe_pow(Qn, b)).sum(axis=1) + bias
    return V - Vhat

def _fit_powerlaw_additive_nls(Q_boot, V_boot):
    Q = np.asarray(Q_boot, float); V = np.asarray(V_boot, float).ravel()
    N, S = Q.shape
    Q_scale = np.nanmedian(np.where(Q > 0, Q, np.nan), axis=0)
    Q_scale = np.where(np.isfinite(Q_scale) & (Q_scale > 0), Q_scale, 1.0)
    Qn = Q / Q_scale

    Vrng = max(1.0, float(np.nanmax(V) - np.nanmin(V)))
    a0 = np.full(S, Vrng / max(S, 1))
    b0 = np.full(S, 0.7)
    bias0 = float(np.percentile(V, 5))
    theta0 = np.concatenate([a0, b0, [bias0]])

    lb = np.concatenate([np.full(S, 1e-6),  np.full(S, 1e-3), [-1e3]])
    ub = np.concatenate([np.full(S, 1e3),   np.full(S, 4.0),   [ 1e3]])

    res = least_squares(
        _pl_add_resid_scaled, theta0, args=(Qn, V),
        bounds=(lb, ub), loss='soft_l1', f_scale=1.0, max_nfev=5000
    )
    theta = res.x
    a_norm = np.clip(theta[:S], lb[0], ub[0])
    b = theta[S:2*S]
    b = np.where(np.abs(b - 1.0) < 1e-6, 1.0 + np.sign(b - 1.0)*1e-3, b)
    bias = theta[-1]
    a_raw = a_norm / np.power(np.maximum(Q_scale, 1.0), b)
    return a_raw, b, bias

def _min_cost_q_Vstar(a, b, bias, Vstar, c_vec, q_cap_vec=None):
    a = np.asarray(a, float); b = np.asarray(b, float); c = np.asarray(c_vec, float)
    S = len(a)
    if Vstar <= bias + 1e-12:
        return np.zeros(S, float)

    bounds = [(0.0, None) for _ in range(S)] if q_cap_vec is None else [(0.0, float(q_cap_vec[k])) for k in range(S)]

    def g(q):
        return float(np.sum(a * np.power(np.maximum(q, 0.0), b)) + bias - Vstar)

    def obj(q):
        return float(np.dot(c, q))

    q0 = np.full(S, 1.0, float)
    cons = [{'type':'ineq', 'fun': lambda q: np.array([g(q)])}]
    res = minimize(obj, q0, method='SLSQP', bounds=bounds, constraints=cons,
                   options=dict(maxiter=500, ftol=1e-12, disp=False))
    q_hat = np.clip(res.x, 0.0, [bnd[1] if bnd[1] is not None else np.inf for bnd in bounds])
    return q_hat

def _bootstrap_Qhats_multi_cost(points_ms: List[MSPair], Vstar: float, c_vec: Sequence[float],
                                B: int = 400, seed: Optional[int] = None,
                                q_cap_vec: Optional[Sequence[int]] = None) -> np.ndarray:
    rng = np.random.default_rng(seed)
    Q = np.array([list(q) for (q,_) in points_ms], dtype=float)
    V = np.array([v for (_,v) in points_ms], dtype=float)
    N, S = Q.shape
    qmax = None if q_cap_vec is None else np.asarray(q_cap_vec, dtype=float)

    Qhats = np.empty((B, S), dtype=float)
    for b in range(B):
        idx = rng.integers(0, N, size=N)
        a, beta, bias = _fit_powerlaw_additive_nls(Q[idx], V[idx])
        Qhats[b] = _min_cost_q_Vstar(a, beta, bias, Vstar, c_vec=c_vec, q_cap_vec=qmax)

    if np.unique(Qhats, axis=0).shape[0] == 1:
        Qhats = np.maximum(Qhats + rng.normal(0.0, 1e-3, size=Qhats.shape), 0.0)
    return Qhats

def _build_F_from_Qhats_gmm_diag(Qhats: np.ndarray, device=None, dtype=torch.float32,
                                n_components: int = 3, reg_covar: float = 1e-6):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    gmm = GaussianMixture(n_components=n_components, covariance_type="diag", reg_covar=reg_covar, max_iter=200)
    gmm.fit(Qhats)

    pis  = torch.tensor(gmm.weights_, dtype=dtype, device=device)
    mus  = torch.tensor(gmm.means_, dtype=dtype, device=device)
    sigs = torch.tensor(np.sqrt(gmm.covariances_), dtype=dtype, device=device)

    inv_sqrt2 = 1.0 / math.sqrt(2.0)

    def F_qvec(x: torch.Tensor) -> torch.Tensor:
        q = x.to(dtype=dtype, device=device)
        z = (q.unsqueeze(-2) - mus) / (sigs + 1e-9)
        Phi = 0.5 * (1.0 + torch.erf(z * inv_sqrt2))
        return torch.clamp((Phi.prod(dim=-1) * pis).sum(dim=-1), 0.0, 1.0)
    return F_qvec

def fit_gmm_success_prob(points_ms: List[MSPair], Vstar: float, S: int, K: int = 3,
                         device=None, dtype=torch.float32,
                         c_vec: Optional[Sequence[float]] = None,
                         q_cap_vec: Optional[Sequence[int]] = None,
                         B: int = 400, seed: Optional[int] = None):
    if c_vec is None:
        c_vec = [1.0]*S
    Qhats = _bootstrap_Qhats_multi_cost(points_ms, Vstar, c_vec=c_vec, B=B, seed=seed, q_cap_vec=q_cap_vec)
    return _build_F_from_Qhats_gmm_diag(Qhats, device=device, dtype=dtype, n_components=K, reg_covar=1e-6)
