import torch 


def compute_lsh_conditional_attention(z1, z2, raw_score, lsh_attention):

    """ 
    inputs: z1, z2, raw_scores

    outputs: Conditional attention matrixes, Mxy, Mxx, Myy, Myx
    """

    n = int(raw_score.shape[0] / 2)

    # sample attention maps - keys
    Kxy = torch.exp(raw_score[:n, :n])
    Kxx = torch.exp(raw_score[:n, n:])
    Kyy = torch.exp(raw_score[n:, :n])
    Kyx = torch.exp(raw_score[n:, n:])


    # perform lsh_attention and condition
    Mxy, _ = lsh_attention(z1, z2, Kxy)
    Mxx, _ = lsh_attention(z1, z2, Kxx)
    Myy, _ = lsh_attention(z1, z2, Kyy)
    Myx, _ = lsh_attention(z1, z2, Kyx)

    return (Mxy, Mxx, Myy, Myx)

