from tqdm import tqdm

import numpy as np
import torch

from main.utils.random import set_random_seed


def _yager_weights(n_val: int, beta: float) -> torch.Tensor:
    """
    Generate Yager weights for OWA operator.
    Args:
        n_val (int): Number of validation data.
        beta (float): Parameter that controls the shape of the weights. beta = 0 gives equal weights, beta < 0 gives higher weight to lower values, and beta > 0 gives higher weight to higher values.
    """
    exponents = -beta * torch.linspace(0, 1, n_val)
    raw_weights = torch.exp(exponents)
    weights = raw_weights / raw_weights.sum()
    return weights


@torch.no_grad()
def _find_topm_idx_owa_parallel(current_sum: torch.Tensor,
                                X: torch.Tensor,
                                remaining_mask: torch.Tensor,
                                w: torch.Tensor,
                                m: int,
                                chunk_size: int=1024):
    """
    Find the top-m data conditioned on the current sum using OWA operator in parallel.
    Args:
        current_sum (torch.Tensor): Current sum for all validation data.
        X (torch.Tensor): Matrix of per-sample Shapley values.
        remaining_mask (torch.Tensor): Mask indicating which training data are still available.
        w (torch.Tensor): Weights for the OWA operator.
        m (int): Number of top data to consider.
        chunk_size (int): Size of the chunks to process in parallel.
    """
    idxs = torch.nonzero(remaining_mask, as_tuple=False).view(-1)
    if idxs.numel() == 0:
        return torch.empty(0, dtype=torch.long, device=X.device), torch.empty(0, dtype=X.dtype, device=X.device)
    
    m = min(m, idxs.numel())
    sorted_current_sum, _ = torch.sort(current_sum, descending=True)
    base_owa = sorted_current_sum @ w

    top_idxs = torch.empty(0, dtype=torch.long, device=X.device)
    top_gains = torch.empty(0, dtype=X.dtype, device=X.device)

    for i in range(0, len(idxs), chunk_size):
        batch_idxs = idxs[i:i+chunk_size] # Process `chunk_size` indices at a time
        candidates = X[batch_idxs]
        sums = candidates + current_sum.unsqueeze(0)
        sorted_sums, _ = torch.sort(sums, dim=1, descending=True)
        owas = sorted_sums @ w
        gains = owas - base_owa

        b = min(m, gains.numel())
        batch_top_gains, batch_top_pos = torch.topk(gains, k=b, largest=True)
        batch_top_idxs = batch_idxs[batch_top_pos]

        if top_gains.numel() == 0:
            top_gains = batch_top_gains
            top_idxs = batch_top_idxs
        else:
            all_gains = torch.cat([top_gains, batch_top_gains], dim=0)
            all_idxs  = torch.cat([top_idxs,  batch_top_idxs ], dim=0)
            keep = min(m, all_gains.numel())
            new_gains, new_pos = torch.topk(all_gains, k=keep, largest=True)
            new_idxs = all_idxs[new_pos]
            top_gains, top_idxs = new_gains, new_idxs

    return top_idxs, top_gains


def greedy_owa_selection(X: torch.Tensor,
                         k: int,
                         beta: float=None,
                         device: torch.device=None,
                         chunk_size=1024,
                         seed=None,
                         weights=None,
                         thresholds: torch.Tensor=None) -> torch.Tensor:
    return random_greedy_owa_selection(X, k, beta=beta, top_m=1, device=device, chunk_size=chunk_size, seed=seed, weights=weights, thresholds=thresholds)

 
def random_greedy_owa_selection(X: torch.Tensor,
                                k: int,
                                beta: float=None,
                                top_m: int=None,
                                device: torch.device=None,
                                chunk_size=1024,
                                seed=None,
                                weights=None,
                                thresholds: torch.Tensor=None) -> torch.Tensor:
    """
    Select k data using the random greedy algorithm.
    Args:
        X (torch.Tensor): Matrix of per-sample Shapley values.
        k (int): Number of data to select.
        beta (float): Parameter that controls the shape of the weights for OWA operator. beta = 0 gives equal weights, beta < 0 gives higher weight to lower values, and beta > 0 gives higher weight to higher values.
        top_m (int, optional): Number of top data to consider in each iteration. If None, it will be set to k.
        device (torch.device, optional): Device to perform computations on. Defaults to None, which uses the device of X.
        chunk_size (int): Size of the chunks to process in parallel.
        seed (int): Random seed for reproducibility.
        weights (np.ndarray, optional): Precomputed weights for the OWA operator. If None, beta must be provided.
        thresholds (torch.Tensor, optional): Thresholds for the selection.
    Returns:
        torch.Tensor: Indices of the selected data.
    """
    if device is None:
        device = X.device

    X = X.to(device)
    n, r = X.shape # n: number of training data, r: number of validation data
    if k > n:
        raise ValueError("k cannot exceed the number of vectors.")
    if top_m is None:
        top_m = k
    if seed is not None:
        set_random_seed(seed)

    if weights is None:
        if beta is None:
            raise ValueError("Either beta or weights must be provided.")

        w = _yager_weights(r, beta).to(dtype=X.dtype, device=device)
    else:
        w = torch.from_numpy(weights.copy()).to(dtype=X.dtype, device=device)

    selected = []
    if thresholds is None:
        current_sum = torch.zeros(r, device=device, dtype=X.dtype)
    else:
        current_sum = (-thresholds).to(device=device, dtype=X.dtype)
    remaining = torch.ones(n, dtype=torch.bool, device=device)

    for _ in tqdm(range(k)):
        cand_idxs, cand_gains = _find_topm_idx_owa_parallel(
            current_sum=current_sum,
            X=X,
            remaining_mask=remaining,
            w=w,
            m=top_m,
            chunk_size=chunk_size
        )
        if cand_idxs.numel() == 0:
            break

        choice_pos = torch.randint(cand_idxs.numel(), (1,), device=cand_idxs.device).item()
        best_idx = cand_idxs[choice_pos].item()

        selected.append(best_idx)
        current_sum = current_sum + X[best_idx]
        remaining[best_idx] = False

        top_m = min(top_m, remaining.sum().item())

    return np.array(selected)

