import torch

from kvpress.attention_patch import search_hyperplane

# def fast_search_hyperplane(X, max_iter: int = 1000):
#     """
#     Given a tensor X of shape (bsz, seq_len, head_dim), search for an hyperplane Y (bsz, head_dim)
#     such that for every i, <X[:, i], Y> <= 0. Returns - 1e5 * Y / ||Y|| ** 2 to ensure exp(<X, Y>) = 0
#     Raises a ValueError if no such hyperplane is found
#     """
#
#     y = X.mean(1)  # this initialization is enough for most cases
#     mask = torch.bmm(X, y.unsqueeze(-1)) <= 0
#     if not mask.any():
#         return -1e5 * y / y.norm(dim=-1, keepdim=True) ** 2
#     Y = torch.zeros_like(X)
#     Y[mask.squeeze(-1)] = y
#     Y += (X * mask).sum(1) / mask.sum(1).clamp(min=1)
#     return Y
    # raise ValueError("Could not find fake keys such that for every query q, exp(<q, k>) = 0")
def fast_search_hyperplane(X: torch.Tensor):
    """
    Constructs a Y_out such that <X[i], Y_out[i]> <= 0 for all i.
    X: Tensor of shape (bsz, seq_len, head_dim)
    Returns:
        Y_out: Tensor of shape (bsz, seq_len, head_dim)
    """
    Y = X.mean(dim=1, keepdim=True)  # (bsz, 1, head_dim)
    dot = (X * Y).sum(dim=-1, keepdim=True)  # (bsz, seq_len, 1)
    sign = torch.where(dot <= 0, -1.0, 1.0)  # (bsz, seq_len, 1)
    Y_out = sign * Y  # (bsz, seq_len, head_dim)

    return -1e5 * Y_out / Y_out.norm(dim=-1, keepdim=True) ** 2

def test_search_hyperplane():
    bsz, seq_len, head_dim = 50, 500, 128
    while True:
        X = torch.randn(bsz, seq_len, head_dim) - 0.2
        Y = fast_search_hyperplane(X)
        assert torch.exp(torch.bmm(X, Y.transpose(1,2))).max() == 0

def test_fast_search_hyperplane():
    bsz, seq_len, head_dim = 50, 500, 128
    X = torch.rand(bsz, seq_len, head_dim) - 0.1
    Y, nY = fast_search_hyperplane(X)

    assert torch.exp(torch.bmm(X, Y.unsqueeze(-1))).max() == 0

if __name__ == "__main__":
    test_search_hyperplane()
    # test_fast_search_hyperplane()