import math

import torch


def get_2d_attention_keys(im_size, key_dim, normalize=False):
    assert key_dim % 4 == 0, f"key dim (currently = {key_dim}) has to be divisible by 2. "
    max_len = im_size[0] * im_size[1]
    half_key_dim = key_dim // 2

    # ____ Construct global position encodings. ____
    pe = torch.zeros(max_len, half_key_dim)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, half_key_dim, 2).float() * (-math.log(2 * max_len) / half_key_dim))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)

    # ____ Concatenate positional encodings obtained using different ways of "unflattening" the image.  ____
    target_shape = list(im_size) + [-1]
    pe_unflattened = pe.view(target_shape).transpose(0, 1).reshape(-1, half_key_dim)
    pe = torch.cat((pe, pe_unflattened), dim=1)

    # ____ Normalize. ____
    if normalize:
        pe = pe / torch.norm(input=pe, p="fro", dim=1, keepdim=True)

    # ____ Reshape and return. ____
    return pe.view(target_shape)


if __name__ == "__main__":
    """
    Run from root. 
    python -m src.dl.models.utils
    """
    test_num = 0

    if test_num == 0:
        import matplotlib.pyplot as plt
        key_dim = 128
        im_size = (32, 35)

        keys = get_image_attention_keys(im_size=im_size, key_dim=key_dim)

        flattened_keys = keys.view(-1, key_dim)
        pairwise_similarities = flattened_keys @ flattened_keys.T
        pairwise_similarities = pairwise_similarities.cpu().numpy()
        plt.imshow(pairwise_similarities, cmap="coolwarm")
        plt.colorbar()
        plt.show()


