# =============================================================================
# Base Problems
# =============================================================================

import math
from pathlib import Path
from abc import ABC, abstractmethod

import torch
from torch import Tensor



# -----------------------------------------------------------------------------
# BiLevelBaseProblem
# -----------------------------------------------------------------------------

class BiLevelBaseProblem(ABC):
    """Abstract base class for bilevel optimization problems."""

    num_dims: list[int]
    num_objectives: list[int]
    num_constraints: list[int]
    bounds: list[tuple[float, float]]
    candidates_path: Path

    def __init__(
        self,
        noise_std: float | list[float] | None = None,
        has_candidates: bool = False,
        num_discretize: int | list[int] | None = None,
    ) -> None:

        self.noise_std = noise_std
        self.has_candidates = has_candidates
        self.num_discretize = num_discretize

        self.d_in = sum(self.num_dims)
        self.d_out = sum(self.num_objectives) + sum(self.num_constraints)
        if self.has_candidates:
            if self.candidates_path is not None:
                self.candidates: Tensor = torch.load(
                    self.candidates_path, weights_only=False,
                )  # shape: [*batch_shape, d_in]
            elif isinstance(num_discretize, int):
                batch_shape = (num_discretize ** self.num_dims[0],
                               num_discretize ** self.num_dims[1])
                self.candidates = torch.stack(torch.meshgrid(
                    [torch.linspace(*bounds, num_discretize)
                     for bounds in self.bounds],
                    indexing="ij",
                ), dim=-1).view(*batch_shape, self.d_in)
            elif isinstance(num_discretize, list):
                batch_shape = (math.prod(num_discretize[:self.num_dims[0]]),
                               math.prod(num_discretize[self.num_dims[0]:]))
                self.candidates = torch.stack(torch.meshgrid(
                    [torch.linspace(*bounds, disc)
                     for bounds, disc in zip(self.bounds, num_discretize)],
                    indexing="ij",
                ), dim=-1).view(*batch_shape, self.d_in)
            self.mask_evaluated: Tensor = torch.zeros(
                *self.candidates.shape[:-1], self.d_out, dtype=torch.bool,
            )  # shape: [*batch_shape, d_out]


    def __call__(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
        noise: bool = True,
    ) -> Tensor:  # shape: [*batch_shape, d_out]

        F = self.evaluate_true(X=X)
        C = self.evaluate_slack(X=X)
        outputs = torch.cat([F, C], dim=-1)
        if noise and self.noise_std is not None:
            _noise = torch.tensor(self.noise_std, dtype=X.dtype)
            outputs += _noise * torch.randn_like(outputs)
        return outputs
            

    @abstractmethod
    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, sum(num_objectives)]

        raise NotImplementedError


    def evaluate_slack(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, sum(num_constraints)]

        return torch.empty(*X.shape[:-1], sum(self.num_constraints))

