# =============================================================================
# Joint Entropy Search
# =============================================================================

import torch
from torch import Tensor
from torch.optim import SGD, LBFGS
from botorch.optim.core import scipy_minimize
from botorch.utils.transforms import unnormalize

from acquisitions.base import BiLevelAcquisition
from utils import RFFModelList



# Utility functions -----------------------------------------------------------

def atleast_3d(input: Tensor, dim: int = 0) -> Tensor:
    while input.ndim < 3:
        input = input.unsqueeze(dim)
    return input

def atleast_4d(input: Tensor, dim: int = 0) -> Tensor:
    while input.ndim < 4:
        input = input.unsqueeze(dim)
    return input

def inv(input: Tensor) -> Tensor:
    _input = input.clone()
    eigvals = torch.linalg.eigvalsh(_input)
    mask = (eigvals.min(dim=-1).values <= 1e-12)
    mask |= _input.isnan().any(dim=(-2, -1))
    _input[mask] = torch.eye(_input.size(-1), dtype=torch.double)
    _min, _max = _input.amin(dim=(-2, -1)), _input.amax(dim=(-2, -1))
    mask = _min.isclose(_max)
    eye = 1e-6 * torch.eye(_input.size(-1), dtype=torch.double)
    _input = _input + mask.view(*mask.shape, 1, 1) * eye
    L = torch.linalg.cholesky(_input)
    matinv = torch.cholesky_inverse(L)
    return matinv

def argmin(input: Tensor, dim: int | None = None) -> Tensor:
    _input = input.clone()
    idx = _input.nan_to_num(float("inf")).argmin(dim=dim, keepdim=True)
    return idx

def cdf(value: Tensor) -> Tensor:
    _value = value.clone()
    mask = _value.isnan()
    _value[mask] = 0.0
    normal = torch.distributions.Normal(0.0, 1.0)
    out = normal.cdf(_value)
    out = torch.where(~mask, out, float("nan"))
    return out

def pdf(value: Tensor) -> Tensor:
    _value = value.clone()
    mask = _value.isnan()
    _value[mask] = 0.0
    normal = torch.distributions.Normal(0.0, 1.0)
    out = torch.exp(normal.log_prob(_value))
    out = torch.where(~mask, out, float("nan"))
    return out


# -----------------------------------------------------------------------------
# Bi-Level Joint Entropy Search (BLJES)
# -----------------------------------------------------------------------------

class BiLevelJointEntropySearch(BiLevelAcquisition):

    def __init__(
        self,
        num_dims: list[int],
        model_Y_upper: RFFModelList,
        model_Y_lower: RFFModelList,
        model_C_upper: RFFModelList | None = None,
        model_C_lower: RFFModelList | None = None,
        num_samples: int = 64,
        noisy_obs: bool = True,
        joint: bool = True,
        num_restarts: int | None = None,  # for query settings
        raw_samples: int | None = None,  # for query settings
    ) -> None:

        super().__init__(
            num_dims=num_dims,
            model_Y_upper=model_Y_upper,
            model_Y_lower=model_Y_lower,
            model_C_upper=model_C_upper,
            model_C_lower=model_C_lower,
            num_restarts=num_restarts,
            raw_samples=raw_samples,
        )
        self.num_samples = num_samples
        self.noisy_obs = noisy_obs
        self.joint = joint


    def _upper_sample(
        self,
        X_upper: Tensor,  # shape: [(b1, b2), num_dims[0]]
        X_lower: Tensor,  # shape: [(b1, b2), num_dims[1]]
    ) -> Tensor:

        X_upper = atleast_3d(X_upper, dim=1)
        X_lower = atleast_3d(X_lower, dim=1)
        X = torch.cat([X_upper, X_lower], dim=-1)
        Y = self.model_Y_upper.rsample(X, self.num_samples)
        return -Y.squeeze(-1)



    def _lower_sample(
        self,
        X_upper: Tensor,  # shape: [(b1, b2), num_dims[0]]
        X_lower: Tensor,  # shape: [(b1, b2), num_dims[1]]
    ) -> Tensor:

        X_upper = atleast_3d(X_upper, dim=1)
        X_lower = atleast_3d(X_lower, dim=1)
        X = torch.cat([X_upper, X_lower], dim=-1)
        Y = self.model_Y_lower.rsample(X, self.num_samples)
        return -Y.squeeze(-1)



    def _upper_mi_lower_bound(
        self,
        X: Tensor,  # shape: [(b1, b2), d_in]
        trunc_X: Tensor,  # shape: [num_samples, (b1, 1), d_in]
        opt_X: Tensor,  # shape: [num_samples, (1, 1), d_in]
    ) -> Tensor:  # shape: [(b1), (b2)]

        X = atleast_3d(X, dim=1)  # shape: [(b1), (b2), d_in]
        trunc_X = atleast_4d(trunc_X, dim=1)  # shape: [num_samples, (b1), 1, d_in]
        opt_X = atleast_4d(opt_X, dim=1)  # shape: [num_samples, 1, 1, d_in]
        Y = self.model_Y_upper.rsample(X, self.num_samples).squeeze(-1)  # shape: [num_samples, (b1), (b2)]
        trunc_Y = self.model_Y_upper.rsample(trunc_X, self.num_samples)  # shape: [num_samples, num_samples, (b1), 1]
        trunc_Y = trunc_Y.squeeze(-1).diagonal(0, 0, 1).permute(2, 0, 1)  # shape: [num_samples, (b1), 1]
        opt_Y = self.model_Y_upper.rsample(opt_X, self.num_samples).squeeze(-1)  # shape: [num_samples, num_samples, 1, 1]
        opt_Y = opt_Y.diagonal(0, 0, 1).permute(2, 0, 1)  # shape: [num_samples, 1, 1]
        Y_mean = self.model_Y_upper.mean(X).squeeze(-1).unsqueeze(0)
        Y_mean = Y_mean.expand(self.num_samples, -1, -1)  # shape: [num_samples, (b1), (b2)]
        Y_var = self.model_Y_upper.var(X).squeeze(-1).unsqueeze(0)
        Y_var = Y_var.expand(self.num_samples, -1, -1)  # shape: [num_samples, (b1), (b2)]
        trunc_Y_mean = self.model_Y_upper.mean(trunc_X).squeeze(-1)  # shape: [num_samples, (b1), 1]
        trunc_Y_var = self.model_Y_upper.var(trunc_X).squeeze(-1)  # shape: [num_samples, (b1), 1]
        opt_Y_mean = self.model_Y_upper.mean(opt_X).squeeze(-1)  # shape: [num_samples, 1, 1]
        opt_Y_var = self.model_Y_upper.var(opt_X).squeeze(-1)  # shape: [num_samples, 1, 1]
        cov1 = self.model_Y_upper.cov(trunc_X, X).squeeze(-1)  # shape: [num_samples, (b1), 1, (b1), (b2)]
        cov1 = cov1.squeeze(2).diagonal(0, 1, 2).permute(0, 2, 1)  # shape: [num_samples, (b1), (b2)]
        cov2 = self.model_Y_upper.cov(trunc_X, opt_X).squeeze(-1)
        cov2 = cov2.diagonal(0, 0, 3).permute(4, 0, 1, 2, 3)
        cov2 = cov2.view(self.num_samples, -1, 1).expand_as(cov1) # shape: [num_samples, (b1), (b2)]
        cov3 = self.model_Y_upper.cov(opt_X, X).squeeze(-1)
        cov3 = cov3.view(self.num_samples, *X.shape[:-1])  # shape: [num_samples, (b1), (b2)]

        if self.model_C_upper is not None:
            C = self.model_C_upper.rsample(X, self.num_samples)  # shape: [num_samples, (b1), (b2), num_constraints[0]]
            C_mean = self.model_C_upper.mean(X).unsqueeze(0)
            C_mean = C_mean.expand(self.num_samples, -1, -1, -1)  # shape: [num_samples, (b1), (b2), num_constraints[0]]
            C_var = self.model_C_upper.var(X).unsqueeze(0)
            C_var = C_var.expand(self.num_samples, -1, -1, -1)  # shape: [num_samples, (b1), (b2), num_constraints[0]]
            trunc_C_mean = self.model_C_upper.mean(trunc_X)  # shape: [num_samples, (b1), 1, num_constraints[0]]
            trunc_C_var = self.model_C_upper.var(trunc_X)  # shape: [num_samples, (b1), 1, num_constraints[0]]
            covC = self.model_C_upper.cov(trunc_X, X)  # shape: [num_samples, (b1), 1, (b1), (b2), num_constraints[0]]
            covC = covC.squeeze(2).diagonal(0, 1, 2).permute(0, 3, 1, 2)  # shape: [num_samples, (b1), (b2), num_constraints[0]]
            if self.noisy_obs:
                for c in range(self.num_constraints[0]):
                    noise = self.model_C_upper.rff_models[c].model.likelihood.noise
                    C[..., c] = C[..., c] + noise.sqrt() * torch.randn_like(C[..., c])
                    C_var = C_var.clone()
                    C_var[..., c] = C_var[..., c] + noise

        if self.noisy_obs:
            noise = self.model_Y_upper.rff_models[0].model.likelihood.noise
            Y = Y + noise.sqrt() * torch.randn_like(Y)
            Y_var = Y_var + noise

        train_X = atleast_4d(self.model_Y_upper.rff_models[0].train_X, dim=1)  # shape: [n_train, 1, 1, d_in]
        train_match = train_X.isclose(X.unsqueeze(0)).all(dim=-1).any(dim=0)  # shape: [(b1), (b2)]
        train_match = train_match.unsqueeze(0).expand(self.num_samples, -1, -1)  # shape: [num_samples, (b1), (b2)]
        trunc_match = trunc_X.isclose(X.unsqueeze(0)).all(dim=-1)  # shape: [num_samples, (b1), (b2)]
        trunc_train_match = trunc_X.unsqueeze(0).isclose(train_X.unsqueeze(1))
        trunc_train_match = trunc_train_match.all(dim=-1).any(dim=0)  # shape: [num_samples, (b1), 1]
        trunc_train_match = trunc_train_match.expand_as(trunc_match)  # shape: [num_samples, (b1), (b2)]
        opt_X0, opt_X1 = opt_X.split(self.num_dims, dim=-1)
        X0, X1 = X.split(self.num_dims, dim=-1)
        opt0_match = opt_X0.isclose(X0.unsqueeze(0)).all(dim=-1)  # shape: [num_samples, (b1), (b2)]
        opt1_match = opt_X1.isclose(X1.unsqueeze(0)).all(dim=-1)  # shape: [num_samples, (b1), (b2)]

        if self.joint:
            covvec = torch.stack([cov1, cov2], dim=-1).unsqueeze(-1)  # shape: [num_samples, (b1), (b2), 2, 1]
            covmat = torch.stack([
                torch.stack([Y_var, cov3], dim=-1),
                torch.stack([cov3, opt_Y_var.expand_as(cov3)], dim=-1),
            ], dim=-2)  # shape: [num_samples, (b1), (b2), 2, 2]
            coeff = covvec.transpose(-2, -1) @ inv(covmat)  # shape: [num_samples, (b1), (b2), 1, 2]
            diff = torch.stack([
                (Y - Y_mean), (opt_Y - opt_Y_mean).expand_as(Y)
            ], dim=-1).unsqueeze(-1)  # shape: [num_samples, (b1), (b2), 2, 1]
            m1 = trunc_Y_mean + (coeff @ diff).squeeze(-1).squeeze(-1)  # shape: [num_samples, (b1), (b2)]
            v1 = trunc_Y_var - (coeff @ covvec).squeeze(-1).squeeze(-1)  # shape: [num_samples, (b1), (b2)]
            m2 = trunc_Y_mean + (cov2 / opt_Y_var) * (opt_Y - opt_Y_mean)  # shape: [num_samples, (b1), 1]
            v2 = trunc_Y_var - (cov2**2 / opt_Y_var)  # shape: [num_samples, (b1), 1]
            m3 = Y_mean + (cov3 / opt_Y_var) * (opt_Y - opt_Y_mean)  # shape: [num_samples, (b1), (b2)]
            v3 = Y_var - (cov3**2 / opt_Y_var)  # shape: [num_samples, (b1), (b2)]

            if self.model_C_upper is not None:
                mC = trunc_C_mean + (covC / C_var) * (C - C_mean)
                vC = trunc_C_var + (covC**2 / C_var)

            # sanitize
            _opt_Y = opt_Y.expand_as(Y)
            m1[opt0_match] = _opt_Y[opt0_match]
            m2[opt0_match] = _opt_Y[opt0_match]
            v1[opt0_match] = 0.0
            v2[opt0_match] = 0.0
            if not self.noisy_obs:
                _trunc_Y = trunc_Y.expand_as(Y)
                m1[trunc_match] = _trunc_Y[trunc_match]
                m1[trunc_train_match] = _trunc_Y[trunc_train_match]
                m2[trunc_train_match] = _trunc_Y[trunc_train_match]
                m3[opt0_match & opt1_match] = _opt_Y[opt0_match & opt1_match]
                v1[trunc_match] = 0.0
                v1[trunc_train_match] = 0.0
                v2[trunc_train_match] = 0.0
                v3[opt0_match & opt1_match | train_match] = 0.0

            p1 = cdf((opt_Y - m1) / v1.sqrt()).clamp_min(1e-12)
            p2 = cdf((opt_Y - m2) / v2.sqrt()).clamp_min(1e-12)
            p3 = pdf((Y - m3) / v3.sqrt()) / v3.sqrt()
            p4 = pdf((Y - Y_mean) / Y_var.sqrt()) / Y_var.sqrt()
            if self.model_C_upper is not None:
                p1 = 1 - (1 - p1) * (1 - cdf(-mC / vC.sqrt())).prod(dim=-1)
                p2 = 1 - (1 - p2) * (1 - cdf(-C_mean / C_var.sqrt())).prod(dim=-1)
                p3 = p3 * (pdf((C - C_mean) / C_var.sqrt()) / C_var.sqrt()).prod(dim=-1)
                p4 = p4 * (pdf((C - C_mean) / C_var.sqrt()) / C_var.sqrt()).prod(dim=-1)
            # sanitize
            p1[opt0_match] = 1.0
            p2[opt0_match] = 1.0
            if not self.noisy_obs:
                p1[trunc_match | trunc_train_match] = 1.0
                p2[trunc_train_match] = 1.0
                p3[opt0_match & opt1_match | train_match] = 1.0
            LB = torch.log((p1 * p3) / (p2 * p4)).nanmean(dim=0)
        else:
            m1 = trunc_Y_mean + (cov1 / Y_var) * (Y - Y_mean)
            v1 = trunc_Y_var - (cov1**2 / Y_var)
            # sanitize
            if not self.noisy_obs:
                _trunc_Y = trunc_Y.expand_as(Y)
                m1[trunc_match] = _trunc_Y[trunc_match]
                m1[trunc_train_match] = _trunc_Y[trunc_train_match]
                v1[trunc_match] = 0.0
                v1[trunc_train_match] = 0.0
            p1 = cdf((opt_Y - m1) / v1.sqrt())
            p2 = cdf((opt_Y - trunc_Y_mean) / trunc_Y_var.sqrt())
            # sanitize
            if not self.noisy_obs:
                p1[trunc_match | trunc_train_match] = 1.0
                p2[trunc_train_match] = 1.0
            LB = torch.log(p1 / p2).nanmean(dim=0)
        return -LB


    def _lower_mi_lower_bound(
        self,
        X: Tensor,  # shape: [(b1, b2), d_in]
        trunc_X: Tensor,  # shape: [num_samples, (1, b2), d_in]
        opt_X: Tensor,  # shape: [num_samples, (1, 1), d_in]
    ) -> Tensor:  # shape: [(b1), (b2)]

        X = atleast_3d(X, dim=1)  # shape: [(b1), (b2), d_in]
        trunc_X = atleast_4d(trunc_X, dim=1)  # shape: [num_samples, 1, (b2), d_in]
        opt_X = atleast_4d(opt_X, dim=1)  # shape: [num_samples, 1, 1, d_in]
        Y = self.model_Y_lower.rsample(X, self.num_samples).squeeze(-1)  # shape: [num_samples, (b1), (b2)]
        print(X.shape, trunc_X.shape, opt_X.shape)
        trunc_Y = self.model_Y_upper.rsample(trunc_X, self.num_samples)  # shape: [num_samples, num_samples, 1, (b2)]
        trunc_Y = trunc_Y.squeeze(-1).diagonal(0, 0, 1).permute(2, 0, 1)  # shape: [num_samples, 1, (b2)]
        opt_Y = self.model_Y_lower.rsample(opt_X, self.num_samples).squeeze(-1)
        opt_Y = opt_Y.diagonal(0, 0, 1).permute(2, 0, 1)  # shape: [num_samples, 1, 1]
        Y_mean = self.model_Y_lower.mean(X).squeeze(-1).unsqueeze(0)
        Y_mean = Y_mean.expand(self.num_samples, -1, -1)  # shape: [num_samples, (b1), (b2)]
        Y_var = self.model_Y_lower.var(X).squeeze(-1).unsqueeze(0)
        Y_var = Y_var.expand(self.num_samples, -1, -1)  # shape: [num_samples, (b1), (b2)]
        trunc_Y_mean = self.model_Y_lower.mean(trunc_X).squeeze(-1)  # shape: [num_samples, 1, (b2)]
        trunc_Y_var = self.model_Y_lower.var(trunc_X).squeeze(-1)  # shape: [num_samples, 1, (b2)]
        opt_Y_mean = self.model_Y_lower.mean(opt_X).squeeze(-1)  # shape: [num_samples, 1, 1]
        opt_Y_var = self.model_Y_lower.var(opt_X).squeeze(-1)  # shape: [num_samples, 1, 1]
        cov1 = self.model_Y_lower.cov(trunc_X, X).squeeze(-1)
        cov1 = cov1.diagonal(0, 2, 4).permute(0, 2, 3, 1).squeeze(-1)  # shape: [num_samples, (b1), (b2)]
        cov2 = self.model_Y_lower.cov(trunc_X, opt_X).squeeze(-1)
        cov2 = cov2.diagonal(0, 0, 3).permute(4, 0, 1, 2, 3)
        cov2 = cov2.view(self.num_samples, 1, -1).expand_as(cov1) # shape: [num_samples, (b1), (b2)]
        cov3 = self.model_Y_lower.cov(opt_X, X).squeeze(-1)
        cov3 = cov3.view(self.num_samples, *X.shape[:-1])  # shape: [num_samples, (b1), (b2)]

        if self.noisy_obs:
            noise = self.model_Y_lower.rff_models[0].model.likelihood.noise
            Y = Y + noise.sqrt() * torch.randn_like(Y)
            Y_var = Y_var + noise

        if self.model_C_lower is not None:
            C = self.model_C_lower.rsample(X, self.num_samples)  # shape: [num_samples, (b1), (b2), num_constraints[0]]
            C_mean = self.model_C_lower.mean(X).unsqueeze(0)
            C_mean = C_mean.expand(self.num_samples, -1, -1, -1)  # shape: [num_samples, (b1), (b2), num_constraints[0]]
            C_var = self.model_C_lower.var(X).unsqueeze(0)
            C_var = C_var.expand(self.num_samples, -1, -1, -1)  # shape: [num_samples, (b1), (b2), num_constraints[0]]
            trunc_C_mean = self.model_C_lower.mean(trunc_X)  # shape: [num_samples, 1, (b2), num_constraints[0]]
            trunc_C_var = self.model_C_lower.var(trunc_X)  # shape: [num_samples, 1, (b2), num_constraints[0]]
            covC = self.model_C_lower.cov(trunc_X, X)  # shape: [num_samples, 1, (b2), (b1), (b2), num_constraints[0]]
            covC = covC.squeeze(1).diagonal(0, 1, 3).permute(0, 1, 3, 2)  # shape: [num_samples, (b1), (b2), num_constraints[0]]
            if self.noisy_obs:
                for c in range(self.num_constraints[1]):
                    noise = self.model_C_lower.rff_models[c].model.likelihood.noise
                    C[..., c] = C[..., c] + noise.sqrt() * torch.randn_like(C[..., c])
                    C_var = C_var.clone()
                    C_var[..., c] = C_var[..., c] + noise

        train_X = atleast_4d(self.model_Y_lower.rff_models[0].train_X, dim=1)  # shape: [n_train, 1, 1, d_in]
        train_match = train_X.isclose(X.unsqueeze(0)).all(dim=-1).any(dim=0)  # shape: [(b1), (b2)]
        trunc_match = trunc_X.isclose(X.unsqueeze(0)).all(dim=-1)  # shape: [num_samples, (b1), (b2)]
        trunc_train_match = trunc_X.unsqueeze(0).isclose(train_X.unsqueeze(1))
        trunc_train_match = trunc_train_match.all(dim=-1).any(dim=0)  # shape: [num_samples, 1, (b2)]
        trunc_train_match = trunc_train_match.expand_as(trunc_match)  # shape: [num_samples, (b1), (b2)]
        opt_X0, opt_X1 = opt_X.split(self.num_dims, dim=-1)
        X0, X1 = X.split(self.num_dims, dim=-1)
        opt0_match = opt_X0.isclose(X0.unsqueeze(0)).all(dim=-1)  # shape: [num_samples, (b1), (b2)]
        opt1_match = opt_X1.isclose(X1.unsqueeze(0)).all(dim=-1)  # shape: [num_samples, (b1), (b2)]

        if self.joint:
            covvec = torch.stack([cov1, cov2], dim=-1).unsqueeze(-1)  # shape: [num_samples, (b1), (b2), 2, 1]
            covmat = torch.stack([
                torch.stack([Y_var, cov3], dim=-1),
                torch.stack([cov3, opt_Y_var.expand_as(cov3)], dim=-1),
            ], dim=-2)  # shape: [num_samples, (b1), (b2), 2, 2]
            coeff = covvec.transpose(-2, -1) @ inv(covmat)  # shape: [num_samples, (b1), (b2), 1, 2]
            diff = torch.stack([
                (Y - Y_mean), (opt_Y - opt_Y_mean).expand_as(Y)
            ], dim=-1).unsqueeze(-1)  # shape: [num_samples, (b1), (b2), 2, 1]
            m1 = trunc_Y_mean + (coeff @ diff).squeeze(-1).squeeze(-1)  # shape: [num_samples, (b1), (b2)]
            v1 = trunc_Y_var - (coeff @ covvec).squeeze(-1).squeeze(-1)  # shape: [num_samples, (b1), (b2)]
            m2 = trunc_Y_mean + (cov2 / opt_Y_var) * (opt_Y - opt_Y_mean)  # shape: [num_samples, 1, (b2)]
            v2 = trunc_Y_var - (cov2**2 / opt_Y_var)  # shape: [num_samples, 1, (b2)]
            m3 = Y_mean + (cov3 / opt_Y_var) * (opt_Y - opt_Y_mean)  # shape: [num_samples, (b1), (b2)]
            v3 = Y_var - (cov3**2 / opt_Y_var)  # shape: [num_samples, (b1), (b2)]

            if self.model_C_lower is not None:
                mC = trunc_C_mean + (covC / C_var) * (C - C_mean)
                vC = trunc_C_var + (covC**2 / C_var)
            # sanitize
            _opt_Y = opt_Y.expand_as(Y)
            m1[opt1_match] = _opt_Y[opt1_match]
            m2[opt1_match] = _opt_Y[opt1_match]
            v1[opt1_match] = 0.0
            v2[opt1_match] = 0.0
            if not self.noisy_obs:
                _trunc_Y = trunc_Y.expand_as(Y)
                m1[opt0_match] = _trunc_Y[opt0_match]
                m1[trunc_train_match] = _trunc_Y[trunc_train_match]
                m2[trunc_train_match] = _trunc_Y[trunc_train_match]
                m3[opt0_match & opt1_match] = _opt_Y[opt0_match & opt1_match]
                v1[opt0_match | trunc_train_match] = 0.0
                v2[trunc_train_match] = 0.0
                v3[opt0_match & opt1_match | train_match] = 0.0
            
            p1 = cdf((opt_Y - m1) / v1.sqrt()).clamp_min(1e-12)
            p2 = cdf((opt_Y - m2) / v2.sqrt()).clamp_min(1e-12)
            p3 = pdf((Y - m3) / v3.sqrt()) / v3.sqrt()
            p4 = pdf((Y - Y_mean) / Y_var.sqrt()) / Y_var.sqrt()
            if self.model_C_lower is not None:
                p1 = 1 - (1 - p1) * (1 - cdf(-mC / vC.sqrt())).prod(dim=-1)
                p2 = 1 - (1 - p2) * (1 - cdf(-C_mean / C_var.sqrt())).prod(dim=-1)
                p3 = p3 * (pdf((C - C_mean) / C_var.sqrt()) / C_var.sqrt()).prod(dim=-1)
                p4 = p4 * (pdf((C - C_mean) / C_var.sqrt()) / C_var.sqrt()).prod(dim=-1)
            # sanitize
            _opt_Y = opt_Y.expand_as(Y)
            _trunc_Y = trunc_Y.expand_as(Y)
            p1[opt1_match] = 1.0
            p2[opt1_match] = 1.0
            if not self.noisy_obs:
                p1[opt0_match | trunc_train_match] = 1.0
                p2[trunc_train_match] = 1.0
                p3[opt0_match & opt1_match | train_match] = 1.0
            LB = torch.log((p1 * p3) / (p2 * p4)).nanmean(dim=0)
        else:
            m1 = trunc_Y_mean + (cov1 / Y_var) * (Y - Y_mean)
            v1 = trunc_Y_var - (cov1**2 / Y_var)
            # sanitize
            if not self.noisy_obs:
                _trunc_Y = trunc_Y.expand_as(Y)
                m1[opt0_match] = _trunc_Y[opt0_match]
                m1[trunc_train_match] = _trunc_Y[trunc_train_match]
                v1[opt0_match] = 0.0
                v1[trunc_train_match] = 0.0
            p1 = cdf((opt_Y - m1) / v1.sqrt())
            p2 = cdf((opt_Y - trunc_Y_mean) / trunc_Y_var.sqrt())
            # sanitize
            if not self.noisy_obs:
                p1[opt0_match | trunc_train_match] = 1.0
                p2[trunc_train_match] = 1.0
            LB = torch.log(p1 / p2).nanmean(dim=0)
        return -LB
        

    def plot(self, X: Tensor, values: Tensor, path: str):
        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(figsize=(6, 5))
        c = ax.contourf(X[..., 0], X[..., 1], values.detach(), levels=50)
        fig.colorbar(c, ax=ax)
        fig.tight_layout()
        fig.savefig(path)
        plt.clf(); plt.close()


    def _optimize_samples_pool(
        self,
        X: Tensor,  # shape: [b1, b2, d_in]
        # mask_evaluated: Tensor | None = None,  # shape: [b1, b2, d_out]
    ) -> tuple[Tensor, ...]:  # shape: [num_samples, 1, 1, d_in], shape: [num_samples, b1, 1, d_in], shape: [num_samples, 1, b2, d_in]

        # optimize lower-level
        Y_lower = self._lower_sample(*X.split(self.num_dims, dim=-1))
        if self.model_C_lower is not None:
            C_lower = self.model_C_lower.rsample(X, self.num_samples)
            Y_lower = torch.where((C_lower >= 0.0).all(dim=-1), Y_lower, float("nan"))
        feas_idx = argmin(Y_lower, dim=2)
        # optimize upper-level
        Y_upper = self._upper_sample(*X.split(self.num_dims, dim=-1))
        if self.model_C_upper is not None:
            C_upper = self.model_C_upper.rsample(X, self.num_samples)
            Y_upper = torch.where((C_upper >= 0.0).all(dim=-1), Y_upper, float("nan"))
        opt_idx = argmin(Y_upper.gather(2, feas_idx), dim=1)
        # sample optimal solution
        _X = X.unsqueeze(0).expand(self.num_samples, -1, -1, -1)
        _feas_idx = feas_idx.unsqueeze(-1).expand(-1, -1, -1, self.d_in)
        _opt_idx = opt_idx.unsqueeze(-1).expand(-1, -1, -1, self.d_in)
        _trunc_idx = _opt_idx.expand(-1, -1, X.size(1), -1)
        feas_X = _X.gather(2, _feas_idx)  # shape: [num_samples, b1, 1, d_in]
        opt_X = feas_X.gather(1, _opt_idx)  # shape: [num_samples, 1, 1, d_in]
        trunc_X = _X.gather(1, _trunc_idx)  # shape: [num_samples, 1, b2, d_in]
        # trunc_X = torch.cat(opt_X[..., :self.num_dims[0]].expand(-1, -1, X.size(1), -1), _X[:, :1, :, self.num_dims[0]:], dim=-1)
        # print(opt_X[:, 0, 0, 0])
        return opt_X, feas_X, trunc_X


    def _optimize_samples_query(
        self,
        bounds: Tensor,  # shape: [2, d_in]
    ) -> Tensor:  # shape: [num_samples, d_in]

        def upper_sample(
            X_upper: Tensor,  # shape: [num_samples, num_dims[0]]
            X_lower: Tensor,  # shape: [num_samples, num_dims[1]]
        ) -> Tensor:  # shape: [num_samples]
            sample = self._upper_sample(X_upper, X_lower)
            return sample.squeeze().diag()

        def lower_sample(
            X_upper: Tensor,  # shape: [num_samples, num_dims[0]]
            X_lower: Tensor,  # shape: [num_samples, num_dims[1]]
        ) -> Tensor:  # shape: [num_samples]
            sample = self._lower_sample(X_upper, X_lower)
            return sample.squeeze().diag()

        d = round(self.raw_samples ** (1 / self.d_in))
        grid = torch.stack(torch.meshgrid(
            [torch.linspace(*b, d) for b in bounds.T], indexing="ij"
        ), dim=-1).view(d**self.num_dims[0], d**self.num_dims[1], self.d_in)
        # initialize parameters
        init_X, _, _ = self._optimize_samples_pool(grid)
        init_X = init_X.squeeze()
        X_upper, X_lower = init_X.split(self.num_dims, dim=-1)
        b_upper, b_lower = bounds.split(self.num_dims, dim=-1)
        X_upper = X_upper.requires_grad_(True)
        X_lower = X_lower.requires_grad_(True)
        opt_upper = SGD([X_upper], lr=1e-2)
        for _ in range(10):
            opt_lower = LBFGS([X_lower], line_search_fn="strong_wolfe")
            def closure():
                opt_lower.zero_grad()
                loss_lower = lower_sample(X_upper, X_lower).sum()
                loss_lower.backward(retain_graph=True)
                return loss_lower
            opt_lower.step(closure)
            with torch.no_grad():
                X_lower.clamp_(b_lower[0], b_lower[1])
            opt_upper.zero_grad()
            loss_upper = upper_sample(X_upper, X_lower).sum()
            g1, = torch.autograd.grad(
                outputs=loss_upper, inputs=X_upper, retain_graph=True,
            )
            g2, = torch.autograd.grad(
                outputs=loss_upper, inputs=X_lower, create_graph=True,
            )
            H = torch.autograd.functional.hessian(
                func=(lambda X: lower_sample(X_upper, X).sum()),
                inputs=X_lower,
            ).diagonal(0, 0, 2).permute(2, 0, 1)
            v = torch.linalg.solve(H, g2)
            g3, = torch.autograd.grad(
                outputs=lower_sample(X_upper, X_lower).sum(),
                inputs=X_lower, create_graph=True,
            )
            Pv, = torch.autograd.grad(
                outputs=g3, inputs=X_upper,
                grad_outputs=v, retain_graph=True,
            )
            X_upper.grad = g1 - Pv
            opt_upper.step()
            with torch.no_grad():
                X_upper.clamp_(b_upper[0], b_upper[1])
        return torch.cat([X_upper, X_lower], dim=-1).detach()


    def _optimize_alpha_query(
        self,
        bounds: Tensor,  # shape: [2, d_in]
        opt_X: Tensor,  # shape: [num_samples, d_in]
    ) -> Tensor:

        def alpha(
            X: Tensor,  # shape: [d_in]
            feas_X_lower: Tensor,  # shape: [num_samples, num_dims[1]]
            opt_X: Tensor,  # shape: [num_samples, d_in]
        ) -> Tensor:  # shape: [1]
            _X = X.unsqueeze(0).expand(self.num_samples, -1)
            X_upper, X_lower = _X.split(self.num_dims, dim=-1)
            opt_X_upper = opt_X[:, :self.num_dims[0]]
            feas_X = torch.cat([X_upper, feas_X_lower], dim=-1)
            trunc_X = torch.cat([opt_X_upper, X_lower], dim=-1)
            mi_upper = self._upper_mi_lower_bound(X.unsqueeze(0), feas_X, opt_X)
            mi_lower = self._lower_mi_lower_bound(X.unsqueeze(0), trunc_X, opt_X)
            return (mi_upper + mi_lower).squeeze(-1)

        def lower_sample(
            X: Tensor,  # shape: [d_in]
            feas_X_lower: Tensor,  # shape: [num_samples, num_dims[1]]
        ) -> Tensor:  # shape: [num_samples]
            X_upper = X[:self.num_dims[0]].unsqueeze(0).expand_as(feas_X_lower)
            sample = self._lower_sample(X_upper, feas_X_lower)
            return sample.squeeze().diag()

        d = round(self.raw_samples ** (1 / self.d_in))
        grid = torch.stack(torch.meshgrid(
            [torch.linspace(*b, d) for b in bounds.T], indexing="ij"
        ), dim=-1).view(d**self.num_dims[0], d**self.num_dims[1], self.d_in)
        # initialize parameters
        _, feas_X, trunc_X = self._optimize_samples_pool(grid)
        mi_upper = self._upper_mi_lower_bound(grid, feas_X, opt_X)
        mi_lower = self._lower_mi_lower_bound(grid, trunc_X, opt_X)
        self.plot(grid, mi_upper+mi_lower, "alpha.pdf")
        indices = argmin((mi_upper + mi_lower), dim=None)
        indices = torch.unravel_index(indices, mi_upper.shape)
        X = grid[*indices].squeeze().requires_grad_(True)
        feas_X_lower = feas_X[:, indices[0], 0, self.num_dims[0]:].squeeze(1).squeeze(1)
        feas_X_lower = feas_X_lower.requires_grad_(True)
        _, b_lower = bounds.split(self.num_dims, dim=-1)
        opt_upper = SGD([X], lr=1e-5)
        for _ in range(20):
            opt_lower = LBFGS([feas_X_lower], line_search_fn="strong_wolfe")
            def closure():
                opt_lower.zero_grad()
                loss_lower = lower_sample(X, feas_X_lower).sum()
                loss_lower.backward(retain_graph=True)
                return loss_lower
            opt_lower.step(closure)
            with torch.no_grad():
                feas_X_lower.clamp_(b_lower[0], b_lower[1])
            opt_upper.zero_grad()
            loss_upper = alpha(X, feas_X_lower, opt_X)
            g1, = torch.autograd.grad(
                outputs=loss_upper, inputs=X, retain_graph=True,
            )
            g2, = torch.autograd.grad(
                outputs=loss_upper, inputs=feas_X_lower, create_graph=True,
            )
            H = torch.autograd.functional.hessian(
                func=(lambda X_lower: lower_sample(X, X_lower).sum()),
                inputs=feas_X_lower,
            ).diagonal(0, 0, 2).permute(2, 0, 1)
            v = torch.linalg.solve(H + 1e-6 * torch.eye(H.size(-1)), g2)
            g3, = torch.autograd.grad(
                outputs=lower_sample(X, feas_X_lower).sum(),
                inputs=feas_X_lower, create_graph=True,
            )
            Pv, = torch.autograd.grad(
                outputs=g3, inputs=X,
                grad_outputs=v, retain_graph=True,
            )
            X.grad = g1 - Pv
            opt_upper.step()
            with torch.no_grad():
                X.clamp_(bounds[0], bounds[1])
            print(X.detach(), loss_upper.detach())
        return X.detach()


    def optimize_pool(
        self,
        candidates: Tensor,  # shape: [*batch_shape, d_in]
        mask_evaluated: Tensor,  # shape: [*batch_shape, d_out]
        decoupled: bool = False,
    ) -> tuple[Tensor, Tensor]:  # shape: [2], shape: [d_out]

        mask = candidates.isnan().any(dim=-1)
        trusted = ~mask.clone()

        X = candidates.clone()
        Y_lower = self._lower_sample(*X.split(self.num_dims, dim=-1))
        if self.model_C_lower is not None:
            C_lower = self.model_C_lower.rsample(X, self.num_samples)
            feas_lower = (C_lower >= 0.0).all(dim=-1)
            Y_lower = torch.where(feas_lower, Y_lower, float("nan"))
        feas_idx = argmin(Y_lower, dim=2)
        _feas_idx = feas_idx.unsqueeze(-1).expand(-1, -1, -1, self.d_in)
        _X = X.unsqueeze(0).expand(self.num_samples, -1, -1, -1)
        feas_X = _X.gather(2, _feas_idx)  # shape: [num_samples, b1, 1, d_in]
        Y_upper = self._upper_sample(*X.split(self.num_dims, dim=-1))
        if self.model_C_upper is not None:
            C_upper = self.model_C_upper.rsample(X, self.num_samples)
            feas_upper = (C_upper >= 0.0).all(dim=-1)
            Y_upper = torch.where(feas_upper, Y_upper, float("nan"))
        feas_Y_upper = Y_upper.gather(2, feas_idx)
        opt_idx = argmin(feas_Y_upper, dim=1)
        _opt_idx = opt_idx.unsqueeze(-1).expand(-1, -1, -1, self.d_in)
        opt_X = feas_X.gather(1, _opt_idx)  # shape: [num_samples, 1, 1, d_in]
        _trunc_idx = _opt_idx.expand(-1, -1, mask.shape[1], -1)
        trunc_X = _X.gather(1, _trunc_idx)  # shape: [num_samples, 1, b2, d_in]
        print(_opt_idx[:, 0, 0, 0])
        # trunc_X = torch.cat((opt_X[..., :self.num_dims[0]].expand(-1, -1, X.size(1), -1), ), dim=-1)
        
        X = X[40, 40].unsqueeze(0); print(X.shape)
        feas_X = feas_X[:, 40]; print(feas_X.shape)
        trunc_X = trunc_X[:, :, 40]; print(trunc_X.shape)
        mi_upper = self._upper_mi_lower_bound(X, feas_X, opt_X)
        mi_lower = self._lower_mi_lower_bound(X, trunc_X, opt_X)
        print(f"10000: {mi_upper[0, 0]+mi_lower[0, 0]}")
        self.num_samples = 5000
        mi_upper = self._upper_mi_lower_bound(X, feas_X[:5000], opt_X[:5000])
        mi_lower = self._lower_mi_lower_bound(X, trunc_X[:5000], opt_X[:5000])
        print(f"5000: {mi_upper[0, 0]+mi_lower[0, 0]}")
        self.num_samples = 1000
        mi_upper = self._upper_mi_lower_bound(X, feas_X[:1000], opt_X[:1000])
        mi_lower = self._lower_mi_lower_bound(X, trunc_X[:1000], opt_X[:1000])
        print(f"1000: {mi_upper[0, 0]+mi_lower[0, 0]}")
        self.num_samples = 500
        mi_upper = self._upper_mi_lower_bound(X, feas_X[:500], opt_X[:500])
        mi_lower = self._lower_mi_lower_bound(X, trunc_X[:500], opt_X[:500])
        print(f"500: {mi_upper[0, 0]+mi_lower[0, 0]}")
        self.num_samples = 100
        mi_upper = self._upper_mi_lower_bound(X, feas_X[:100], opt_X[:100])
        mi_lower = self._lower_mi_lower_bound(X, trunc_X[:100], opt_X[:100])
        print(f"100: {mi_upper[0, 0]+mi_lower[0, 0]}")
        self.num_samples = 50
        mi_upper = self._upper_mi_lower_bound(X, feas_X[:50], opt_X[:50])
        mi_lower = self._lower_mi_lower_bound(X, trunc_X[:50], opt_X[:50])
        print(f"50: {mi_upper[0, 0]+mi_lower[0, 0]}")
        self.num_samples = 10
        mi_upper = self._upper_mi_lower_bound(X, feas_X[:10], opt_X[:10])
        mi_lower = self._lower_mi_lower_bound(X, trunc_X[:10], opt_X[:10])
        print(f"10: {mi_upper[0, 0]+mi_lower[0, 0]}")
        return 

        _trusted = ~mask_evaluated[..., 0].squeeze()
        _mi_upper = torch.where(_trusted, mi_upper, float("nan"))
        _mi_lower = torch.where(_trusted, mi_lower, float("nan"))
        # self.plot(X, _mi_upper, "mi_upper.pdf")
        # self.plot(X, _mi_lower, "mi_lower.pdf")
        
        if decoupled:
            if mi_upper.min() < mi_lower.min():
                trusted &= ~mask_evaluated[..., 0].squeeze()
                alpha = torch.where(trusted, mi_upper, float("nan"))
                self.Y_mask[0] = True
            else:
                trusted &= ~mask_evaluated[..., 1].squeeze()
                alpha = torch.where(trusted, mi_lower, float("nan"))
                self.Y_mask[1] = True
        else:
            trusted &= ~mask_evaluated[..., 0].squeeze()
            alpha = torch.where(trusted, mi_upper + mi_lower, float("nan"))
            self.Y_mask.fill_(True)
        indices = torch.unravel_index(argmin(alpha, dim=None), alpha.shape)
        indices = torch.stack(indices)

        # self.plot(X, alpha, "alpha.pdf")
        return indices, self.Y_mask


    def optimize_query(
        self,
        bounds: Tensor,
        decoupled: Tensor,
    ) -> tuple[Tensor, Tensor]:  # shape: [d_in], shape: [d_out]

        opt_X = self._optimize_samples_query(bounds)
        print(opt_X)
        next_X = self._optimize_alpha_query(bounds, opt_X)
        self.Y_mask.fill_(True)
        print(next_X)
        return next_X, self.Y_mask
        # opt_X = []
        # for idx in range(self.num_samples):
        #     opt_X.append(self._optimize_sample(bounds, idx))
        # opt_X = torch.stack(opt_X)
        # print(opt_X)
        # next_X = self._optimize_alpha(bounds, decoupled, opt_X)
        # self.Y_mask.fill_(True)
        # return next_X, self.Y_mask


    



    



    
        

    def plot(self, X: Tensor, values: Tensor, path: str):
        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(figsize=(6, 5))
        c = ax.contourf(X[..., 0], X[..., 1], values.detach(), levels=50)
        fig.colorbar(c, ax=ax)
        fig.tight_layout()
        fig.savefig(path)
        plt.clf(); plt.close()


    def _optimize_alpha(
        self,
        bounds: Tensor,  # shape: [2, d_in]
        opt_X: Tensor, # shape: [num_samples, d_in]
    ) -> Tensor:

        d = round(self.raw_samples ** (1 / self.d_in))
        grid = torch.stack(torch.meshgrid(
            [torch.linspace(*b, d) for b in bounds.T], indexing="ij"
        ), dim=-1).view(d**self.num_dims[0], d**self.num_dims[1], self.d_in)



    def _optimize_sample(
        self,
        bounds: Tensor,  # shape: [2, d_in]
        idx: int,
    ) -> Tensor:

        bounds_upper, bounds_lower = bounds.split(self.num_dims, dim=-1)
        bounds_upper = {"X": (bounds_upper[0], bounds_upper[1])}
        bounds_lower = {"X": (bounds_lower[0], bounds_lower[1])}

        def Y_upper(X_upper: Tensor, X_lower: Tensor) -> Tensor:
            X_upper = torch.atleast_2d(X_upper)
            X_lower = torch.atleast_2d(X_lower)
            X = torch.cat([X_upper, X_lower], dim=-1)
            Y = self.model_Y_upper.rsample(X, self.num_samples)
            return -Y[idx].squeeze()

        def Y_lower(X_upper: Tensor, X_lower: Tensor) -> Tensor:
            X_upper = torch.atleast_2d(X_upper)
            X_lower = torch.atleast_2d(X_lower)
            X = torch.cat([X_upper, X_lower], dim=-1)
            Y = self.model_Y_lower.rsample(X, self.num_samples)
            return -Y[idx].squeeze()

        opt_X = torch.zeros(self.d_in, dtype=torch.double)
        best_val = float("inf")
        for n in range(self.num_restarts):
            X_upper = unnormalize(torch.rand(
                self.num_dims[0], dtype=torch.double
            ), bounds=bounds[:, :self.num_dims[0]])
            p_upper = {"X": X_upper.requires_grad_(True)}
            p_lower = {"X": torch.rand(self.num_dims[1])}

            def closure():
                _X_upper = X_upper.unsqueeze(0).expand(self.raw_samples, -1)
                _X_lower = unnormalize(torch.rand(
                    self.raw_samples, self.num_dims[1], dtype=torch.double,
                ), bounds=bounds[:, self.num_dims[0]:])
                idx = Y_lower(_X_upper, _X_lower).argmin()
                p_lower["X"] = _X_lower[idx].requires_grad_(True)

                def _closure():
                    _loss = Y_lower(X_upper, p_lower["X"])
                    _grad, = torch.autograd.grad(
                        outputs=_loss, inputs=p_lower["X"]
                    )
                    return _loss.detach(), _grad.detach()
                
                _res = scipy_minimize(
                    closure=_closure,
                    parameters=p_lower,
                    bounds=bounds_lower,
                    options={"maxfun": 5},
                )

                # compute hypergrad
                g1, = torch.autograd.grad(
                    outputs=Y_upper(p_upper["X"], p_lower["X"]),
                    inputs=p_upper["X"], retain_graph=True,
                )
                g2, = torch.autograd.grad(
                    outputs=Y_upper(p_upper["X"], p_lower["X"]),
                    inputs=p_lower["X"], create_graph=True,
                )
                H = torch.autograd.functional.hessian(
                    func=(lambda X: Y_lower(p_upper["X"], X)),
                    inputs=p_lower["X"],
                )
                v = torch.linalg.solve(H, g2)
                g3, = torch.autograd.grad(
                    outputs=Y_lower(p_upper["X"], p_lower["X"]),
                    inputs=p_lower["X"], create_graph=True,
                )
                Pv, = torch.autograd.grad(
                    outputs=g3, inputs=p_upper["X"],
                    grad_outputs=v, retain_graph=True,
                )
                hypergrad = g1 - Pv
                loss = Y_upper(p_upper["X"], p_lower["X"])
                p_upper["X"].grad = hypergrad
                return loss.detach(), hypergrad.detach()

            res = scipy_minimize(
                closure=closure,
                parameters=p_upper,
                bounds=bounds_upper,
                options={"maxfun": 5},
            )
            if res.fval < best_val:
                best_val = res.fval
                opt_X = torch.cat(
                    [p_upper["X"].detach(), p_lower["X"].detach()]
                )
                
        return opt_X


    def _optimize_alpha(
        self,
        bounds: Tensor,  # shape: [2, d_in]
        decoupled: bool,
        opt_X: Tensor,  # shape: [num_samples, d_in]
    ) -> Tensor:

        _, bounds_lower = bounds.split(self.num_dims, dim=-1)
        bounds_upper = {"X": (bounds[0], bounds[1])}
        bounds_lower = {"X": (bounds_lower[0], bounds_lower[1])}

        def Y_upper(X: Tensor, X_lower: Tensor, opt_X: Tensor) -> Tensor:
            X = torch.atleast_2d(X)
            X_lower = torch.atleast_2d(X_lower)
            feas_X = torch.cat([
                X[:, :self.num_dims[0]].expand(self.num_samples, -1),
                X_lower
            ], dim=-1)
            perp_X = torch.cat([
                opt_X[:, :self.num_dims[0]],
                X[:, self.num_dims[0]:].expand(self.num_samples, -1),
            ], dim=-1)
            mi_upper = self._upper_mi_lower_bound(X, feas_X, opt_X)
            mi_lower = self._lower_mi_lower_bound(X, perp_X, opt_X)
            return mi_upper + mi_lower

        def Y_lower(X: Tensor, X_lower: Tensor, idx: int) -> Tensor:
            X_upper = torch.atleast_2d(X)[:, :self.num_dims[0]]
            X_lower = torch.atleast_2d(X_lower)
            X = torch.cat([X_upper, X_lower], dim=-1)
            Y = self.model_Y_lower.rsample(X, self.num_samples)
            return -Y[idx].squeeze()

        next_X = torch.zeros(self.d_in, dtype=torch.double)
        best_val = float("inf")
        for n in range(self.num_restarts):
            print(f"restart: {n}")
            X = unnormalize(torch.rand(self.d_in, dtype=torch.double),
                            bounds=bounds)
            p_upper = {"X": X.requires_grad_(True)}
            p_lower = {"X": torch.rand(self.num_dims[1])}

            def closure():
                X_lower = []
                for i in range(self.num_samples):
                    _X = X.unsqueeze(0).expand(self.raw_samples, -1)
                    _X_lower = unnormalize(torch.rand(
                        self.raw_samples, self.num_dims[1], dtype=torch.double,
                    ), bounds=bounds[:, self.num_dims[0]:])
                    idx = Y_lower(_X, _X_lower, i).argmin()
                    p_lower["X"] = _X_lower[idx].requires_grad_(True)

                    def _closure():
                        _loss = Y_lower(X, p_lower["X"], i)
                        _grad, = torch.autograd.grad(
                            outputs=_loss, inputs=p_lower["X"]
                        )
                        return _loss.detach(), _grad.detach()

                    _res = scipy_minimize(
                        closure=_closure,
                        parameters=p_lower,
                        bounds=bounds_lower,
                        options={"maxfun": 5},
                    )
                    X_lower.append(p_lower["X"].detach())
                X_lower = torch.stack(X_lower)
                p_lower["X_all"] = X_lower.requires_grad_(True)
                
                # compute hypergrad
                g1, = torch.autograd.grad(
                    outputs=Y_upper(p_upper["X"], p_lower["X_all"], opt_X),
                    inputs=p_upper["X"], retain_graph=True,
                )
                g2, = torch.autograd.grad(
                    outputs=Y_upper(p_upper["X"], p_lower["X_all"], opt_X),
                    inputs=p_lower["X_all"], create_graph=True,
                )

                def hessian(X: Tensor, X_lower: Tensor) -> Tensor:
                    hes = []
                    for idx in range(self.num_samples):
                        h = torch.autograd.functional.hessian(
                            func=(lambda x_lower: Y_lower(X, x_lower, idx)),
                            inputs=X_lower[idx],
                        ) # shape: [num_dims[1], num_dims[1]]
                        hes.append(h)
                    hes = torch.block_diag(*hes)
                    return hes

                H = hessian(p_upper["X"], p_lower["X_all"])
                v = torch.linalg.solve(H, g2)

                def grad(X: Tensor, X_lower: Tensor) -> Tensor:
                    grads = []
                    for idx in range(self.num_samples):
                        x_lower = X_lower[idx]
                        g, = torch.autograd.grad(
                            outputs=Y_lower(X, x_lower, idx),
                            inputs=x_lower, create_graph=True,
                        )
                        grads.append(g)
                    grads = torch.stack(grads)
                    return grads
                
                g3 = grad(p_upper["X"], p_lower["X_all"])
                Pv, = torch.autograd.grad(
                    outputs=g3, inputs=p_upper["X"],
                    grad_outputs=v, retain_graph=True,
                )
                hypergrad = g1 - Pv
                loss = Y_upper(p_upper["X"], p_lower["X_all"], opt_X)
                p_upper["X"].grad = hypergrad
                print(p_upper["X"].detach(), loss.detach())
                return loss.detach(), hypergrad.detach()

            res = scipy_minimize(
                closure=closure,
                parameters=p_upper,
                bounds=bounds_upper,
                options={"maxfun": 10},
            )
            if res.fval < best_val:
                best_val = res.fval
                next_X = p_upper["X"].detach()

        return next_X