import math
from tqdm import tqdm

import numpy as np
import torch

from main.utils.random import set_random_seed


def _lower_tail_cvar(x, alpha):
    n = x.numel()
    k = math.ceil(alpha * n)
    thresh = x.kthvalue(k).values
    tail = x[x <= thresh]
    return tail.mean()


def _find_best_idx_parallel(current_sum: torch.Tensor,
                            X: torch.Tensor,
                            remaining_mask: torch.Tensor,
                            alpha: float,
                            base_cvar: float):
    idxs = torch.nonzero(remaining_mask, as_tuple=False).view(-1)
    candidates = X[idxs]
    sums = candidates + current_sum.unsqueeze(0)
    D = sums.size(1)
    k = math.ceil(alpha * D)
    thresh, _ = sums.kthvalue(k, dim=1)
    mask     = sums <= thresh.unsqueeze(1)
    tail_sum = (sums * mask).sum(dim=1)
    tail_cnt = mask.sum(dim=1).to(sums.dtype)
    cvars    = tail_sum / tail_cnt
    gains    = cvars - base_cvar
    loc      = gains.argmax()
    best_idx = idxs[loc].item()
    best_gain= gains[loc].item()
    return best_idx, best_gain


def greedy_cvar_selection(X: torch.Tensor,
                          k: int,
                          alpha: float,
                          device: torch.device = None,
                          seed: int = None,
                          thresholds=None) -> torch.Tensor:
    """
    Greedy CVaR-based subset selection parallelized on GPU/CPU.
    X: (n, r) tensor
    k: number of vectors to select
    alpha: CVaR tail fraction
    device: torch device (e.g. 'cuda' or 'cpu')
    Returns: (k,) tensor of selected indices
    """
    if seed is not None:
        set_random_seed(seed)
    if device is None:
        device = X.device
    X = X.to(device)
    n, r = X.shape
    if k > n:
        raise ValueError("k cannot exceed the number of vectors.")

    selected = []
    if thresholds is None:
        current_sum = torch.zeros(r, device=device, dtype=X.dtype)
    else:
        current_sum = (-thresholds).to(device=device, dtype=X.dtype)
    remaining = torch.ones(n, dtype=torch.bool, device=device)

    for _ in tqdm(range(k)):
        base_cvar = _lower_tail_cvar(current_sum, alpha).item()
        best_idx, _ = _find_best_idx_parallel(current_sum, X, remaining, alpha, base_cvar)
        selected.append(best_idx)
        current_sum = current_sum + X[best_idx]
        remaining[best_idx] = False

    return np.array(selected)
