import math
import numpy as np
import torch
from scipy.special import gammaincc
from scipy.stats import gamma

from config import COV_TOL
from utils import get_gpu_count, mahalanobis_torch, safe_cov_torch


def pm_tail_gamma(d_out_sq, sq_dists):
    """
    Computes the PM measure based on the Gamma fit.
    :param d_out_sq: squared mahalanobis distance from the output to its cluster on the manifold.
    :param sq_dists: squared mahalanobis distance of all distortions in the cluster to their cluster on the manifold.
    :return: PM score.
    """
    mu = sq_dists.mean().item()
    var = sq_dists.var(unbiased=True).item()
    if var == 0.0:
        return 1.0
    k = (mu**2) / var
    theta = var / mu
    return float(1.0 - gamma.cdf(d_out_sq, a=k, scale=theta))


def pm_tail_rank(d_out_sq, sq_dists):
    """
    A depracted method to compute the PM measure based on the ranking method of distances.
    """
    rank = int((sq_dists < d_out_sq).sum().item())
    n = sq_dists.numel()
    return 1.0 - (rank + 0.5) / (n + 1.0)


def diffusion_map_torch(
    X_np,
    labels_by_mix,
    *,
    cutoff=0.99,
    tol=1e-3,
    diffusion_time=1,
    alpha=0.0,
    eig_solver="lobpcg",
    k=None,
    device=None,
    return_eigs=False,
    return_complement=False,
    return_cval=False,
):
    """
    Compute diffusion maps from a high dimensional set of points.

    :param X_np: high dimensional input.
    :param labels_by_mix: used to keep track of each source's coordinates on the manifold.
    :param cutoff: the desired ratio between sum of kept and sum of all eigenvalues.
    :param tol: deprecated since we do not use the "lobpcg" solver.
    :param diffusion_time: number of steps taken on the probability transition matrix.
    :param alpha: normalization factor in [0, 1].
    :param eig_solver: "lobpcg" or "full".
    :param k: pre-defined truncation dimension.
    :param device: "cpu" or "cuda".
    :param return_eigs: return eigenvalues and eigenvectors.
    :param return_complement: return complementary coordinates, not just kept coordinates.
    :param return_cval: calculate and return the psi_2 norm of the coordinates.
    :return:
    """
    device = device or ("cuda:0" if torch.cuda.is_available() else "cpu")
    X = torch.as_tensor(X_np, dtype=torch.float32, device=device)
    N = X.shape[0]

    if device != "cpu" and torch.cuda.is_available():
        stream = torch.cuda.Stream(device=device)
        ctx_dev = torch.cuda.device(device)
        ctx_stream = torch.cuda.stream(stream)
    else:
        from contextlib import nullcontext

        stream = None
        ctx_dev = nullcontext()
        ctx_stream = nullcontext()

    with ctx_dev:
        with ctx_stream:
            if N > 1000:
                chunk = min(500, N)
                D2 = torch.zeros(N, N, device=device)
                for i in range(0, N, chunk):
                    ei = min(i + chunk, N)
                    for j in range(0, N, chunk):
                        ej = min(j + chunk, N)
                        D2[i:ei, j:ej] = torch.cdist(X[i:ei], X[j:ej]).pow_(2)
            else:
                D2 = torch.cdist(X, X).pow_(2)

            i, j = torch.triu_indices(
                N, N, offset=1, device=None if device == "cpu" else device
            )
            eps = torch.median(D2[i, j])
            K = torch.exp(-D2 / (2 * eps))
            d = K.sum(dim=1)

            if alpha != 0.0:
                d_alpha_inv = d.pow(-alpha)
                K *= d_alpha_inv[:, None] * d_alpha_inv[None, :]
                d = K.sum(dim=1)

            D_half_inv = torch.diag(torch.rsqrt(d))
            K_sym = D_half_inv @ K @ D_half_inv

            if eig_solver == "lobpcg":
                m = k if k is not None else min(N - 1, 50)
                init = torch.randn(N, m, device=device)
                vals, vecs = torch.lobpcg(
                    K_sym, k=m, X=init, niter=200, tol=tol, largest=True
                )
            elif eig_solver == "full":
                vals, vecs = torch.linalg.eigh(K_sym)
                vals, vecs = vals.flip(0), vecs.flip(1)
                if k is not None:
                    vecs = vecs[:, : k + 1]
                    vals = vals[: k + 1]
            else:
                raise ValueError(f"Unknown eig_solver '{eig_solver}'")

            psi = vecs[:, 1:]
            lam = vals[1:]
            cum = torch.cumsum(lam, dim=0)
            L = int((cum / cum[-1] < cutoff).sum().item()) + 1
            lam_pow = lam.pow(diffusion_time)
            psi_all = psi * lam_pow
            Psi = psi_all[:, :L]
            Psi_rest = psi_all[:, L:]

            if return_cval:
                indices_with_out = [
                    ii for ii, name in enumerate(labels_by_mix) if "out" in name
                ]
                valid_idx = torch.tensor(
                    [ii for ii in range(N) if ii not in indices_with_out], device=device
                )
                pi_min = d[valid_idx].min() / d[valid_idx].sum()
                c_val = lam_pow[0] * pi_min.rsqrt() / math.log(2.0)

            if stream is not None:
                stream.synchronize()

    if return_complement and return_eigs and return_cval:
        return (
            Psi.cpu().numpy(),
            Psi_rest.cpu().numpy(),
            lam.cpu().numpy(),
            float(c_val),
        )
    if return_complement and return_eigs:
        return Psi.cpu().numpy(), Psi_rest.cpu().numpy(), lam.cpu().numpy()
    if return_complement:
        return Psi.cpu().numpy(), Psi_rest.cpu().numpy()
    if return_eigs:
        return Psi.cpu().numpy(), lam.cpu().numpy()
    return Psi.cpu().numpy()


def compute_ps(coords, labels, max_gpus=None):
    """
    Computes the PS measure.
    :param coords: coordinates on the manifold.
    :param labels: assign source index per coordinate.
    :param max_gpus: maximal number of GPUs to use.
    :return: the PS measure.
    """
    ngpu = get_gpu_count(max_gpus)

    if ngpu == 0:
        coords_t = torch.tensor(coords)
        spks_here = sorted({l.split("-")[0] for l in labels})
        out = {}
        for s in spks_here:
            idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
            out_i = labels.index(f"{s}-out")
            ref_is = [i for i in idxs if i != out_i]
            mu = coords_t[ref_is].mean(0)
            cov = safe_cov_torch(coords_t[ref_is])
            inv = torch.linalg.inv(cov)
            A = mahalanobis_torch(coords_t[out_i], mu, inv)
            B_list = []
            for o in spks_here:
                if o == s:
                    continue
                o_idxs = [
                    i
                    for i, l in enumerate(labels)
                    if l.startswith(o) and not l.endswith("-out")
                ]
                mu_o = coords_t[o_idxs].mean(0)
                inv_o = torch.linalg.inv(safe_cov_torch(coords_t[o_idxs]))
                B_list.append(mahalanobis_torch(coords_t[out_i], mu_o, inv_o))
            B_min = torch.min(torch.stack(B_list)) if B_list else torch.tensor(0.0)
            out[s] = (1 - A / (A + B_min + 1e-6)).item()
        return out

    device = min(ngpu - 1, 1)
    device_str = f"cuda:{device}"
    coords_t = torch.tensor(coords, device=device_str)
    spks_here = sorted({l.split("-")[0] for l in labels})
    out = {}

    stream = torch.cuda.Stream(device=device_str)
    with torch.cuda.device(device):
        with torch.cuda.stream(stream):
            for s in spks_here:
                idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
                out_i = labels.index(f"{s}-out")
                ref_is = [i for i in idxs if i != out_i]
                mu = coords_t[ref_is].mean(0)
                cov = safe_cov_torch(coords_t[ref_is])
                inv = torch.linalg.inv(cov)
                A = mahalanobis_torch(coords_t[out_i], mu, inv)
                B_list = []
                for o in spks_here:
                    if o == s:
                        continue
                    o_idxs = [
                        i
                        for i, l in enumerate(labels)
                        if l.startswith(o) and not l.endswith("-out")
                    ]
                    mu_o = coords_t[o_idxs].mean(0)
                    inv_o = torch.linalg.inv(safe_cov_torch(coords_t[o_idxs]))
                    B_list.append(mahalanobis_torch(coords_t[out_i], mu_o, inv_o))
                B_min = (
                    torch.min(torch.stack(B_list))
                    if B_list
                    else torch.tensor(0.0, device=device_str)
                )
                out[s] = (1 - A / (A + B_min + 1e-6)).item()
            stream.synchronize()
    return out


def compute_pm(coords, labels, pm_method, max_gpus=None):
    """
    Computes the PM measure.
    :param coords: coordinates on the manifold.
    :param labels: assign source index per coordinate.
    :param pm_method: "rank" or "gamma".
    :param max_gpus: maximal number of GPUs to use.
    :return: the PS measure.
    """
    ngpu = get_gpu_count(max_gpus)

    if ngpu == 0:
        coords_t = torch.tensor(coords)
        spks_here = sorted({l.split("-")[0] for l in labels})
        out = {}
        for s in spks_here:
            idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
            ref_i = labels.index(f"{s}-ref")
            out_i = labels.index(f"{s}-out")
            d_idx = [i for i in idxs if i not in {ref_i, out_i}]
            if len(d_idx) < 2:
                out[s] = 0.0
                continue
            ref_v = coords_t[ref_i]
            dist = coords_t[d_idx] - ref_v
            N, D = dist.shape
            cov = dist.T @ dist / (N - 1)
            if torch.linalg.matrix_rank(cov) < D:
                cov += torch.eye(D) * COV_TOL
            inv = torch.linalg.inv(cov)
            sq_dists = torch.stack(
                [mahalanobis_torch(coords_t[i], ref_v, inv) ** 2 for i in d_idx]
            )
            d_out_sq = float(mahalanobis_torch(coords_t[out_i], ref_v, inv) ** 2)
            pm_score = (
                pm_tail_rank(d_out_sq, sq_dists)
                if pm_method == "rank"
                else pm_tail_gamma(d_out_sq, sq_dists)
            )
            out[s] = float(np.clip(pm_score, 0.0, 1.0))
        return out

    device = min(ngpu - 1, 1)
    device_str = f"cuda:{device}"
    coords_t = torch.tensor(coords, device=device_str)
    spks_here = sorted({l.split("-")[0] for l in labels})
    out = {}

    stream = torch.cuda.Stream(device=device_str)
    with torch.cuda.device(device):
        with torch.cuda.stream(stream):
            for s in spks_here:
                idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
                ref_i = labels.index(f"{s}-ref")
                out_i = labels.index(f"{s}-out")
                d_idx = [i for i in idxs if i not in {ref_i, out_i}]
                if len(d_idx) < 2:
                    out[s] = 0.0
                    continue
                ref_v = coords_t[ref_i]
                dist = coords_t[d_idx] - ref_v
                N, D = dist.shape
                cov = dist.T @ dist / (N - 1)
                if torch.linalg.matrix_rank(cov) < D:
                    cov += torch.eye(D, device=device_str) * COV_TOL
                inv = torch.linalg.inv(cov)
                sq_dists = torch.stack(
                    [mahalanobis_torch(coords_t[i], ref_v, inv) ** 2 for i in d_idx]
                )
                d_out_sq = float(mahalanobis_torch(coords_t[out_i], ref_v, inv) ** 2)
                pm_score = (
                    pm_tail_rank(d_out_sq, sq_dists)
                    if pm_method == "rank"
                    else pm_tail_gamma(d_out_sq, sq_dists)
                )
                out[s] = float(np.clip(pm_score, 0.0, 1.0))
            stream.synchronize()
    return out


def pm_ci_components_full(
    coords_d, coords_rest, eigvals, labels, *, delta=0.05, K=1.0, C1=1.0, C2=1.0
):
    """
    Computes the error radius and tail bounds for the PM measure.
    :param coords_d: Retained diffusion maps coordinates.
    :param coords_rest: Complement diffusion maps coordinates.
    :param eigvals: Eigenvalues of the diffusion maps.
    :param labels: Assign source index per coordinate
    :param delta: 1-\delta is the confidence score.
    :param K: Absolute constant.
    :param C1: Absolute constant.
    :param C2: Absolute constant.
    :return: error radius and tail bounds for the PM measure.
    """
    _EPS = 1e-12

    def _safe_x(a, theta):
        return a / max(theta, _EPS)

    D = coords_d.shape[1]
    m = coords_rest.shape[1]
    if m == 0:
        z = {s: 0.0 for s in {l.split("-")[0] for l in labels}}
        return z.copy(), z.copy()

    X_d = torch.tensor(
        coords_d, device="cuda:0" if torch.cuda.is_available() else "cpu"
    )
    X_c = torch.tensor(
        coords_rest, device="cuda:0" if torch.cuda.is_available() else "cpu"
    )
    spk_ids = sorted({l.split("-")[0] for l in labels})
    bias_ci = {}
    prob_ci = {}

    for s in spk_ids:
        idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
        ref_i = labels.index(f"{s}-ref")
        out_i = labels.index(f"{s}-out")
        dist_is = [i for i in idxs if i not in {ref_i, out_i}]
        n_p = len(dist_is)

        if n_p < 2:
            bias_ci[s] = 0.0
            prob_ci[s] = 0.0
            continue

        ref_d = X_d[ref_i]
        ref_c = X_c[ref_i]
        D_mat = X_d[dist_is] - ref_d
        C_mat = X_c[dist_is] - ref_c
        Sigma_d = safe_cov_torch(D_mat)
        Sigma_c = safe_cov_torch(C_mat)
        C_dc = D_mat.T @ C_mat / (n_p - 1)
        inv_Sigma_d = torch.linalg.inv(Sigma_d)

        S_i = (
            Sigma_c
            - C_dc.T @ inv_Sigma_d @ C_dc
            + torch.eye(X_c.shape[1], device=X_c.device) * 1e-9
        )
        S_inv = torch.linalg.inv(S_i)

        diff_out_d = X_d[out_i] - ref_d
        diff_out_c = X_c[out_i] - ref_c
        r_out = diff_out_c - C_dc.T @ inv_Sigma_d @ diff_out_d
        delta_Gi_a = float(r_out @ S_inv @ r_out)

        r_list = []
        for p in dist_is:
            d_p = X_d[p] - ref_d
            c_p = X_c[p] - ref_c
            r_p = c_p - C_dc.T @ inv_Sigma_d @ d_p
            r_list.append(r_p)
        R_p = torch.stack(r_list, dim=0)
        delta_Gi_p = torch.sum(R_p @ S_inv * R_p, dim=1)
        delta_Gi_mu_max = float(delta_Gi_p.max())

        mah_sq = torch.stack(
            [(X_d[i] - ref_d) @ inv_Sigma_d @ (X_d[i] - ref_d) for i in dist_is]
        )
        mu_g = float(mah_sq.mean())
        sigma2_g = float(mah_sq.var(unbiased=True) + 1e-12)
        sigma_g = math.sqrt(sigma2_g)

        full_sq = mah_sq + delta_Gi_p
        mu_full = float(full_sq.mean())
        sigma2_full = float(full_sq.var(unbiased=True) + 1e-12)

        if sigma2_g == 0.0:
            delta_Gi_k = delta_Gi_theta = 0.0
        else:
            factor = delta_Gi_mu_max * n_p / (n_p - 1)
            delta_Gi_k = 1.0 * factor * (mu_full + mu_g) / sigma2_g
            delta_Gi_theta = 1.0 * factor * (sigma2_full + sigma2_g) / (mu_g**2 + 1e-9)

        k_d = (mu_g**2) / max(sigma2_g, 1e-12)
        theta_d = sigma2_g / max(mu_g, 1e-12)
        a_d = float(diff_out_d @ inv_Sigma_d @ diff_out_d)

        pm_center = gammaincc(k_d, _safe_x(a_d, theta_d))

        corner_vals = []
        for s_k in (-1, 1):
            for s_theta in (-1, 1):
                for s_a in (-1, 1):
                    k_c = max(k_d + s_k * delta_Gi_k, 1e-6)
                    theta_c = max(theta_d + s_theta * delta_Gi_theta, 1e-6)
                    a_c = max(a_d + s_a * delta_Gi_a, 1e-8)
                    corner_vals.append(gammaincc(k_c, _safe_x(a_c, theta_c)))

        bias_ci[s] = max(abs(v - pm_center) for v in corner_vals)

        R_sq = float(mah_sq.max()) + 1e-12
        log_term = math.log(6.0 / delta)
        eps_mu = math.sqrt(2 * sigma2_g * log_term / n_p) + 3 * R_sq * log_term / n_p
        eps_sigma = (
            math.sqrt(2 * R_sq**2 * log_term / n_p) + 3 * R_sq**2 * log_term / n_p
        )

        g1_x = 2.0 * mu_g / (sigma2_g + 1e-9)
        g1_y = -2.0 * mu_g**2 / (sigma_g**3 + 1e-9)
        g2_x = -sigma2_g / (mu_g**2 + 1e-9)
        g2_y = 2.0 * sigma_g / (mu_g + 1e-9)

        delta_k = min(abs(g1_x) * eps_mu + abs(g1_y) * eps_sigma, 0.5 * k_d)
        delta_theta = min(abs(g2_x) * eps_mu + abs(g2_y) * eps_sigma, 0.5 * theta_d)
        delta_a = min(R_sq * math.sqrt(2 * log_term / n_p), 0.5 * a_d + 1e-12)

        pm_corners = []
        for s_k in (-1, 1):
            for s_theta in (-1, 1):
                for s_a in (-1, 1):
                    k_c = k_d + s_k * delta_k
                    theta_c = theta_d + s_theta * delta_theta
                    a_c = max(a_d + s_a * delta_a, 1e-8)
                    pm_corners.append(gammaincc(k_c, _safe_x(a_c, theta_c)))

        prob_ci[s] = max(abs(pm - pm_center) for pm in pm_corners)

    return bias_ci, prob_ci


def ps_ci_components_full(coords_d, coords_rest, eigvals, labels, *, delta=0.05):
    """
    Computes the error radius and tail bounds for the PS measure.
    :param coords_d: Retained diffusion maps coordinates.
    :param coords_rest: Complement diffusion maps coordinates.
    :param eigvals: Eigenvalues of the diffusion maps.
    :param labels: Assign source index per coordinate
    :param delta: 1-\delta is the confidence score.
    :return: error radius and tail bounds for the PS measure.
    """

    def _mean_dev(lam_max, delta, n_eff):
        return math.sqrt(2 * lam_max * math.log(2 / delta) / n_eff)

    def _rel_cov_dev(lam_max, trace, delta, n_eff, C=1.0):
        r = trace / lam_max
        abs_dev = (
            C * lam_max * (math.sqrt(r / n_eff) + (r + math.log(2 / delta)) / n_eff)
        )
        return abs_dev / lam_max

    def _maha_eps_m(a_hat, lam_min, lam_max, mean_dev, rel_cov_dev):
        term1 = 2 * math.sqrt(a_hat) * mean_dev * math.sqrt(lam_max / lam_min)
        term2 = a_hat * rel_cov_dev
        return term1 + term2

    D = coords_d.shape[1]
    m = coords_rest.shape[1]
    if m == 0:
        z = {s: 0.0 for s in set(l.split("-")[0] for l in labels)}
        return z.copy(), z.copy()

    X_d = torch.tensor(
        coords_d, device="cuda:0" if torch.cuda.is_available() else "cpu"
    )
    X_c = torch.tensor(
        coords_rest, device="cuda:0" if torch.cuda.is_available() else "cpu"
    )
    spk_ids = sorted({l.split("-")[0] for l in labels})
    bias = {}
    prob = {}

    for s in spk_ids:
        idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
        out_i = labels.index(f"{s}-out")
        ref_is = [i for i in idxs if i != out_i]

        mu_d = X_d[ref_is].mean(0)
        mu_c = X_c[ref_is].mean(0)
        Sigma_d = safe_cov_torch(X_d[ref_is])
        Sigma_c = safe_cov_torch(X_c[ref_is])
        C_dc = (X_d[ref_is] - mu_d).T @ (X_c[ref_is] - mu_c) / (len(ref_is) - 1)
        inv_Sd = torch.linalg.inv(Sigma_d)

        lam_min = torch.linalg.eigvalsh(Sigma_d).min().clamp_min(1e-9).item()
        lam_max = torch.linalg.eigvalsh(Sigma_d).max()
        trace = torch.trace(Sigma_d).item()

        diff_d = X_d[out_i] - mu_d
        diff_c = X_c[out_i] - mu_c
        A_d = float(mahalanobis_torch(X_d[out_i], mu_d, inv_Sd))

        r_i = diff_c - C_dc.T @ inv_Sd @ diff_d
        S_i = (
            Sigma_c
            - C_dc.T @ inv_Sd @ C_dc
            + torch.eye(X_c.shape[1], device=X_c.device) * 1e-9
        )
        term_i = math.sqrt(float(r_i @ torch.linalg.solve(S_i, r_i)))

        B_d, term_j = float("inf"), 0.0
        Sig_o = None
        for o in spk_ids:
            if o == s:
                continue
            o_idxs = [
                i
                for i, l in enumerate(labels)
                if l.startswith(o) and not l.endswith("-out")
            ]
            muo_d = X_d[o_idxs].mean(0)
            muo_c = X_c[o_idxs].mean(0)
            Sig_o_tmp = safe_cov_torch(X_d[o_idxs])
            inv_So = torch.linalg.inv(Sig_o_tmp)
            this_B = float(mahalanobis_torch(X_d[out_i], muo_d, inv_So))

            if this_B < B_d:
                B_d = this_B
                Sig_o = Sig_o_tmp
                diff_do = X_d[out_i] - muo_d
                diff_co = X_c[out_i] - muo_c
                C_oc = (
                    (X_d[o_idxs] - muo_d).T @ (X_c[o_idxs] - muo_c) / (len(o_idxs) - 1)
                )
                r_j = diff_co - C_oc.T @ inv_So @ diff_do
                S_j = (
                    safe_cov_torch(X_c[o_idxs])
                    - C_oc.T @ inv_So @ C_oc
                    + torch.eye(X_c.shape[1], device=X_c.device) * 1e-9
                )
                term_j = math.sqrt(float(r_j @ torch.linalg.solve(S_j, r_j)))

        denom = A_d + B_d
        bias[s] = (B_d * term_i + A_d * term_j) / (denom**2)

        if Sig_o is not None:
            lam_min_o = torch.linalg.eigvalsh(Sig_o).min().clamp_min(1e-9).item()
            lam_max_o = torch.linalg.eigvalsh(Sig_o).max().item()
            trace_o = torch.trace(Sig_o).item()

            n_eff = max(int(0.7 * len(ref_is)), 3)
            RIDGE = 0.05
            lam_min_eff = max(lam_min, RIDGE * lam_max.item())
            lam_min_o_eff = max(lam_min_o, RIDGE * lam_max_o)

            eps_i_sg = _maha_eps_m(
                A_d,
                lam_min_eff,
                lam_max.item(),
                _mean_dev(lam_max.item(), delta / 2, n_eff),
                _rel_cov_dev(lam_max.item(), trace, delta / 2, n_eff),
            )
            eps_j_sg = _maha_eps_m(
                B_d,
                lam_min_o_eff,
                lam_max_o,
                _mean_dev(lam_max_o, delta / 2, n_eff),
                _rel_cov_dev(lam_max_o, trace_o, delta / 2, n_eff),
            )

            grad_l2 = math.hypot(A_d, B_d) / (A_d + B_d) ** 2
            ps_radius = grad_l2 * math.hypot(eps_i_sg, eps_j_sg)
            prob[s] = min(1.0, ps_radius)
        else:
            prob[s] = 0.0

    return bias, prob
