import torch
import numpy as np
import math
from random import sample

gen = torch.Generator()
gen.manual_seed(12345)

def choose_S_random(X, H, K):
    return None, list(np.random.choice(len(X), K, replace=False))

def choose_S_greedy(X, H, K):
    inv_H = np.linalg.pinv(H)
    best_S = None
    max_score = -np.inf
    for ref in range(len(X)):
        diff = X - X[ref]
        u = np.sqrt(np.sum((diff @ inv_H) * diff, axis=1))
        S = [ref]
        for i in np.argsort(u)[::-1]:
            if i != ref and len(S) < K:
                S.append(i)
        score = np.mean(u[S])
        if score > max_score:
            best_S = S
            max_score = score
    return None, best_S

def choose_S(X, H, K):
    n, d = X.shape
    inv_H = np.linalg.pinv(H)
    
    best_ref = None
    best_S = None
    max_avg = -np.inf
    
    for ref_idx in range(n):
        ref_vec = X[ref_idx]
        
        diff = X - ref_vec  # shape (n, d)
        uncertainty = np.sqrt(np.sum((diff @ inv_H) * diff, axis=1))
        
        S = [ref_idx]
        prev_avg = uncertainty[ref_idx]
        
        remaining = [i for i in np.argsort(uncertainty)[::-1] if i != ref_idx]

        for idx in remaining:
            if len(S) >= K:
                break
            new_S = S + [idx]
            new_avg = np.mean(uncertainty[new_S])
            if new_avg > prev_avg:
                S = new_S
                prev_avg = new_avg
            else:
                break
        
        if prev_avg > max_avg:
            best_ref = ref_idx
            best_S = S
            max_avg = prev_avg

    return best_ref, best_S

def choose_S_dopewolfe(X, K, V_inv, z_dict, R=100000, alpha_tol=1e-3):
    N, d = X.shape   
    R = min(math.comb(N, K), R)
    R_list = [sample(range(N), K) for _ in range(R)]
    grad_dict = PartialGrad(R_list, V_inv, z_dict)
    S_t = max(grad_dict.items(), key=lambda x: x[1])[0]
    A_S = construct_A_S(S_t, z_dict, d)
    alpha_t = GoldenSearch(V_inv, A_S, alpha_tol)
    V_inv = UpdateInverse(V_inv, A_S, alpha_t)

    return V_inv, list(S_t)

def PartialGrad(R_list, V_inv, z_dict):
    D = {}
    for (j, k), z in z_dict.items():
        z = z.reshape(-1, 1)  # (d, 1)
        D[(j, k)] = float(z.T @ V_inv @ z)

    grad_dict = {}
    for S in R_list:
        S = sorted(S)
        G_S = 0.0
        for i in range(len(S)):
            for j in range(i + 1, len(S)):
                pair = (S[i], S[j])
                if pair in D:
                    G_S += D[pair]
                else:
                    G_S += D[(pair[1], pair[0])]
        grad_dict[tuple(S)] = G_S

    return grad_dict

def construct_A_S(S, z_dict, d):
    K = len(S)
    col_count = K * (K - 1) // 2
    A_S = np.empty((d, col_count))

    col = 0
    for i in range(K):
        for j in range(i + 1, K):
            A_S[:, col] = z_dict[(S[i], S[j])]
            col += 1

    return A_S

def UpdateInverse(V_inv, A, alpha, epsilon=1e-8):
    if not (0 <= alpha < 1):
        raise ValueError(f"Alpha must be in [0, 1). Got alpha = {alpha}")

    r = A.shape[1]
    I_r = np.eye(r)

    # Compute middle matrix safely with regularization
    middle = (1 / alpha) * I_r + A.T @ V_inv @ A
    middle += epsilon * I_r  # Ensure it's invertible

    try:
        middle_inv = np.linalg.inv(middle)
    except np.linalg.LinAlgError:
        # Use pseudo-inverse as fallback
        middle_inv = np.linalg.pinv(middle)

    # Woodbury update
    correction = V_inv @ A @ middle_inv @ A.T @ V_inv
    updated_inv = V_inv - correction

    return updated_inv

def UpdateLogDet(V_inv, A, alpha):
    d, r = A.shape
    A_tilde = np.sqrt(alpha) * A
    I_r = np.eye(r)
    inner = I_r + A_tilde.T @ V_inv @ A_tilde
    logdet = np.linalg.slogdet(inner)[1]  # stable log(det)
    return d * np.log(1 - alpha) + logdet

def GoldenSearch(V_inv, A, alpha_tol=1e-3):
    phi = (np.sqrt(5) + 1) / 2
    alpha_a, alpha_h, alpha_b = 0, 1, 1

    alpha_c = alpha_a + alpha_h / phi**2
    alpha_d = alpha_a + alpha_h / phi

    V_c = UpdateLogDet(V_inv, A, alpha_c)
    V_d = UpdateLogDet(V_inv, A, alpha_d)

    while abs(alpha_a - alpha_b) >= alpha_tol:
        alpha_h /= phi
        if V_c > V_d:
            alpha_b = alpha_d
            alpha_d = alpha_c
            V_d = V_c
            alpha_c = alpha_a + alpha_h / phi**2
            V_c = UpdateLogDet(V_inv, A, alpha_c)
        else:
            alpha_a = alpha_c
            alpha_c = alpha_d
            V_c = V_d
            alpha_d = alpha_a + alpha_h / phi
            V_d = UpdateLogDet(V_inv, A, alpha_d)

    return (alpha_a + alpha_b) / 2
