import math
from tqdm import tqdm

import numpy as np
import torch

from main.utils.random import set_random_seed


def _budget_additive(x, budgets):
    n = x.numel()
    return (1/budgets) @ torch.min(x, budgets)


def _find_best_idx_parallel(current_sum: torch.Tensor,
                            X: torch.Tensor,
                            remaining_mask: torch.Tensor,
                            budgets,
                            base: float):
    idxs = torch.nonzero(remaining_mask, as_tuple=False).view(-1)
    candidates = X[idxs]
    sums = candidates + current_sum.unsqueeze(0)
    capped = torch.minimum(sums, budgets.unsqueeze(0) if budgets.ndim == 1 else budgets)
    scores = (capped / budgets).sum(dim=1)
    gains    = scores - base
    loc      = gains.argmax()
    best_idx = idxs[loc].item()
    best_gain= gains[loc].item()
    return best_idx, best_gain


def greedy_ba_selection(X: torch.Tensor,
                        k: int,
                        device: torch.device = None,
                        seed: int = None,
                        thresholds=None, eps=1e-3) -> torch.Tensor:
    if seed is not None:
        set_random_seed(seed)
    if device is None:
        device = X.device
    X = X.to(device)
    n, r = X.shape
    if k > n:
        raise ValueError("k cannot exceed the number of vectors.")

    selected = []
    current_sum = torch.zeros(r, device=device, dtype=X.dtype)
    budgets = torch.clamp(thresholds, min=eps).to(device=device, dtype=X.dtype)
    remaining = torch.ones(n, dtype=torch.bool, device=device)

    for _ in tqdm(range(k)):
        base = _budget_additive(current_sum, budgets).item()
        best_idx, _ = _find_best_idx_parallel(current_sum, X, remaining, budgets, base)
        selected.append(best_idx)
        current_sum = current_sum + X[best_idx]
        remaining[best_idx] = False

    return np.array(selected)
