# =============================================================================
# BILBO
# =============================================================================

import math
import torch
from torch import Tensor

from acquisitions.base import BiLevelAcquisition
from utils import RFFModelList


# Utility functions -----------------------------------------------------------

def argmax(input: Tensor, dim: int | None = None) -> Tensor:
    _input = input.clone()
    idx = _input.nan_to_num(float("-inf")).argmax(dim=dim, keepdim=True)
    return idx


# -----------------------------------------------------------------------------
# Bi-Level Upper Confidence Bound (BILBO)
# -----------------------------------------------------------------------------

class BiLevelUpperConfidenceBound(BiLevelAcquisition):

    def __init__(
        self,
        num_dims: list[int],
        model_Y_upper: RFFModelList,
        model_Y_lower: RFFModelList,
        model_C_upper: RFFModelList | None = None,
        model_C_lower: RFFModelList | None = None,
        delta: float = 0.05
    ) -> None:

        super().__init__(
            num_dims=num_dims,
            model_Y_upper=model_Y_upper,
            model_Y_lower=model_Y_lower,
            model_C_upper=model_C_upper,
            model_C_lower=model_C_lower,
        )
        self.delta = delta


    def optimize_pool(
        self,
        candidates: Tensor,  # shape: [*batch_shape, d_in]
        mask_evaluated: Tensor,  # shape: [*batch_shape, d_out]
        decoupled: bool = False,
    ) -> tuple[Tensor, Tensor]:  # shape: [2], shape: [d_out]

        num_pts = self.d_out * mask_evaluated.numel()
        coeff = (mask_evaluated.sum() * math.pi) ** 2 / (6 * self.delta)
        beta = 2 * math.log(coeff * num_pts)
        mask = candidates.isnan().any(dim=-1)
        trusted = ~mask.clone()

        if self.model_C_upper is not None:
            m_C_upper = self.model_C_upper.mean(candidates)
            v_C_upper = self.model_C_upper.var(candidates)
            UCB_C_upper = m_C_upper + math.sqrt(beta) * v_C_upper.sqrt()
            trusted &= (UCB_C_upper >= 0.0).all(dim=-1)
            
        if self.model_C_lower is not None:
            m_C_lower = self.model_C_lower.mean(candidates)
            v_C_lower = self.model_C_lower.var(candidates)
            UCB_C_lower = m_C_lower + math.sqrt(beta) * v_C_lower.sqrt()
            trusted &= (UCB_C_lower >= 0.0).all(dim=-1)

        m_Y_lower = self.model_Y_lower.mean(candidates).squeeze(-1)
        v_Y_lower = self.model_Y_lower.var(candidates).squeeze(-1)
        UCB_Y_lower = m_Y_lower + math.sqrt(beta) * v_Y_lower.sqrt()
        LCB_Y_lower = m_Y_lower - math.sqrt(beta) * v_Y_lower.sqrt()
        feas_idx = argmax(UCB_Y_lower, dim=1)
        trusted &= (UCB_Y_lower >= LCB_Y_lower.gather(1, feas_idx))

        if trusted.sum() == 0:
            idx = torch.randint(self.d_out, (1,))
            mask |= mask_evaluated[..., idx].squeeze()
            valid_idx = (~mask).nonzero(as_tuple=False)
            perm = torch.randperm(valid_idx.size(0))
            indices = valid_idx[perm[:1]].view(-1)
            if decoupled:
                self.Y_mask[idx] = True
            else:
                self.Y_mask.fill_(True)
            return indices, self.Y_mask

        m_Y_upper = self.model_Y_upper.mean(candidates).squeeze(-1)
        v_Y_upper = self.model_Y_upper.var(candidates).squeeze(-1)
        UCB_Y_upper = m_Y_upper + math.sqrt(beta) * v_Y_upper.sqrt()
        trusted &= ~mask_evaluated.all(dim=-1)
        alpha = torch.where(trusted, UCB_Y_upper, float("nan"))
        # self.plot(candidates, alpha, "alpha.pdf")
        indices = torch.unravel_index(argmax(alpha), alpha.shape)
        indices = torch.stack(indices)
        feas_v_Y_lower = v_Y_lower.gather(1, feas_idx)

        if decoupled:
            _feas_v_Y_lower = feas_v_Y_lower.expand_as(v_Y_lower).clone()
            _feas_v_Y_lower.scatter_(1, feas_idx, 0.0)
            r_Y_upper = 2 * math.sqrt(beta) * v_Y_upper.sqrt()
            r_Y_lower = 2 * math.sqrt(beta) * v_Y_lower.sqrt()
            r_Y_lower += 2 * math.sqrt(beta) * _feas_v_Y_lower.sqrt()
            regrets = [r_Y_upper[*indices], r_Y_lower[*indices]]
            regrets = torch.tensor(regrets)
            if self.model_C_upper is not None:
                r_C_upper = 2 * math.sqrt(beta) * v_C_upper.sqrt()
                regrets = torch.cat([regrets, r_C_upper[*indices, :]], dim=-1)
            if self.model_C_lower is not None:
                r_C_lower = 2 * math.sqrt(beta) * v_C_lower.sqrt()
                regrets = torch.cat([regrets, r_C_lower[*indices, :]], dim=-1)
            self.Y_mask[regrets.argmax()] = True
        else:
            self.Y_mask.fill_(True)
        if self.Y_mask[1] & (feas_v_Y_lower[indices[0]] > v_Y_lower[*indices]):
            indices[1] = feas_idx[indices[0]]
        return indices, self.Y_mask


    def plot(self, X: Tensor, values: Tensor, path: str):
        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(figsize=(6, 5))
        c = ax.contourf(X[..., 0], X[..., 1], values.detach(), levels=50)
        fig.colorbar(c, ax=ax)
        fig.tight_layout()
        fig.savefig(path)
        plt.clf(); plt.close()