import clip
import torch

"""
tokens = clip.tokenize([description]).numpy()
token_tensor = torch.from_numpy(tokens).to(device)
with torch.no_grad():
    lang_feats, lang_embs = _clip_encode_text(clip_model, token_tensor)
"""

# extract CLIP language features for goal string
def clip_encode_text(clip_model, text):
    x = clip_model.token_embedding(text).type(
        clip_model.dtype
    )  # [batch_size, n_ctx, d_model]

    x = x + clip_model.positional_embedding.type(clip_model.dtype)
    x = x.permute(1, 0, 2)  # NLD -> LND
    x = clip_model.transformer(x)
    x = x.permute(1, 0, 2)  # LND -> NLD
    x = clip_model.ln_final(x).type(clip_model.dtype)

    emb = x.clone()
    x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ clip_model.text_projection

    return x, emb