import torch
from random import sample
import math

def choose_S(X, H, K):
    n, d = X.shape
    inv_H = torch.inverse(H) 
    if torch.isnan(inv_H).any():
        print("NaN detected in inv_H")

    best_S = None
    max_avg = -float("inf")

    for ref_idx in range(n):
        ref_vec = X[ref_idx]  # (d,)
        diff = X - ref_vec  # (n, d)
        # Mahalanobis-like uncertainty
        uncertainty = torch.sqrt(torch.sum((diff @ inv_H) * diff, dim=1))  # (n,)

        S = [ref_idx]
        prev_avg = uncertainty[ref_idx].item()

        # Get other indices in decreasing uncertainty
        remaining = torch.argsort(uncertainty, descending=True).tolist()
        remaining = [i for i in remaining if i != ref_idx]

        for idx in remaining:
            if len(S) >= K:
                break
            new_S = S + [idx]
            new_avg = uncertainty[new_S].mean().item()
            if new_avg > prev_avg:
                S = new_S
                prev_avg = new_avg
            else:
                break

        if prev_avg > max_avg:
            best_S = S
            max_avg = prev_avg

    return best_S

def choose_S_rand_ref(X, H, K):
    n, d = X.shape
    inv_H = torch.inverse(H)

    # Randomly pick a reference index
    ref_idx = torch.randint(0, n, (1,)).item()
    ref_vec = X[ref_idx]  # (d,)
    diff = X - ref_vec  # (n, d)
    uncertainty = torch.sqrt(torch.sum((diff @ inv_H) * diff, dim=1))  # (n,)

    S = [ref_idx]
    prev_avg = uncertainty[ref_idx].item()

    # Indices sorted by descending uncertainty, excluding the reference
    remaining = torch.argsort(uncertainty, descending=True).tolist()
    remaining = [i for i in remaining if i != ref_idx]

    for idx in remaining:
        if len(S) >= K:
            break
        new_S = S + [idx]
        new_avg = uncertainty[new_S].mean().item()
        if new_avg > prev_avg:
            S = new_S
            prev_avg = new_avg
        else:
            break

    return S

def choose_S_dopewolfe(X, K, V_inv, z_dict, R=1000, alpha_tol=1e-3, device="cpu"):
    N, d = X.shape   
    R = min(math.comb(N, K), R)
    R_list = [sample(range(N), K) for _ in range(R)]
    grad_dict = PartialGrad(R_list, V_inv, z_dict, device)
    S_t = max(grad_dict.items(), key=lambda x: x[1])[0]
    A_S = construct_A_S(S_t, z_dict, d, device)
    alpha_t = GoldenSearch(V_inv, A_S, alpha_tol)
    V_inv = UpdateInverse(V_inv, A_S, alpha_t)

    return V_inv, list(S_t)

def PartialGrad(R_list, V_inv, z_dict, device="cpu"):
    D = {}
    for (j, k), z in z_dict.items():
        z = z.reshape(-1, 1).to(device)  # (d, 1)
        D[(j, k)] = float((z.T @ V_inv @ z).item())

    grad_dict = {}
    for S in R_list:
        S = sorted(S)
        G_S = 0.0
        for i in range(len(S)):
            for j in range(i + 1, len(S)):
                pair = (S[i], S[j])
                G_S += D.get(pair, D.get((pair[1], pair[0]), 0.0))
        grad_dict[tuple(S)] = G_S

    return grad_dict

def construct_A_S(S, z_dict, d, device="cpu"):
    K = len(S)
    col_count = K * (K - 1) // 2
    A_S = torch.empty((d, col_count), device=device)

    col = 0
    for i in range(K):
        for j in range(i + 1, K):
            A_S[:, col] = z_dict[(S[i], S[j])].to(device)
            col += 1

    return A_S

def UpdateInverse(V_inv, A, alpha, epsilon=1e-8):
    if not (0 <= alpha < 1):
        raise ValueError(f"Alpha must be in [0, 1). Got alpha = {alpha}")

    r = A.shape[1]
    I_r = torch.eye(r, device=A.device)
    middle = (1 / alpha) * I_r + A.T @ V_inv @ A
    middle += epsilon * I_r  # Regularization

    try:
        middle_inv = torch.inverse(middle)
    except RuntimeError:
        middle_inv = torch.linalg.pinv(middle)

    correction = V_inv @ A @ middle_inv @ A.T @ V_inv
    updated_inv = V_inv - correction

    return updated_inv

def UpdateLogDet(V_inv, A, alpha, epsilon=1e-8):
    d, r = A.shape
    A_tilde = torch.sqrt(torch.tensor(alpha, dtype=A.dtype, device=A.device)) * A
    I_r = torch.eye(r, device=A.device)
    inner = I_r + A_tilde.T @ V_inv @ A_tilde

    # Stable logdet using slogdet
    sign, logdet = torch.linalg.slogdet(inner + epsilon * I_r)
    return d * torch.log(1 - alpha) + logdet

def GoldenSearch(V_inv, A, alpha_tol=1e-3):
    phi = (torch.sqrt(torch.tensor(5.0)) + 1) / 2
    alpha_a, alpha_h, alpha_b = 0.0, 1.0, 1.0

    alpha_c = alpha_a + alpha_h / phi**2
    alpha_d = alpha_a + alpha_h / phi

    V_c = UpdateLogDet(V_inv, A, alpha_c)
    V_d = UpdateLogDet(V_inv, A, alpha_d)

    while abs(alpha_a - alpha_b) >= alpha_tol:
        alpha_h /= phi
        if V_c > V_d:
            alpha_b = alpha_d
            alpha_d = alpha_c
            V_d = V_c
            alpha_c = alpha_a + alpha_h / phi**2
            V_c = UpdateLogDet(V_inv, A, alpha_c)
        else:
            alpha_a = alpha_c
            alpha_c = alpha_d
            V_c = V_d
            alpha_d = alpha_a + alpha_h / phi
            V_d = UpdateLogDet(V_inv, A, alpha_d)

    return (alpha_a + alpha_b) / 2
