# =============================================================================
# Joint Entropy Search
# =============================================================================

import torch
from torch import Tensor
from botorch.utils.transforms import unnormalize
from botorch.optim.core import scipy_minimize

from acquisitions.base import BiLevelAcquisition
from utils import RFFHybridModel, RFFModelList


# -----------------------------------------------------------------------------
# Bi-Level Joint Entropy Search (BLJES)
# -----------------------------------------------------------------------------

class BiLevelJointEntropySearch(BiLevelAcquisition):

    def __init__(
        self,
        num_dims: list[int],
        model_Y_upper: RFFModelList,
        model_Y_lower: RFFModelList,
        model_C_upper: RFFModelList | None = None,
        model_C_lower: RFFModelList | None = None,
        num_samples: int = 64,
        noisy_obs: bool = True,
        num_restarts: int | None = None,
        raw_samples: int | None = None,
    ) -> None:

        super().__init__(
            num_dims=num_dims,
            model_Y_upper=model_Y_upper,
            model_Y_lower=model_Y_lower,
            model_C_upper=model_C_upper,
            model_C_lower=model_C_lower,
            num_restarts=num_restarts,
            raw_samples=raw_samples
        )
        self.num_samples = num_samples
        self.noisy_obs = noisy_obs

    
    def optimize_pool(
        self,
        candidates: Tensor,  # shape: [*batch_shape, d_in]
        mask_evaluated: Tensor,  # shape: [*batch_shape, d_out]
        decoupled: bool = False,
    ) -> tuple[Tensor, Tensor]:  # shape: [2], shape: [d_out]

        mask = candidates.isnan().any(dim=-1)
        trusted = ~mask.clone()

        Y_lower = self.model_Y_lower.rsample(candidates, self.num_samples)
        if self.model_C_lower is not None:
            C_lower = self.model_C_lower.rsample(candidates, self.num_samples)
            trusted &= (C_lower >= 0.0).all(dim=-1)
        _Y_lower = torch.where(trusted, Y_lower.squeeze(-1), float("-inf"))
        feas_idx = _Y_lower.argmax(2, keepdim=True)
        feas_Y_lower = _Y_lower.gather(2, feas_idx)

        Y_upper = self.model_Y_upper.rsample(candidates, self.num_samples)
        if self.model_C_upper is not None:
            C_upper = self.model_C_upper.rsample(candidates, self.num_samples)
            trusted &= (C_upper >= 0.0).all(dim=-1)
        _Y_upper = torch.where(trusted, Y_upper.squeeze(-1), float("-inf"))
        feas_Y_upper = _Y_upper.gather(2, feas_idx)
        opt_idx = feas_Y_upper.argmax(1, keepdim=True)
        opt_Y_upper = feas_Y_upper.gather(1, opt_idx)
        opt_Y_lower = feas_Y_lower.gather(1, opt_idx)

        m_Y_upper = self.model_Y_upper.mean(candidates).squeeze(-1)
        v_Y_upper = self.model_Y_upper.var(candidates).squeeze(-1)
        m_Y_lower = self.model_Y_lower.mean(candidates).squeeze(-1)
        v_Y_lower = self.model_Y_lower.var(candidates).squeeze(-1)
        _m_Y_upper = m_Y_upper.unsqueeze(0).expand(self.num_samples, -1, -1)
        _v_Y_upper = v_Y_upper.unsqueeze(0).expand(self.num_samples, -1, -1)
        _m_Y_lower = m_Y_lower.unsqueeze(0).expand(self.num_samples, -1, -1)
        _v_Y_lower = v_Y_lower.unsqueeze(0).expand(self.num_samples, -1, -1)
        
        n_Y_upper = (self.model_Y_upper.rff_models[0].model.likelihood.noise
                     if self.noisy_obs else torch.tensor([[[0.0]]]))
        n_Y_lower = (self.model_Y_lower.rff_models[0].model.likelihood.noise
                     if self.noisy_obs else torch.tensor([[[0.0]]]))
        _Y_upper = _Y_upper + n_Y_upper.sqrt() * torch.rand_like(_Y_upper)
        _Y_lower = _Y_lower + n_Y_lower.sqrt() * torch.rand_like(_Y_lower)
        _v_Y_upper = _v_Y_upper + n_Y_upper
        _v_Y_lower = _v_Y_lower + n_Y_lower
            
        if self.model_C_upper is not None:
            pass
        if self.model_C_lower is not None:
            pass
        
        feas_m_Y_upper = _m_Y_upper.gather(2, feas_idx)  # shape: [num_samples, b1, 1]
        feas_v_Y_upper = _v_Y_upper.gather(2, feas_idx)  # shape: [num_samples, b1, 1]
        feas_m_Y_lower = _m_Y_lower.gather(2, feas_idx)  # shape: [num_samples, b1, 1]
        feas_v_Y_lower = _v_Y_lower.gather(2, feas_idx)  # shape: [num_samples, b1, 1]
        opt_m_Y_upper = feas_m_Y_upper.gather(1, opt_idx)  # shape: [num_samples, 1, 1]
        opt_v_Y_upper = feas_v_Y_upper.gather(1, opt_idx)  # shape: [num_samples, 1, 1]
        opt_m_Y_lower = feas_m_Y_lower.gather(1, opt_idx)  # shape: [num_samples, 1, 1]
        opt_v_Y_lower = feas_v_Y_lower.gather(1, opt_idx)  # shape: [num_samples, 1, 1]
        perp_idx = opt_idx.expand(-1, -1, mask.shape[1])
        perp_m_Y_lower = _m_Y_lower.gather(1, perp_idx)  # shape: [num_samples, 1, b2]
        perp_v_Y_lower = _v_Y_lower.gather(1, perp_idx)  # shape: [num_samples, 1, b2]

        cands_X = candidates.unsqueeze(0).expand(self.num_samples, -1, -1, -1)
        _feas_idx = feas_idx.unsqueeze(-1).expand(-1, -1, -1, cands_X.size(-1))
        _opt_idx = opt_idx.unsqueeze(-1).expand(-1, -1, -1, cands_X.size(-1))
        _perp_idx = _opt_idx.expand(-1, -1, mask.shape[1], -1)
        feas_X = cands_X.gather(2, _feas_idx)
        opt_X = feas_X.gather(1, _opt_idx)
        perp_X = cands_X.gather(1, _perp_idx)
    
        cov1_upper, cov2_upper, cov3_upper = [], [], []
        for cands_x, feas_x, opt_x in zip(cands_X, feas_X, opt_X):
            cov1 = self.model_Y_upper.cov(feas_x, cands_x).squeeze(-1)
            cov2 = self.model_Y_upper.cov(feas_x, opt_x).squeeze(-1)
            cov3 = self.model_Y_upper.cov(opt_x, cands_x).squeeze(-1)
            cov1 = cov1.diagonal(0, 0, 2).permute(2, 0, 1)
            cov2 = cov2.squeeze(2).expand_as(cov1)
            cov1_upper.append(cov1.squeeze())
            cov2_upper.append(cov2.squeeze())
            cov3_upper.append(cov3.squeeze())
        cov1_upper = torch.stack(cov1_upper)
        cov2_upper = torch.stack(cov2_upper)
        cov3_upper = torch.stack(cov3_upper)
        print(f"cov1_upper - nan: {cov1_upper.isnan().sum()}, inf: {cov1_upper.isinf().sum()}")
        print(f"cov2_upper - nan: {cov2_upper.isnan().sum()}, inf: {cov2_upper.isinf().sum()}")
        print(f"cov3_upper - nan: {cov3_upper.isnan().sum()}, inf: {cov3_upper.isinf().sum()}")
        _opt_v_Y_upper = opt_v_Y_upper.expand(-1, *mask.shape)
        covvec = torch.stack([cov1_upper, cov2_upper], dim=-1).unsqueeze(-1)
        covmat = torch.stack([
            torch.stack([_v_Y_upper, cov3_upper], dim=-1),
            torch.stack([cov3_upper, _opt_v_Y_upper], dim=-1),
        ], dim=-2)
        coeff = covvec.transpose(-2, -1) @ torch.linalg.pinv(covmat)
        diff = torch.stack([
            (_Y_upper - _m_Y_upper),
            (opt_Y_upper - opt_m_Y_upper).expand(-1, *mask.shape),
        ], dim=-1).unsqueeze(-1)
        m1_upper = feas_m_Y_upper + (coeff @ diff).squeeze()
        m1_upper.scatter_(1, perp_idx, opt_Y_upper.expand_as(perp_idx))
        if not self.noisy_obs:
            m1_upper.scatter_(2, feas_idx, feas_Y_upper)
        v1_upper = feas_v_Y_upper - (coeff @ covvec).squeeze()
        v1_upper.scatter_(2, feas_idx, n_Y_upper.expand_as(feas_idx))
        v1_upper.scatter_(1, perp_idx, 0.0)
        print(f"m1_upper - nan: {m1_upper.isnan().sum()}, inf: {m1_upper.isinf().sum()}")
        print(f"v1_upper - nan: {v1_upper.isnan().sum()}, neg: {(v1_upper < 0.0).sum()}, zero: {(v1_upper == 0.0).sum()}")
        coeff = cov2_upper / _opt_v_Y_upper
        diff = opt_Y_upper - opt_m_Y_upper
        m2_upper = feas_m_Y_upper + coeff * diff
        m2_upper.scatter_(1, perp_idx, opt_Y_upper.expand_as(perp_idx))
        v2_upper = feas_v_Y_upper - coeff * cov2_upper
        v2_upper.scatter_(1, perp_idx, 0.0)
        print(f"m2_upper - nan: {m2_upper.isnan().sum()}, inf: {m2_upper.isinf().sum()}")
        print(f"v2_upper - nan: {v2_upper.isnan().sum()}, neg: {(v2_upper < 0.0).sum()}, zero: {(v2_upper == 0.0).sum()}")
        coeff = cov3_upper / _opt_v_Y_upper
        m3_upper = _m_Y_upper + coeff * diff
        m3_upper.scatter_(1, perp_idx, opt_Y_upper.expand_as(perp_idx))
        v3_upper = _v_Y_upper - coeff * cov3_upper
        v3_upper.scatter_(1, perp_idx, n_Y_upper.expand_as(perp_idx))
        print(f"m3_upper - nan: {m3_upper.isnan().sum()}, inf: {m3_upper.isinf().sum()}")
        print(f"v3_upper - nan: {v3_upper.isnan().sum()}, neg: {(v3_upper < 0.0).sum()}, zero: {(v3_upper == 0.0).sum()}")
        g1_upper = (opt_Y_upper - m1_upper) / v1_upper.sqrt()
        g2_upper = (opt_Y_upper - m2_upper) / v2_upper.sqrt()
        g3_upper = (_Y_upper - m3_upper) / v3_upper.sqrt()
        g4_upper = (_Y_upper - _m_Y_upper) / _v_Y_upper.sqrt()
        print(f"g1_upper - nan: {g1_upper.isnan().sum()}, inf: {g1_upper.isinf().sum()}")
        print(f"g2_upper - nan: {g2_upper.isnan().sum()}, inf: {g2_upper.isinf().sum()}")
        print(f"g3_upper - nan: {g3_upper.isnan().sum()}, inf: {g3_upper.isinf().sum()}")
        print(f"g4_upper - nan: {g3_upper.isnan().sum()}, inf: {g3_upper.isinf().sum()}")
        normal = torch.distributions.Normal(0.0, 1.0)
        p1_upper = normal.cdf(g1_upper.nan_to_num(nan=float("-inf")))
        p2_upper = normal.cdf(g2_upper.nan_to_num(nan=float("-inf")))
        p3_upper = normal.cdf(g3_upper.nan_to_num(nan=float("-inf"))) / v3_upper.sqrt()
        p4_upper = normal.cdf(g4_upper.nan_to_num(nan=float("-inf"))) / _v_Y_upper.sqrt()
        q_upper = (p1_upper * p3_upper) / (p2_upper * p4_upper)
        mi_upper = torch.log(q_upper.clamp_min_(1e-12)).nanmean(dim=0)
        mi_upper.clamp_min_(0.0)
        
        cov1_lower, cov2_lower, cov3_lower = [], [], []
        for cands_x, perp_x, opt_x in zip(cands_X, perp_X, opt_X):
            cov1 = self.model_Y_lower.cov(perp_x, cands_x).squeeze(-1)
            cov2 = self.model_Y_lower.cov(perp_x, opt_x).squeeze(-1)
            cov3 = self.model_Y_lower.cov(opt_x, cands_x).squeeze(-1)
            cov1 = cov1.diagonal(0, 1, 3).permute(1, 2, 0)
            cov2 = cov2.squeeze(2).expand_as(cov1)
            cov1_lower.append(cov1.squeeze())
            cov2_lower.append(cov2.squeeze())
            cov3_lower.append(cov3.squeeze())
        cov1_lower = torch.stack(cov1_lower)
        cov2_lower = torch.stack(cov2_lower)
        cov3_lower = torch.stack(cov3_lower)
        print(f"cov1_lower - nan: {cov1_lower.isnan().sum()}, inf: {cov1_lower.isinf().sum()}")
        print(f"cov2_lower - nan: {cov2_lower.isnan().sum()}, inf: {cov2_lower.isinf().sum()}")
        print(f"cov3_lower - nan: {cov3_lower.isnan().sum()}, inf: {cov3_lower.isinf().sum()}")
        hori_idx = feas_idx.gather(1, opt_idx).expand(-1, mask.shape[0], -1)
        _opt_v_Y_lower = opt_v_Y_lower.expand(-1, *mask.shape)
        covvec = torch.stack([cov1_lower, cov2_lower], dim=-1).unsqueeze(-1)
        covmat = torch.stack([
            torch.stack([_v_Y_lower, cov3_lower], dim=-1),
            torch.stack([cov3_lower, _opt_v_Y_lower], dim=-1),
        ], dim=-2)
        coeff = covvec.transpose(-2, -1) @ torch.linalg.pinv(covmat)
        diff = torch.stack([
            (_Y_lower - _m_Y_lower),
            (opt_Y_lower - opt_m_Y_lower).expand(-1, *mask.shape),
        ], dim=-1).unsqueeze(-1)
        m1_lower = perp_m_Y_lower + (coeff @ diff).squeeze()
        if not self.noisy_obs:
            m1_lower.scatter_(1, perp_idx, _Y_lower.gather(1, perp_idx))
        m1_lower.scatter_(2, hori_idx, opt_Y_lower.expand_as(hori_idx))
        v1_lower = perp_v_Y_lower - (coeff @ covvec).squeeze()
        v1_lower.scatter_(1, perp_idx, n_Y_lower.expand_as(perp_idx))
        v1_lower.scatter_(2, hori_idx, 0.0).clamp_min_(0.0)
        print(f"m1_lower - nan: {m1_lower.isnan().sum()}, inf: {m1_lower.isinf().sum()}")
        print(f"v1_lower - nan: {v1_lower.isnan().sum()}, neg: {(v1_lower < 0.0).sum()}")
        coeff = cov2_lower / _opt_v_Y_lower
        diff = opt_Y_lower - opt_m_Y_lower
        m2_lower = perp_m_Y_lower + coeff * diff
        m2_lower.scatter_(2, hori_idx, opt_Y_lower.expand_as(hori_idx))
        v2_lower = perp_v_Y_lower - coeff * cov2_lower
        v2_lower.scatter_(2, hori_idx, 0.0)
        print(f"m2_lower - nan: {m2_lower.isnan().sum()}, inf: {m2_lower.isinf().sum()}")
        print(f"v2_lower - nan: {v2_lower.isnan().sum()}, neg: {(v2_lower < 0.0).sum()}")
        coeff = cov3_lower / _opt_v_Y_lower
        m3_lower = _m_Y_lower + coeff * diff
        m3_lower.scatter_(1, opt_idx, opt_Y_lower)
        v3_lower = _v_Y_lower - coeff * cov3_lower
        v3_lower.scatter_(1, opt_idx, n_Y_lower.expand_as(opt_idx))
        v3_lower.clamp_min_(0.0)
        print(f"m3_lower - nan: {m3_lower.isnan().sum()}, inf: {m3_lower.isinf().sum()}")
        print(f"v3_lower - nan: {v3_lower.isnan().sum()}, neg: {(v3_lower < 0.0).sum()}")
        g1_lower = (opt_Y_lower - m1_lower) / v1_lower.sqrt()
        g2_lower = (opt_Y_lower - m2_lower) / v2_lower.sqrt()
        g3_lower = (_Y_lower - m3_lower) / v3_lower.sqrt()
        g4_lower = (_Y_lower - _m_Y_lower) / _v_Y_lower.sqrt()
        print(f"g1_lower - nan: {g1_lower.isnan().sum()}, inf: {g1_lower.isinf().sum()}")
        print(f"g2_lower - nan: {g2_lower.isnan().sum()}, inf: {g2_lower.isinf().sum()}")
        print(f"g3_lower - nan: {g3_lower.isnan().sum()}, inf: {g3_lower.isinf().sum()}")
        print(f"g4_lower - nan: {g4_lower.isnan().sum()}, inf: {g4_lower.isinf().sum()}")
        normal = torch.distributions.Normal(0.0, 1.0)
        p1_lower = normal.cdf(g1_lower.nan_to_num(nan=float("-inf")))
        p2_lower = normal.cdf(g2_lower.nan_to_num(nan=float("-inf")))
        p3_lower = normal.cdf(g3_lower.nan_to_num(nan=float("-inf"))) / v3_lower.sqrt()
        p4_lower = normal.cdf(g4_lower.nan_to_num(nan=float("-inf"))) / _v_Y_lower.sqrt()
        q_lower = (p1_lower * p3_lower) / (p2_lower * p4_lower)
        mi_lower = torch.log(q_lower.clamp_min_(1e-12)).nanmean(dim=0)
        mi_lower.clamp_min_(0.0)

        if decoupled:
            alpha = mi_upper if mi_upper.max() > mi_lower.max() else mi_lower
            idx = 0 if mi_upper.max() > mi_lower.max() else 1
            idx_con = (list(range(2, 2+self.num_constraints[0])) if idx == 0
                       else list(range(2+self.num_constraints[0], self.d_out)))
            trusted &= ~mask_evaluated[..., idx].squeeze()
            alpha = torch.where(trusted, alpha, float("-inf"))
            indices = torch.unravel_index(alpha.argmax(), alpha.shape)
            indices = torch.stack(indices)
            self.Y_mask[idx] = True
            self.Y_mask[idx_con] = True
        else:
            trusted &= ~mask_evaluated[..., 0].squeeze()
            alpha = torch.where(trusted, mi_upper + mi_lower, float("-inf"))
            indices = torch.unravel_index(alpha.argmax(), alpha.shape)
            indices = torch.stack(indices)
            self.Y_mask.fill_(True)
        return indices, self.Y_mask


    def optimize_query(
        self,
        bounds: Tensor,  # shape: [2, d_in]
        decoupled: bool = False,
    ) -> tuple[Tensor, Tensor]:  # shape: [d_in], shape: [d_out]

        opt_X = []
        for idx in range(self.num_samples):
            opt_X.append(self._optimize_sample(bounds, idx))
        opt_X = torch.stack(opt_X)
        print(opt_X)
        next_X = self._optimize_alpha(bounds, decoupled, opt_X)
        self.Y_mask.fill_(True)
        return next_X, self.Y_mask


    def _optimize_sample(
        self,
        bounds: Tensor,  # shape: [2, d_in]
        idx: int,
    ) -> Tensor:

        bounds_upper, bounds_lower = bounds.split(self.num_dims, dim=-1)
        bounds_upper = {"X": (bounds_upper[0], bounds_upper[1])}
        bounds_lower = {"X": (bounds_lower[0], bounds_lower[1])}

        def Y_upper(X_upper: Tensor, X_lower: Tensor) -> Tensor:
            X_upper = torch.atleast_2d(X_upper)
            X_lower = torch.atleast_2d(X_lower)
            X = torch.cat([X_upper, X_lower], dim=-1)
            Y = self.model_Y_upper.rsample(X, self.num_samples)
            return -Y[idx].squeeze()

        def Y_lower(X_upper: Tensor, X_lower: Tensor) -> Tensor:
            X_upper = torch.atleast_2d(X_upper)
            X_lower = torch.atleast_2d(X_lower)
            X = torch.cat([X_upper, X_lower], dim=-1)
            Y = self.model_Y_lower.rsample(X, self.num_samples)
            return -Y[idx].squeeze()

        opt_X = torch.zeros(self.d_in, dtype=torch.double)
        best_val = float("inf")
        for n in range(self.num_restarts):
            X_upper = unnormalize(torch.rand(
                self.num_dims[0], dtype=torch.double
            ), bounds=bounds[:, :self.num_dims[0]])
            p_upper = {"X": X_upper.requires_grad_(True)}
            p_lower = {"X": torch.rand(self.num_dims[1])}

            def closure():
                _X_upper = X_upper.unsqueeze(0).expand(self.raw_samples, -1)
                _X_lower = unnormalize(torch.rand(
                    self.raw_samples, self.num_dims[1], dtype=torch.double,
                ), bounds=bounds[:, self.num_dims[0]:])
                idx = Y_lower(_X_upper, _X_lower).argmin()
                p_lower["X"] = _X_lower[idx].requires_grad_(True)

                def _closure():
                    _loss = Y_lower(X_upper, p_lower["X"])
                    _grad, = torch.autograd.grad(
                        outputs=_loss, inputs=p_lower["X"]
                    )
                    return _loss.detach(), _grad.detach()
                
                _res = scipy_minimize(
                    closure=_closure,
                    parameters=p_lower,
                    bounds=bounds_lower,
                    options={"maxfun": 5},
                )

                # compute hypergrad
                g1, = torch.autograd.grad(
                    outputs=Y_upper(p_upper["X"], p_lower["X"]),
                    inputs=p_upper["X"], retain_graph=True,
                )
                g2, = torch.autograd.grad(
                    outputs=Y_upper(p_upper["X"], p_lower["X"]),
                    inputs=p_lower["X"], create_graph=True,
                )
                H = torch.autograd.functional.hessian(
                    func=(lambda X: Y_lower(p_upper["X"], X)),
                    inputs=p_lower["X"],
                )
                v = torch.linalg.solve(H, g2)
                g3, = torch.autograd.grad(
                    outputs=Y_lower(p_upper["X"], p_lower["X"]),
                    inputs=p_lower["X"], create_graph=True,
                )
                Pv, = torch.autograd.grad(
                    outputs=g3, inputs=p_upper["X"],
                    grad_outputs=v, retain_graph=True,
                )
                hypergrad = g1 - Pv
                loss = Y_upper(p_upper["X"], p_lower["X"])
                p_upper["X"].grad = hypergrad
                return loss.detach(), hypergrad.detach()

            res = scipy_minimize(
                closure=closure,
                parameters=p_upper,
                bounds=bounds_upper,
                options={"maxfun": 5},
            )
            if res.fval < best_val:
                best_val = res.fval
                opt_X = torch.cat(
                    [p_upper["X"].detach(), p_lower["X"].detach()]
                )
                
        return opt_X


    def _optimize_alpha(
        self,
        bounds: Tensor,  # shape: [2, d_in]
        decoupled: bool,
        opt_X: Tensor,  # shape: [num_samples, d_in]
    ) -> Tensor:

        _, bounds_lower = bounds.split(self.num_dims, dim=-1)
        bounds_upper = {"X": (bounds[0], bounds[1])}
        bounds_lower = {"X": (bounds_lower[0], bounds_lower[1])}

        def Y_upper(X: Tensor, X_lower: Tensor, opt_X: Tensor) -> Tensor:
            X = torch.atleast_2d(X)
            X_lower = torch.atleast_2d(X_lower)
            print(torch.cat([X.expand(self.num_samples, -1), X_lower, opt_X], dim=-1))
            feas_X = torch.cat([
                X[:, :self.num_dims[0]].expand(self.num_samples, -1),
                X_lower
            ], dim=-1)
            perp_X = torch.cat([
                opt_X[:, :self.num_dims[0]],
                X[:, self.num_dims[0]:].expand(self.num_samples, -1),
            ], dim=-1)

            Y_upper = self.model_Y_upper.rsample(X, self.num_samples).squeeze(-1)
            
            opt_Y_upper = self.model_Y_upper.rsample(opt_X, self.num_samples)
            print(opt_Y_upper.shape)
            opt_Y_upper = opt_Y_upper.squeeze().diag().unsqueeze(-1)
            m_Y_upper = self.model_Y_upper.mean(X).expand(self.num_samples, -1)
            v_Y_upper = self.model_Y_upper.var(X).expand(self.num_samples, -1)
            feas_m_Y_upper = self.model_Y_upper.mean(feas_X)
            feas_v_Y_upper = self.model_Y_upper.var(feas_X)
            opt_m_Y_upper = self.model_Y_upper.mean(opt_X)
            opt_v_Y_upper = self.model_Y_upper.var(opt_X)
            cov1_upper = self.model_Y_upper.cov(feas_X, X)
            cov2_upper = self.model_Y_upper.cov(feas_X, opt_X)
            cov2_upper = cov2_upper.diagonal(0, 0, 1).permute(1, 0)
            cov3_upper = self.model_Y_upper.cov(opt_X, X).squeeze(-1)
            print(f"cov1_upper - nan: {cov1_upper.isnan().sum()}, inf: {cov1_upper.isinf().sum()}"
                + f"cov2_upper - nan: {cov2_upper.isnan().sum()}, inf: {cov2_upper.isinf().sum()}"
                + f"cov3_upper - nan: {cov3_upper.isnan().sum()}, inf: {cov3_upper.isinf().sum()}")

            if self.noisy_obs:
                n_Y_upper = self.model_Y_upper.rff_models[0].model.likelihood.noise
                Y_upper = Y_upper + n_Y_upper.sqrt() * torch.rand_like(Y_upper)
                v_Y_upper = v_Y_upper + n_Y_upper

            covvec = torch.cat([cov1_upper, cov2_upper.unsqueeze(-1)], dim=-2)
            covmat = torch.stack([
                torch.cat([v_Y_upper, cov3_upper], dim=-1),
                torch.cat([cov3_upper, opt_v_Y_upper], dim=-1),
            ], dim=-2)
            coeff = covvec.transpose(-2, -1) @ torch.linalg.pinv(covmat)
            diff = torch.cat([
                (Y_upper - m_Y_upper), (opt_Y_upper - opt_m_Y_upper)
            ], dim=-1).unsqueeze(-1)
            m1_upper = feas_m_Y_upper + (coeff @ diff).squeeze(-1)
            v1_upper = feas_v_Y_upper - (coeff @ diff).squeeze(-1)
            print(f"m1_upper - nan: {m1_upper.isnan().sum()}, inf: {m1_upper.isinf().sum()}"
                + f"v1_upper - nan: {v1_upper.isnan().sum()}, neg: {(v1_upper < 0.0).sum()}, zero: {(v1_upper == 0.0).sum()}")
            v1_upper.clamp_min_(1e-6)
            coeff = cov2_upper / opt_v_Y_upper
            diff = opt_Y_upper - opt_m_Y_upper
            m2_upper = feas_m_Y_upper + coeff * diff
            v2_upper = feas_v_Y_upper - coeff * cov2_upper
            print(f"m2_upper - nan: {m2_upper.isnan().sum()}, inf: {m2_upper.isinf().sum()}"
                + f"v2_upper - nan: {v2_upper.isnan().sum()}, neg: {(v2_upper < 0.0).sum()}, zero: {(v2_upper == 0.0).sum()}")
            v2_upper.clamp_min_(1e-6)
            coeff = cov3_upper / opt_v_Y_upper
            m3_upper = m_Y_upper + coeff * diff
            v3_upper = v_Y_upper - coeff * cov3_upper
            print(f"m3_upper - nan: {m3_upper.isnan().sum()}, inf: {m3_upper.isinf().sum()}"
                + f"v3_upper - nan: {v3_upper.isnan().sum()}, neg: {(v3_upper < 0.0).sum()}, zero: {(v3_upper == 0.0).sum()}")
            v3_upper.clamp_min_(1e-6)
            g1_upper = (opt_Y_upper - m1_upper) / v1_upper.sqrt()
            g2_upper = (opt_Y_upper - m2_upper) / v2_upper.sqrt()
            g3_upper = (Y_upper - m3_upper) / v3_upper.sqrt()
            g4_upper = (Y_upper - m_Y_upper) / v_Y_upper.sqrt()
            normal = torch.distributions.Normal(0.0, 1.0)
            p1_upper = normal.cdf(g1_upper)
            p2_upper = normal.cdf(g2_upper)
            p3_upper = normal.cdf(g3_upper) / v3_upper.sqrt()
            p4_upper = normal.cdf(g4_upper) / v_Y_upper.sqrt()
            q_upper = (p1_upper * p3_upper) / (p2_upper * p4_upper)
            mi_upper = torch.log(q_upper).nanmean(dim=0)

            Y_lower = self.model_Y_lower.rsample(X, self.num_samples).squeeze(-1)
            opt_Y_lower = self.model_Y_lower.rsample(opt_X, self.num_samples)
            opt_Y_lower = opt_Y_lower.squeeze().diag().unsqueeze(-1)
            m_Y_lower = self.model_Y_lower.mean(X).expand(self.num_samples, -1)
            v_Y_lower = self.model_Y_lower.var(X).expand(self.num_samples, -1)
            opt_m_Y_lower = self.model_Y_lower.mean(opt_X)
            opt_v_Y_lower = self.model_Y_lower.var(opt_X)
            perp_m_Y_lower = self.model_Y_lower.mean(perp_X)
            perp_v_Y_lower = self.model_Y_lower.var(perp_X)
            cov1_lower = self.model_Y_lower.cov(perp_X, X)
            cov2_lower = self.model_Y_lower.cov(perp_X, opt_X)
            cov2_lower = cov2_lower.diagonal(0, 0, 1).permute(1, 0)
            cov3_lower = self.model_Y_lower.cov(opt_X, X).squeeze(-1)

            if self.noisy_obs:
                n_Y_lower = self.model_Y_lower.rff_models[0].model.likelihood.noise
                Y_lower = Y_lower + n_Y_lower.sqrt() * torch.rand_like(Y_lower)
                v_Y_lower = v_Y_lower + n_Y_lower

            covvec = torch.cat([cov1_lower, cov2_lower.unsqueeze(-1)], dim=-2)
            covmat = torch.stack([
                torch.cat([v_Y_lower, cov3_lower], dim=-1),
                torch.cat([cov3_lower, opt_v_Y_lower], dim=-1),
            ], dim=-2)
            coeff = covvec.transpose(-2, -1) @ torch.linalg.pinv(covmat)
            diff = torch.cat([
                (Y_lower - m_Y_lower), (opt_Y_lower - opt_m_Y_lower)
            ], dim=-1).unsqueeze(-1)
            m1_lower = perp_m_Y_lower + (coeff @ diff).squeeze(-1)
            v1_lower = perp_v_Y_lower - (coeff @ diff).squeeze(-1)
            v1_lower.clamp_min_(1e-6)
            coeff = cov2_lower / opt_v_Y_lower
            diff = opt_Y_lower - opt_m_Y_lower
            m2_lower = perp_m_Y_lower + coeff * diff
            v2_lower = perp_v_Y_lower - coeff * cov2_lower
            v2_lower.clamp_min_(1e-6)
            coeff = cov3_lower / opt_v_Y_lower
            m3_lower = m_Y_lower + coeff * diff
            v3_lower = v_Y_lower - coeff * cov3_lower
            v3_lower.clamp_min_(1e-6)
            g1_lower = (opt_Y_lower - m1_lower) / v1_lower.sqrt()
            g2_lower = (opt_Y_lower - m2_lower) / v2_lower.sqrt()
            g3_lower = (Y_lower - m3_lower) / v3_lower.sqrt()
            g4_lower = (Y_lower - m_Y_lower) / v_Y_lower.sqrt()
            normal = torch.distributions.Normal(0.0, 1.0)
            p1_lower = normal.cdf(g1_lower)
            p2_lower = normal.cdf(g2_lower)
            p3_lower = normal.cdf(g3_lower) / v3_lower.sqrt()
            p4_lower = normal.cdf(g4_lower) / v_Y_lower.sqrt()
            q_lower = (p1_lower * p3_lower) / (p2_lower * p4_lower)
            mi_lower = torch.log(q_lower).nanmean(dim=0)
            return -(mi_upper + mi_lower)

        def Y_lower(X: Tensor, X_lower: Tensor, idx: int) -> Tensor:
            X_upper = torch.atleast_2d(X)[:, :self.num_dims[0]]
            X_lower = torch.atleast_2d(X_lower)
            X = torch.cat([X_upper, X_lower], dim=-1)
            Y = self.model_Y_lower.rsample(X, self.num_samples)
            return -Y[idx].squeeze()

        next_X = torch.zeros(self.d_in, dtype=torch.double)
        best_val = float("inf")
        for n in range(self.num_restarts):
            print(f"restart: {n}")
            X = unnormalize(torch.rand(self.d_in, dtype=torch.double),
                            bounds=bounds)
            p_upper = {"X": X.requires_grad_(True)}
            p_lower = {"X": torch.rand(self.num_dims[1])}

            def closure():
                X_lower = []
                for i in range(self.num_samples):
                    _X = X.unsqueeze(0).expand(self.raw_samples, -1)
                    _X_lower = unnormalize(torch.rand(
                        self.raw_samples, self.num_dims[1], dtype=torch.double,
                    ), bounds=bounds[:, self.num_dims[0]:])
                    idx = Y_lower(_X, _X_lower, i).argmin()
                    p_lower["X"] = _X_lower[idx].requires_grad_(True)

                    def _closure():
                        _loss = Y_lower(X, p_lower["X"], i)
                        _grad, = torch.autograd.grad(
                            outputs=_loss, inputs=p_lower["X"]
                        )
                        return _loss.detach(), _grad.detach()

                    _res = scipy_minimize(
                        closure=_closure,
                        parameters=p_lower,
                        bounds=bounds_lower,
                        options={"maxfun": 5},
                    )
                    X_lower.append(p_lower["X"].detach())
                X_lower = torch.stack(X_lower)
                p_lower["X_all"] = X_lower.requires_grad_(True)
                
                # compute hypergrad
                g1, = torch.autograd.grad(
                    outputs=Y_upper(p_upper["X"], p_lower["X_all"], opt_X),
                    inputs=p_upper["X"], retain_graph=True,
                )
                g2, = torch.autograd.grad(
                    outputs=Y_upper(p_upper["X"], p_lower["X_all"], opt_X),
                    inputs=p_lower["X_all"], create_graph=True,
                )

                def hessian(X: Tensor, X_lower: Tensor) -> Tensor:
                    hes = []
                    for idx in range(self.num_samples):
                        h = torch.autograd.functional.hessian(
                            func=(lambda x_lower: Y_lower(X, x_lower, idx)),
                            inputs=X_lower[idx],
                        ) # shape: [num_dims[1], num_dims[1]]
                        hes.append(h)
                    hes = torch.block_diag(*hes)
                    return hes

                H = hessian(p_upper["X"], p_lower["X_all"])
                v = torch.linalg.solve(H, g2)

                def grad(X: Tensor, X_lower: Tensor) -> Tensor:
                    grads = []
                    for idx in range(self.num_samples):
                        x_lower = X_lower[idx]
                        g, = torch.autograd.grad(
                            outputs=Y_lower(X, x_lower, idx),
                            inputs=x_lower, create_graph=True,
                        )
                        grads.append(g)
                    grads = torch.stack(grads)
                    return grads
                
                g3 = grad(p_upper["X"], p_lower["X_all"])
                Pv, = torch.autograd.grad(
                    outputs=g3, inputs=p_upper["X"],
                    grad_outputs=v, retain_graph=True,
                )
                hypergrad = g1 - Pv
                loss = Y_upper(p_upper["X"], p_lower["X_all"], opt_X)
                p_upper["X"].grad = hypergrad
                print(p_upper["X"].detach(), loss.detach())
                return loss.detach(), hypergrad.detach()

            res = scipy_minimize(
                closure=closure,
                parameters=p_upper,
                bounds=bounds_upper,
                options={"maxfun": 5},
            )
            if res.fval < best_val:
                best_val = res.fval
                next_X = p_upper["X"].detach()

        return next_X
                    