import numpy as np

import torch

def get_2d_sincos_pos_embed(embed_dim, grid_size, nos_token=False, cls_token=False):
   
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  
    grid = np.stack(grid, axis=0) 

    grid = grid.reshape([2, 1, grid_size * grid_size]) 
    if nos_token:
        grid = np.concatenate([grid, np.ones([2, 1, 1]) * grid_size], axis=-1)  
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)

    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)

    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  

    emb = np.concatenate([emb_w, emb_h], axis=1)  
    return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):

    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega  

    pos = pos.reshape(-1)  
    out = np.einsum('m,d->md', pos, omega)  

    emb_sin = np.sin(out) 
    emb_cos = np.cos(out) 

    emb = np.concatenate([emb_sin, emb_cos], axis=1) 
    return emb

