#!/usr/bin/env python
# coding: utf-8

# In[1]:


import json
import numpy as np
import pandas as pd
import os
from transformers import BertTokenizer, VisualBertModel
from transformers import ViTFeatureExtractor, ViTModel, LxmertTokenizer, LxmertModel
import requests
from torch import nn


# In[2]:


import torch
import clip
from PIL import Image


# In[3]:


tokenizer = LxmertTokenizer.from_pretrained('unc-nlp/lxmert-base-uncased')
model = LxmertModel.from_pretrained('unc-nlp/lxmert-base-uncased')


# In[4]:


import pandas as pd
sub1file = pd.read_csv('./stim_list/stim_lists/CSI01_stim_lists.txt', sep='\n',header=None)


# In[5]:


with open('COCO_images_captions.json') as f:
    data = json.load(f)


# In[6]:


text_sent = []
for i in sub1file[0]:
    if 'COCO_train' in i or 'rep_COCO_train' in i:
        i = i.replace('rep_','')
        text_sent.append(data[i][0])


# In[7]:


text_imagenet = open('./BOLD5000_Stimuli/Image_Labels/imagenet_final_labels.txt', 'r')
lines = text_imagenet.readlines()
text_imagenet_data = {}
for line in lines:
    if line.split(' ',1)[0].strip() not in text_imagenet_data:
        text_imagenet_data[line.split(' ',1)[0].strip()] = line.split(' ',1)[1].strip()


# In[ ]:


img_feat = np.load('./coco_frcnn.npy')
img_feat1 = np.load('./imagenet_bold_frcnn.npy')
img_feat_boxes = np.load('./coco_frcnn_boxes.npy')
img_feat_boxes1 = np.load('./imagenet_bold_frcnn_boxes.npy')


# In[9]:


text_sent = []
img_feat2 = []
img_feat_boxes2 = []
count = 0
count1 = 0
for i in sub1file[0]:
    i = i.replace('rep_','')
    if 'COCO_train' in i or 'rep_COCO_train' in i:
        text_sent.append(data[i][0])
        img_feat2.append(img_feat[count])
        img_feat_boxes2.append(img_feat_boxes[count])
        count+=1
    elif 'n0' in i or ('n1' in i and 'n1.' not in i and 'n11.' not in i):
        #print(i.split('_')[0])
        text_sent.append(text_imagenet_data[i.split('_')[0]])
        img_feat2.append(img_feat1[count1])
        img_feat_boxes2.append(img_feat_boxes1[count1])
        count1+=1
    else:
        text_sent.append(i.split('.')[0][:-1])
        img_feat2.append(img_feat1[count1])
        img_feat_boxes2.append(img_feat_boxes1[count1])
        count1+=1


# In[10]:


img_feat2 = np.array(img_feat2)
img_feat_boxes2 = np.array(img_feat_boxes2)
print(img_feat2.shape, img_feat_boxes2.shape)


# In[11]:


language_output = []
language_avg_output = []
vision_output = []
vision_avg_output = []
pooled_output = []
#language_hidden_states = []
#vision_hidden_states = []
for i in np.arange(img_feat2.shape[0]):
    inputs = tokenizer(text_sent[i], return_tensors="pt",padding=True)
    visual_embeds = torch.Tensor(img_feat2[i].reshape(1,img_feat2[i].shape[0],img_feat2[i].shape[1]))
    visual_token_type_ids = torch.Tensor(img_feat_boxes2[i].reshape(1,img_feat_boxes2[i].shape[0],img_feat_boxes2[i].shape[1]))
    visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
    inputs.update({
     "visual_feats": visual_embeds,
     "visual_pos": visual_token_type_ids,
     "visual_attention_mask": visual_attention_mask
     })
    outputs = model(**inputs, output_attentions=True, output_hidden_states=True)
    #language_output.append(outputs['language_output'].detach().numpy())
    language_avg_output.append(nn.functional.adaptive_avg_pool2d(outputs['language_output'], (1,768)).detach().numpy())
    #vision_output.append(outputs['vision_output'].detach().numpy())
    vision_avg_output.append(nn.functional.adaptive_avg_pool2d(outputs['vision_output'], (1,768)).detach().numpy())
    pooled_output.append(outputs['pooled_output'].detach().numpy())
    #language_hidden_states.append(list(outputs['language_hidden_states']))
    #vision_hidden_states.append(list(outputs['vision_hidden_states']))


# In[12]:


language_avg_output = np.array(language_avg_output)
print(language_avg_output.shape)


# In[13]:


np.save('lxmert_bold5000_language_avg',language_avg_output.reshape(language_avg_output.shape[0],language_avg_output.shape[3]))


# In[14]:


vision_avg_output = np.array(vision_avg_output)
print(vision_avg_output.shape)


# In[15]:


np.save('lxmert_bold5000_vision_avg',vision_avg_output.reshape(vision_avg_output.shape[0],vision_avg_output.shape[3]))


# In[16]:


pooled_output = np.array(pooled_output)
pooled_output = pooled_output.reshape(pooled_output.shape[0], pooled_output.shape[2])


# In[17]:


np.save('lxmert_bold5000_common',pooled_output)


# In[16]:


#np.save('lxmert_imagenet_language',np.array(language_output))
np.save('lxmert_imagenet_vision',np.array(vision_output))


# In[11]:


#language_output = []
#language_avg_output = []
#vision_output = []
vision_avg_output = []
#pooled_output = []
vision_hidden0_output = []
vision_hidden3_output = []
vision_hidden6_output = []
vision_hidden2_output = []
vision_hidden4_output = []
#language_hidden_states = []
#vision_hidden_states = []
for i in np.arange(img_feat2.shape[0]):
    inputs = tokenizer(text_sent[i], return_tensors="pt",padding=True)
    visual_embeds = torch.Tensor(img_feat2[i].reshape(1,img_feat2[i].shape[0],img_feat2[i].shape[1]))
    visual_token_type_ids = torch.Tensor(img_feat_boxes2[i].reshape(1,img_feat_boxes2[i].shape[0],img_feat_boxes2[i].shape[1]))
    visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
    inputs.update({
     "visual_feats": visual_embeds,
     "visual_pos": visual_token_type_ids,
     "visual_attention_mask": visual_attention_mask
     })
    outputs = model(**inputs, output_attentions=True, output_hidden_states=True)
    #language_output.append(outputs['language_output'].detach().numpy())
    #language_avg_output.append(np.mean(outputs['language_output'].detach().numpy(),axis=1))
    #vision_output.append(outputs['vision_output'].detach().numpy())
    #vision_avg_output.append(nn.functional.adaptive_avg_pool2d(outputs['vision_output'], (1,768)).detach().numpy())
    vision_hidden0_output.append(nn.functional.adaptive_avg_pool2d(outputs['vision_hidden_states'][0], (1,768)).detach().numpy())
    vision_hidden2_output.append(nn.functional.adaptive_avg_pool2d(outputs['vision_hidden_states'][2], (1,768)).detach().numpy())
    vision_hidden3_output.append(nn.functional.adaptive_avg_pool2d(outputs['vision_hidden_states'][3], (1,768)).detach().numpy())
    vision_hidden4_output.append(nn.functional.adaptive_avg_pool2d(outputs['vision_hidden_states'][4], (1,768)).detach().numpy())
    vision_hidden6_output.append(nn.functional.adaptive_avg_pool2d(outputs['vision_hidden_states'][6], (1,768)).detach().numpy())
    #pooled_output.append(outputs['pooled_output'].detach().numpy())
    #language_hidden_states.append(list(outputs['language_hidden_states']))
    #vision_hidden_states.append(list(outputs['vision_hidden_states']))
    #break


# In[51]:


img_avg0 = np.array(vision_hidden0_output)
print(img_avg0.shape)

img_avg2 = np.array(vision_hidden2_output)
print(img_avg2.shape)

img_avg = np.array(vision_hidden7_output)
print(img_avg7.shape)

img_avg9 = np.array(vision_hidden9_output)
print(img_avg9.shape)

img_avg = np.array(vision_avg_output)
print(img_avg.shape)


# In[52]:


np.save('lxmert_img_feat_bold_layer1',np.reshape(img_avg1,(img_avg1.shape[0],img_avg1.shape[3])))
np.save('lxmert_img_feat_bold_layer5',np.reshape(img_avg5,(img_avg5.shape[0],img_avg5.shape[3])))
np.save('lxmert_img_feat_bold_layer7',np.reshape(img_avg7,(img_avg7.shape[0],img_avg7.shape[3])))
np.save('lxmert_img_feat_bold_layer9',np.reshape(img_avg9,(img_avg9.shape[0],img_avg9.shape[3])))
np.save('lxmert_img_feat_bold',np.reshape(img_avg,(img_avg.shape[0],img_avg.shape[3])))


# In[ ]:




