# =============================================================================
# Base Acquisitions
# =============================================================================

from abc import ABC, abstractmethod

import torch
from torch import Tensor
from botorch.utils.transforms import unnormalize

from utils import RFFModelList


# -----------------------------------------------------------------------------
# Bi-Level Acquisition
# -----------------------------------------------------------------------------

class BiLevelAcquisition(ABC):

    def __init__(
        self,
        num_dims: list[int],
        model_Y_upper: RFFModelList,
        model_Y_lower: RFFModelList,
        model_C_upper: RFFModelList | None = None,
        model_C_lower: RFFModelList | None = None,
        num_restarts: int | None = None,
        raw_samples: int | None = None,
    ) -> None:

        self.num_dims = num_dims
        self.model_Y_upper = model_Y_upper
        self.model_Y_lower = model_Y_lower
        self.model_C_upper = model_C_upper
        self.model_C_lower = model_C_lower
        self.num_restarts = num_restarts
        self.raw_samples = raw_samples

        self.num_objectives = []
        self.num_objectives.append(len(model_Y_upper.rff_models))
        self.num_objectives.append(len(model_Y_lower.rff_models))
        self.num_constraints = [0, 0]
        if model_C_upper is not None:
            self.num_constraints[0] = len(model_C_upper.rff_models)
        if model_C_lower is not None:
            self.num_constraints[1] = len(model_C_lower.rff_models)

        self.d_in = sum(num_dims)
        self.d_out = sum(self.num_objectives + self.num_constraints)
        self.Y_mask = torch.zeros(self.d_out, dtype=torch.bool)


    @abstractmethod
    def optimize_pool(
        self,
        candidates: Tensor,  # shape: [*batch_shape, d_in]
        mask_evaluated: Tensor,  # shape: [*batch_shape, d_out]
        decoupled: bool = False,
    ) -> tuple[Tensor, Tensor]:  # shape: [2], shape: [d_out]
        
        raise NotImplementedError


    def optimize_query(
        self,
        bounds: Tensor,  # shape: [2, d_in]
        decoupled: bool = False,
    ) -> tuple[Tensor, Tensor]:  # shape: [d_in], shape: [d_out]

        next_X = torch.rand(1, bounds.size(-1))
        next_X = unnormalize(next_X, bounds=bounds)
        if decoupled:
            self.Y_mask[torch.randint(self.d_out, (1,))] = True
        else:
            self.Y_mask.fill_(True)
        return next_X, self.Y_mask

