import torchvision.transforms as transforms

import torch
import lib.models.layers.clip.clip as clip

from torch import nn



class nlp_embedding(nn.Module):
    def __init__(self):
        super().__init__()

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        # self.clip_model, self.preprocess = clip.load('/home/yuqing/test2/CiteTracker/lib/models/layers/clip/ViT-B-32.pt', self.device)
        self.clip_model, self.preprocess = clip.load('ViT-B/32', self.device)
        self.tem_nlp = None
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, nlp):
        # image-encoder
        if isinstance(nlp, str):
            nlp=[nlp]
        nlp_des = torch.Tensor().cuda()
        tem_nlp = clip.tokenize(nlp).to(self.device)
        if self.training:
            with torch.no_grad():
                x = self.clip_model.token_embedding(tem_nlp).type(
                        self.clip_model.dtype)  # [batch_size, n_ctx, d_model]
                x = x + self.clip_model.positional_embedding.type(self.clip_model.dtype)
                x = x.permute(1, 0, 2)  # NLD -> LND
                x = self.clip_model.transformer(x)
                x = x.permute(1, 0, 2)  # LND -> NLD
                x = self.clip_model.ln_final(x).type(self.clip_model.dtype)

                hidden = x[torch.arange(x.shape[0]), :24] @ self.clip_model.text_projection

        else:
            if self.tem_nlp is None:
                with torch.no_grad():
                    x = self.clip_model.token_embedding(tem_nlp).type(
                            self.clip_model.dtype)  # [batch_size, n_ctx, d_model]

                    x = x + self.clip_model.positional_embedding.type(self.clip_model.dtype)
                    x = x.permute(1, 0, 2)  # NLD -> LND
                    x = self.clip_model.transformer(x)
                    x = x.permute(1, 0, 2)  # LND -> NLD
                    x = self.clip_model.ln_final(x).type(self.clip_model.dtype)

                    hidden = x[torch.arange(x.shape[0]), :24] @ self.clip_model.text_projection

                # with torch.no_grad():

        nlp_des = torch.cat([nlp_des, hidden], dim=0)

        return nlp_des
