#!/usr/bin/env python
# coding: utf-8

# In[1]:


import json
import numpy as np
import pandas as pd
import os


# In[4]:


import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)


# In[5]:


import json
  
# Opening JSON file
f = open('concept2caption.json')
  
# returns JSON object as 
# a dictionary
data = json.load(f)


# In[6]:


text_sent = []
for eachword in sorted(data['concept2caption'].keys()):
    #print(eachword)
    if len(data['concept2caption'][eachword])!=6:
        print(len(data['concept2caption'][eachword]))
        print(eachword)
    for eachsent in data['concept2caption'][eachword]:
        #print(eachsent)
        text_sent.append(eachsent)


# In[17]:


remove_indices = [264, 379, 558, 610, 674, 675, 692, 758, 780, 782, 866, 897, 1005, 1009, 1013]


# In[18]:


img_dir = []
for eachdir in sorted(os.listdir('./IARPA_expt1_stim_images/images/')):
    if eachdir not in ['.DS_Store', 'image_concepts.mat', 'images_concepts.txt']:
        for eachimg in sorted(os.listdir('./IARPA_expt1_stim_images/images/'+eachdir+'/')):
            img_dir.append('./IARPA_expt1_stim_images/images/'+eachdir+'/'+eachimg)   


# In[19]:


img_dir = np.delete(img_dir, remove_indices, axis=0)
print(len(img_dir))


# In[20]:


img_feat = []
text_feat = []
for i in np.arange(len(img_dir)):
    image = preprocess(Image.open(img_dir[i])).unsqueeze(0).to(device)
    text = clip.tokenize(text_sent[i]).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)
        img_feat.append(image_features.detach().numpy())
        text_feat.append(text_features.detach().numpy())


# In[22]:


vision_embeddings = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    vision_embeddings.append(np.mean(img_feat[i:i+lt],axis=0))
    i+=lt
vision_embeddings = np.array(vision_embeddings)
print(vision_embeddings.shape)


# In[24]:


np.save('img_feat_periera_clip',np.reshape(vision_embeddings,(vision_embeddings.shape[0],vision_embeddings.shape[2])))


# In[25]:


language_embeddings = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    language_embeddings.append(np.mean(text_feat[i:i+lt],axis=0))
    i+=lt
language_embeddings = np.array(language_embeddings)
print(language_embeddings.shape)


# In[26]:


np.save('text_feat_periera_clip',np.reshape(language_embeddings,(language_embeddings.shape[0],language_embeddings.shape[2])))

