#!/usr/bin/env python
# coding: utf-8

# In[1]:


import json
import numpy as np
import pandas as pd
import os
import torch.nn as nn
import torch


# In[2]:


from transformers import ViTFeatureExtractor, ViTModel, DeiTFeatureExtractor, DeiTModel, BeitFeatureExtractor, BeitModel
from PIL import Image


# In[3]:


feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
model = BeitModel.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')


# In[4]:


import json
  
# Opening JSON file
f = open('concept2caption.json')
  
# returns JSON object as 
# a dictionary
data = json.load(f)


# In[5]:


text_sent = []
for eachword in sorted(data['concept2caption'].keys()):
    #print(eachword)
    if len(data['concept2caption'][eachword])!=6:
        print(len(data['concept2caption'][eachword]))
        print(eachword)
    for eachsent in data['concept2caption'][eachword]:
        #print(eachsent)
        text_sent.append(eachsent)


# In[6]:


remove_indices = [264, 379, 558, 610, 674, 675, 692, 758, 780, 782, 866, 897, 1005, 1009, 1013]


# In[7]:


img_dir = []
for eachdir in sorted(os.listdir('./IARPA_expt1_stim_images/images/')):
    if eachdir not in ['.DS_Store', 'image_concepts.mat', 'images_concepts.txt']:
        for eachimg in sorted(os.listdir('./IARPA_expt1_stim_images/images/'+eachdir+'/')):
            img_dir.append('./IARPA_expt1_stim_images/images/'+eachdir+'/'+eachimg)   


# In[8]:


img_dir = np.delete(img_dir, remove_indices, axis=0)
print(len(img_dir))


# In[9]:


img_feat = []
img_avg = []
import cv2
for i in np.arange(len(img_dir)):
    try:
        image = cv2.imread(img_dir[i])
        inputs = feature_extractor(images=image, return_tensors="pt")
        outputs = model(**inputs, output_hidden_states=True)
        img_feat.append(outputs['last_hidden_state'].detach().numpy())
        img_avg.append(outputs['pooler_output'].detach().numpy())
    except:
        print(i)


# In[10]:


vision_embeddings = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    vision_embeddings.append(np.mean(img_avg[i:i+lt],axis=0))
    i+=lt
vision_embeddings = np.array(vision_embeddings)
print(vision_embeddings.shape)


# In[11]:


img_avg = np.array(img_avg)
print(img_avg.shape)


# In[12]:


np.save('beit_img_feat_periera',np.reshape(vision_embeddings,(vision_embeddings.shape[0],vision_embeddings.shape[2])))


# In[13]:


img_feat = np.array(img_feat)


# In[14]:


img_feat.shape


# In[15]:


img_feat = np.mean(img_feat,axis=2)


# In[16]:


np.save('beit_img_feat_final_periera',np.reshape(img_feat,(img_feat.shape[0],img_feat.shape[2])))


# In[10]:


vision_embeddings = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    vision_embeddings.append(np.mean(img_feat[i:i+lt],axis=0))
    i+=lt
vision_embeddings = np.array(vision_embeddings)
print(vision_embeddings.shape)


# In[11]:


np.save('beit_img_feat_final_periera12',vision_embeddings)


# In[16]:


img_avg2 = []
img_avg4 = []
img_avg6 = []
img_avg8 = []
img_avg9 = []
img_avg11 = []
import cv2
for i in np.arange(len(img_dir)):
    try:
        image = cv2.imread(img_dir[i])
        inputs = feature_extractor(images=image, return_tensors="pt")
        outputs = model(**inputs, output_hidden_states=True)
        img_avg2.append(nn.functional.adaptive_avg_pool2d(outputs['hidden_states'][2],(1,768)).detach().numpy())
        img_avg4.append(nn.functional.adaptive_avg_pool2d(outputs['hidden_states'][4],(1,768)).detach().numpy())
        img_avg6.append(nn.functional.adaptive_avg_pool2d(outputs['hidden_states'][6],(1,768)).detach().numpy())
        img_avg8.append(nn.functional.adaptive_avg_pool2d(outputs['hidden_states'][8],(1,768)).detach().numpy())
        img_avg9.append(nn.functional.adaptive_avg_pool2d(outputs['hidden_states'][9],(1,768)).detach().numpy())
        img_avg11.append(nn.functional.adaptive_avg_pool2d(outputs['hidden_states'][11],(1,768)).detach().numpy())
    except:
        print(i)


# In[10]:


img_avg1 = np.array(img_avg1)
img_avg1 = np.reshape(img_avg1, (img_avg1.shape[0], img_avg1.shape[3]))
vision_embeddings1 = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    vision_embeddings1.append(np.mean(img_avg1[i:i+lt],axis=0))
    i+=lt
vision_embeddings1 = np.array(vision_embeddings1)
print(vision_embeddings1.shape)


# In[11]:


img_avg3 = np.array(img_avg3)
img_avg3 = np.reshape(img_avg3, (img_avg3.shape[0], img_avg3.shape[3]))
vision_embeddings3 = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    vision_embeddings3.append(np.mean(img_avg3[i:i+lt],axis=0))
    i+=lt
vision_embeddings3 = np.array(vision_embeddings3)
print(vision_embeddings3.shape)


# In[12]:


img_avg5 = np.array(img_avg5)
img_avg5 = np.reshape(img_avg5, (img_avg5.shape[0], img_avg5.shape[3]))
vision_embeddings5 = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    vision_embeddings5.append(np.mean(img_avg5[i:i+lt],axis=0))
    i+=lt
vision_embeddings5 = np.array(vision_embeddings5)
print(vision_embeddings5.shape)


# In[13]:


img_avg7 = np.array(img_avg7)
img_avg7 = np.reshape(img_avg7, (img_avg7.shape[0], img_avg7.shape[3]))
vision_embeddings7 = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    vision_embeddings7.append(np.mean(img_avg7[i:i+lt],axis=0))
    i+=lt
vision_embeddings7 = np.array(vision_embeddings7)
print(vision_embeddings7.shape)


# In[14]:


img_avg10 = np.array(img_avg10)
img_avg10 = np.reshape(img_avg10, (img_avg10.shape[0], img_avg10.shape[3]))
vision_embeddings10 = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    vision_embeddings10.append(np.mean(img_avg10[i:i+lt],axis=0))
    i+=lt
vision_embeddings10 = np.array(vision_embeddings10)
print(vision_embeddings10.shape)


# In[15]:


np.save('beit_img_layer1_periera',vision_embeddings1)
np.save('beit_img_layer3_periera',vision_embeddings3)
np.save('beit_img_layer5_periera',vision_embeddings5)
np.save('beit_img_layer7_periera',vision_embeddings7)
np.save('beit_img_layer10_periera',vision_embeddings10)


# In[17]:


img_avg2 = np.array(img_avg2)
img_avg2 = np.reshape(img_avg2, (img_avg2.shape[0], img_avg2.shape[3]))
vision_embeddings2 = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    vision_embeddings2.append(np.mean(img_avg2[i:i+lt],axis=0))
    i+=lt
vision_embeddings2 = np.array(vision_embeddings2)
print(vision_embeddings2.shape)


# In[18]:


img_avg4 = np.array(img_avg4)
img_avg4 = np.reshape(img_avg4, (img_avg4.shape[0], img_avg4.shape[3]))
vision_embeddings4 = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    vision_embeddings4.append(np.mean(img_avg4[i:i+lt],axis=0))
    i+=lt
vision_embeddings4 = np.array(vision_embeddings4)
print(vision_embeddings4.shape)


# In[19]:


img_avg6 = np.array(img_avg6)
img_avg6 = np.reshape(img_avg6, (img_avg6.shape[0], img_avg6.shape[3]))
vision_embeddings6 = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    vision_embeddings6.append(np.mean(img_avg6[i:i+lt],axis=0))
    i+=lt
vision_embeddings6 = np.array(vision_embeddings6)
print(vision_embeddings6.shape)


# In[20]:


img_avg8 = np.array(img_avg8)
img_avg8 = np.reshape(img_avg8, (img_avg8.shape[0], img_avg8.shape[3]))
vision_embeddings8 = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    vision_embeddings8.append(np.mean(img_avg8[i:i+lt],axis=0))
    i+=lt
vision_embeddings8 = np.array(vision_embeddings8)
print(vision_embeddings8.shape)


# In[21]:


img_avg9 = np.array(img_avg9)
img_avg9 = np.reshape(img_avg9, (img_avg9.shape[0], img_avg9.shape[3]))
vision_embeddings9 = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    vision_embeddings9.append(np.mean(img_avg9[i:i+lt],axis=0))
    i+=lt
vision_embeddings9 = np.array(vision_embeddings9)
print(vision_embeddings9.shape)


# In[22]:


img_avg11 = np.array(img_avg11)
img_avg11 = np.reshape(img_avg11, (img_avg11.shape[0], img_avg11.shape[3]))
vision_embeddings11 = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    vision_embeddings11.append(np.mean(img_avg11[i:i+lt],axis=0))
    i+=lt
vision_embeddings11 = np.array(vision_embeddings11)
print(vision_embeddings11.shape)


# In[23]:


np.save('beit_img_layer2_periera',vision_embeddings2)
np.save('beit_img_layer4_periera',vision_embeddings4)
np.save('beit_img_layer6_periera',vision_embeddings6)
np.save('beit_img_layer8_periera',vision_embeddings8)
np.save('beit_img_layer9_periera',vision_embeddings9)
np.save('beit_img_layer11_periera',vision_embeddings11)


# In[ ]:




