# =============================================================================
# Bi-Level Real-World Problems
# =============================================================================

from pathlib import Path
from typing import Literal

import torch
from torch import Tensor

from problems.base import BiLevelBaseProblem


# Utility functions -----------------------------------------------------------

def log1p(input: Tensor) -> Tensor:
    return input.sign() * torch.log1p(input.abs())


# -----------------------------------------------------------------------------
# High Entropy Alloy (Small)
# -----------------------------------------------------------------------------

class HighEntropyAlloySmall(BiLevelBaseProblem):

    num_dims = [3, 162]
    num_objectives = [1, 1]
    num_constraints = [0, 0]
    bounds = [(15.0, 54.0)] * 3 + [(0.0, 1.0)] * 162
    candidates_path = Path("data/hea_small/candidates.pt")
    outputs_path = Path("data/hea_small/outputs.pt")

    def __init__(
        self,
        noise_std: float | list[float] | None = None,
        has_candidates: bool = True,
        num_discretize: int | list[int] | None = None,
        descriptor: Literal["RDF"] | None = None,
    ) -> None:

        if descriptor == "RDF":
            self.num_dims = [3, 30]
            self.bounds = ([(15.0, 24.0)] * 3
             + [(0.0, 7.8006), (0.0, 4.3988), (0.0, 4.1625), (0.0, 2.8639), (0.0, 2.0945), (0.0, 8.4259), (0.0, 4.2831), (0.0, 3.7125), (0.0, 3.0685),
                (0.0, 1.9239), (0.0, 8.3885), (0.0, 4.2136), (0.0, 3.6750), (0.0, 3.0429), (0.0, 1.9548), (0.0, 8.4789), (0.0, 5.0934), (0.0, 4.1625),
                (0.0, 2.9534), (0.0, 2.1750), (0.0, 8.1080), (0.0, 4.3409), (0.0, 3.5859), (0.0, 3.0697), (0.0, 1.9258), (0.0, 7.8006), (0.0, 4.6303),
                (0.0, 3.7924), (0.0, 2.9534), (0.0, 1.9736)])
            self.candidates_path = Path("data/hea_small/candidates_RDF.pt")
        super().__init__(noise_std, has_candidates, num_discretize)
        self.outputs = torch.load(self.outputs_path, weights_only=False)


    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, sum(num_objectives)]

        X_flat = X.view(-1, self.d_in)
        idx = (~self.candidates.isnan().any(dim=-1)).to(int).argmax(dim=1)
        X_upper = self.candidates[torch.arange(self.candidates.size(0)), idx]
        idx0 = torch.isclose(
            input=X_flat[:, :self.num_dims[0]].unsqueeze(1),
            other=X_upper[:, :self.num_dims[0]].unsqueeze(0),
        ).all(dim=-1).nonzero(as_tuple=False)[:, 1:].squeeze()
        idx1 = torch.isclose(
            input=X_flat[:, self.num_dims[0]:].unsqueeze(1),
            other=self.candidates[idx0, :, self.num_dims[0]:],
        ).all(dim=-1).nonzero(as_tuple=False)[:, 1:].squeeze()
        return self.outputs[idx0, idx1].view(*X.shape[:-1], self.d_out)



# -----------------------------------------------------------------------------
# High Entropy Alloy (Medium)
# -----------------------------------------------------------------------------

class HighEntropyAlloyMedium(BiLevelBaseProblem):

    num_dims = [3, 162]
    num_objectives = [1, 1]
    num_constraints = [0, 0]
    bounds = [(10.0, 54.0)] * 3 + [(0.0, 1.0)] * 162
    candidates_path = Path("data/hea_medium/candidates.pt")
    outputs_path = Path("data/hea_medium/outputs.pt")

    def __init__(
        self,
        noise_std: float | list[float] | None = None,
        has_candidates: bool = True,
        num_discretize: int | list[int] | None = None,
        descriptor: Literal["RDF"] | None = None,
    ) -> None:

        if descriptor == "RDF":
            self.num_dims = [3, 30]
            self.bounds = [(15.0, 54.0)] * 3 + [(0.0, 34.0)] * 30
            self.candidates_path = Path("data/hea_medium/candidates_RDF.pt")
        super().__init__(noise_std, has_candidates, num_discretize)
        self.outputs = torch.load(self.outputs_path, weights_only=False)


    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, sum(num_objectives)]

        X_flat = X.view(-1, self.d_in)
        idx = (~self.candidates.isnan().any(dim=-1)).to(int).argmax(dim=1)
        X_upper = self.candidates[torch.arange(self.candidates.size(0)), idx]
        idx0 = torch.isclose(
            input=X_flat[:, :self.num_dims[0]].unsqueeze(1),
            other=X_upper[:, :self.num_dims[0]].unsqueeze(0),
        ).all(dim=-1).nonzero(as_tuple=False)[:, 1:].squeeze()
        idx1 = torch.isclose(
            input=X_flat[:, self.num_dims[0]:].unsqueeze(1),
            other=self.candidates[idx0, :, self.num_dims[0]:],
        ).all(dim=-1).nonzero(as_tuple=False)[:, 1:].squeeze()
        return self.outputs[idx0, idx1].view(*X.shape[:-1], self.d_out)



# -----------------------------------------------------------------------------
# High Entropy Alloy (Large)
# -----------------------------------------------------------------------------

class HighEntropyAlloyLarge(BiLevelBaseProblem):

    num_dims = [3, 162]
    num_objectives = [1, 1]
    num_constraints = [0, 0]
    bounds = [(10.0, 54.0)] * 3 + [(0.0, 1.0)] * 162
    candidates_path = Path("data/hea_large/candidates.pt")
    outputs_path = Path("data/hea_large/outputs.pt")

    def __init__(
        self,
        noise_std: float | list[float] | None = None,
        has_candidates: bool = True,
        num_discretize: int | list[int] | None = None,
        descriptor: Literal["RDF"] | None = None,
    ) -> None:

        if descriptor == "RDF":
            self.num_dims = [3, 30]
            self.bounds = [(15.0, 54.0)] * 3 + [(0.0, 72.0)] * 30
            self.candidates_path = Path("data/hea_large/candidates_RDF.pt")
        super().__init__(noise_std, has_candidates, num_discretize)
        self.outputs = torch.load(self.outputs_path, weights_only=False)


    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, sum(num_objectives)]

        X_flat = X.view(-1, self.d_in)
        idx = (~self.candidates.isnan().any(dim=-1)).to(int).argmax(dim=1)
        X_upper = self.candidates[torch.arange(self.candidates.size(0)), idx]
        idx0 = torch.isclose(
            input=X_flat[:, :self.num_dims[0]].unsqueeze(1),
            other=X_upper[:, :self.num_dims[0]].unsqueeze(0),
        ).all(dim=-1).nonzero(as_tuple=False)[:, 1:].squeeze()
        idx1 = torch.isclose(
            input=X_flat[:, self.num_dims[0]:].unsqueeze(1),
            other=self.candidates[idx0, :, self.num_dims[0]:],
        ).all(dim=-1).nonzero(as_tuple=False)[:, 1:].squeeze()
        return self.outputs[idx0, idx1].view(*X.shape[:-1], self.d_out)


# -----------------------------------------------------------------------------
# Energy market
# -----------------------------------------------------------------------------

class EnergyMarket(BiLevelBaseProblem):

    num_dims = [2, 2]
    num_objectives = [1, 1]
    num_constraints = [0, 0]
    bounds = [(0.01, 0.5), (200, 500), (0.0, 0.2), (0.5, 1.5)]
    candidates_path = Path("data/energy/candidates.pt")
    outputs_path = Path("data/energy/outputs.pt")

    def __init__(
        self,
        noise_std: float | list[float] | None = None,
        has_candidates: bool = True,
        num_discretize: int | list[int] | None = None,
        log_transform: bool = False,
    ) -> None:

        super().__init__(noise_std, has_candidates, num_discretize)
        self.log_transform = log_transform
        self.outputs = torch.load(self.outputs_path, weights_only=False)
        if self.log_transform:
            self.outputs = log1p(self.outputs)


    
    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, sum(num_objectives)]

        X_flat = X.view(-1, self.d_in)
        idx = (~self.candidates.isnan().any(dim=-1)).to(int).argmax(dim=1)
        X_upper = self.candidates[torch.arange(self.candidates.size(0)), idx]
        idx0 = torch.isclose(
            input=X_flat[:, :self.num_dims[0]].unsqueeze(1),
            other=X_upper[:, :self.num_dims[0]].unsqueeze(0),
        ).all(dim=-1).nonzero(as_tuple=False)[:, 1:].squeeze()
        idx1 = torch.isclose(
            input=X_flat[:, self.num_dims[0]:].unsqueeze(1),
            other=self.candidates[idx0, :, self.num_dims[0]:],
        ).all(dim=-1).nonzero(as_tuple=False)[:, 1:].squeeze()
        return self.outputs[idx0, idx1].view(*X.shape[:-1], self.d_out)


# -----------------------------------------------------------------------------
# Chem
# -----------------------------------------------------------------------------

class Chemistry(BiLevelBaseProblem):

    num_dims = [1, 3]
    num_objectives = [1, 1]
    num_constraints = [1, 0]
    bounds = [(0.0, 0.1), (0.0, 1.0), (0.0, 1.0), (0.0, 1.0)]
    candidates_path = Path("data/chem/candidates.pt")
    outputs_path = Path("data/chem/outputs.pt")

    def __init__(
        self,
        noise_std: float | list[float] | None = None,
        has_candidates: bool = True,
        num_discretize: int | list[int] | None = None,
        log_transform: bool = False,
    ) -> None:

        super().__init__(noise_std, has_candidates, num_discretize)
        self.log_transform = log_transform
        self.outputs = torch.load(self.outputs_path, weights_only=False)
        if self.log_transform:
            self.outputs = log1p(self.outputs)


    
    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, sum(num_objectives)]

        X_flat = X.view(-1, self.d_in)
        idx = (~self.candidates.isnan().any(dim=-1)).to(int).argmax(dim=1)
        X_upper = self.candidates[torch.arange(self.candidates.size(0)), idx]
        idx0 = torch.isclose(
            input=X_flat[:, :self.num_dims[0]].unsqueeze(1),
            other=X_upper[:, :self.num_dims[0]].unsqueeze(0),
        ).all(dim=-1).nonzero(as_tuple=False)[:, 1:].squeeze()
        idx1 = torch.isclose(
            input=X_flat[:, self.num_dims[0]:].unsqueeze(1),
            other=self.candidates[idx0, :, self.num_dims[0]:],
        ).all(dim=-1).nonzero(as_tuple=False)[:, 1:].squeeze()
        return self.outputs[idx0, idx1].view(*X.shape[:-1], self.d_out)


