# =============================================================================
# Random
# =============================================================================

import torch
from torch import Tensor

from acquisitions.base import BiLevelAcquisition


# -----------------------------------------------------------------------------
# Random Sampling
# -----------------------------------------------------------------------------

class RandomSampling(BiLevelAcquisition):

    def optimize_pool(
        self,
        candidates: Tensor,  # shape: [*batch_shape, d_in]
        mask_evaluated: Tensor,  # shape: [*batch_shape, d_out]
        decoupled: bool = False,
    ) -> tuple[Tensor, Tensor]:  # shape: [2], shape: [d_out]

        idx = torch.randint(self.d_out, (1,))
        mask = candidates.isnan().any(dim=-1)
        mask |= mask_evaluated[..., idx].squeeze()
        
        valid_idx = (~mask).nonzero(as_tuple=False)
        indices = valid_idx[torch.randperm(valid_idx.size(0))[:1]].view(-1)
        if decoupled:
            self.Y_mask[idx] = True
        else:
            self.Y_mask.fill_(True)
        return indices, self.Y_mask


