#!/usr/bin/env python
# coding: utf-8

# In[10]:


import json
import numpy as np
import pandas as pd
import os
from transformers import BertTokenizer, VisualBertModel
from transformers import ViTFeatureExtractor, ViTModel, LxmertTokenizer, LxmertModel
import requests
import torch.nn as nn
import torch


# In[2]:


import torch
import clip
from PIL import Image


# In[3]:


tokenizer = LxmertTokenizer.from_pretrained('unc-nlp/lxmert-base-uncased')
model = LxmertModel.from_pretrained('unc-nlp/lxmert-base-uncased')


# In[4]:


img_feat = np.load('../vilbert-multi-task/periera_img1.npy')
img_feat2 = np.load('../vilbert-multi-task/periera_img2.npy')
img_feat = np.concatenate([img_feat,img_feat2], axis=0)
print(img_feat.shape)


# In[5]:


img_feat_box = np.load('../vilbert-multi-task/periera_boxes_img1.npy')
img_feat2_box = np.load('../vilbert-multi-task/periera_boxes_img2.npy')
img_feat_boxes = np.concatenate([img_feat_box,img_feat2_box], axis=0)
print(img_feat_boxes.shape)


# In[6]:


import json
  
# Opening JSON file
f = open('concept2caption.json')
  
# returns JSON object as 
# a dictionary
data = json.load(f)


# In[7]:


text_sent = []
for eachword in sorted(data['concept2caption'].keys()):
    #print(eachword)
    if len(data['concept2caption'][eachword])!=6:
        print(len(data['concept2caption'][eachword]))
        print(eachword)
    for eachsent in data['concept2caption'][eachword]:
        #print(eachsent)
        text_sent.append(eachsent)


# In[8]:


remove_indices = [264, 379, 558, 610, 674, 675, 692, 758, 780, 782, 866, 897, 1005, 1009, 1013]


# In[9]:


img_feat = np.delete(img_feat, remove_indices, axis=0)
img_feat_boxes = np.delete(img_feat_boxes, remove_indices, axis=0)
print(img_feat.shape, img_feat_boxes.shape)


# In[10]:


language_output = []
language_avg_output = []
vision_output = []
vision_avg_output = []
pooled_output = []
#language_hidden_states = []
#vision_hidden_states = []
for i in np.arange(img_feat.shape[0]):
    inputs = tokenizer(text_sent[i], return_tensors="pt",padding=True)
    visual_embeds = torch.Tensor(img_feat[i].reshape(1,img_feat[i].shape[0],img_feat[i].shape[1]))
    visual_token_type_ids = torch.Tensor(img_feat_boxes[i].reshape(1,img_feat_boxes[i].shape[0],img_feat_boxes[i].shape[1]))
    visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
    inputs.update({
     "visual_feats": visual_embeds,
     "visual_pos": visual_token_type_ids,
     "visual_attention_mask": visual_attention_mask
     })
    outputs = model(**inputs, output_attentions=True, output_hidden_states=True)
    language_output.append(outputs['language_output'].detach().numpy())
    language_avg_output.append(np.mean(outputs['language_output'].detach().numpy(),axis=1))
    vision_output.append(outputs['vision_output'].detach().numpy())
    vision_avg_output.append(np.mean(outputs['vision_output'].detach().numpy(), axis=1))
    pooled_output.append(outputs['pooled_output'].detach().numpy())
    #language_hidden_states.append(list(outputs['language_hidden_states']))
    #vision_hidden_states.append(list(outputs['vision_hidden_states']))


# In[11]:


language_avg_output = np.array(language_avg_output)
print(language_avg_output.shape)


# In[12]:


vision_embeddings = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    vision_embeddings.append(np.mean(vision_avg_output[i:i+lt],axis=0))
    i+=lt


# In[13]:


vision_embeddings = np.array(vision_embeddings)
print(vision_embeddings.shape)


# In[14]:


np.save('lxmert_periera_visualembeddings',vision_embeddings.reshape(vision_embeddings.shape[0],vision_embeddings.shape[2]))


# In[16]:


language_embeddings = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    language_embeddings.append(np.mean(language_avg_output[i:i+lt],axis=0))
    i+=lt


# In[17]:


language_embeddings = np.array(language_embeddings)
print(language_embeddings.shape)


# In[18]:


np.save('lxmert_periera_languageembeddings',language_embeddings.reshape(language_embeddings.shape[0],language_embeddings.shape[2]))


# In[19]:


pooled_embeddings = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    pooled_embeddings.append(np.mean(pooled_output[i:i+lt],axis=0))
    i+=lt


# In[20]:


pooled_embeddings = np.array(pooled_embeddings)
print(pooled_embeddings.shape)


# In[21]:


np.save('lxmert_periera_pooled',pooled_embeddings.reshape(pooled_embeddings.shape[0],pooled_embeddings.shape[2]))


# In[11]:


np.save('lxmert_coco_language_avg',language_avg_output.reshape(language_avg_output.shape[0],language_avg_output.shape[2]))


# In[12]:


vision_avg_output = np.array(vision_avg_output)
print(vision_avg_output.shape)


# In[13]:


np.save('lxmert_coco_vision_avg',vision_avg_output.reshape(vision_avg_output.shape[0],vision_avg_output.shape[2]))


# In[14]:


pooled_output = np.array(pooled_output)
pooled_output = pooled_output.reshape(pooled_output.shape[0], pooled_output.shape[2])


# In[16]:


np.save('lxmert_coco_common',pooled_output)


# In[18]:


#language_output = []
#language_avg_output = []
#vision_output = []
vision_avg_output = []
#pooled_output = []
vision_hidden1_output = []
vision_hidden5_output = []
vision_hidden7_output = []
vision_hidden8_output = []
vision_hidden9_output = []
#language_hidden_states = []
#vision_hidden_states = []
for i in np.arange(img_feat.shape[0]):
    inputs = tokenizer(text_sent[i], return_tensors="pt",padding=True)
    visual_embeds = torch.Tensor(img_feat[i].reshape(1,img_feat[i].shape[0],img_feat[i].shape[1]))
    visual_token_type_ids = torch.Tensor(img_feat_boxes[i].reshape(1,img_feat_boxes[i].shape[0],img_feat_boxes[i].shape[1]))
    visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
    inputs.update({
     "visual_feats": visual_embeds,
     "visual_pos": visual_token_type_ids,
     "visual_attention_mask": visual_attention_mask
     })
    outputs = model(**inputs, output_attentions=True, output_hidden_states=True)
    #language_output.append(outputs['language_output'].detach().numpy())
    #language_avg_output.append(np.mean(outputs['language_output'].detach().numpy(),axis=1))
    #vision_output.append(outputs['vision_output'].detach().numpy())
    #vision_avg_output.append(nn.functional.adaptive_avg_pool2d(outputs['vision_output'], (1,768)).detach().numpy())
    vision_hidden1_output.append(nn.functional.adaptive_avg_pool2d(outputs['vision_hidden_states'][1], (1,768)).detach().numpy())
    vision_hidden5_output.append(nn.functional.adaptive_avg_pool2d(outputs['vision_hidden_states'][5], (1,768)).detach().numpy())
    vision_hidden7_output.append(nn.functional.adaptive_avg_pool2d(outputs['vision_hidden_states'][7], (1,768)).detach().numpy())
    vision_hidden8_output.append(nn.functional.adaptive_avg_pool2d(outputs['vision_hidden_states'][8], (1,768)).detach().numpy())
    vision_hidden9_output.append(nn.functional.adaptive_avg_pool2d(outputs['vision_hidden_states'][9], (1,768)).detach().numpy())
    #pooled_output.append(outputs['pooled_output'].detach().numpy())
    #language_hidden_states.append(list(outputs['language_hidden_states']))
    #vision_hidden_states.append(list(outputs['vision_hidden_states']))
    #break


# In[19]:


img_avg1 = np.array(vision_hidden1_output)
img_avg1 = np.reshape(img_avg1, (img_avg1.shape[0], img_avg1.shape[3]))
vision_embeddings1 = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    vision_embeddings1.append(np.mean(img_avg1[i:i+lt],axis=0))
    i+=lt
vision_embeddings1 = np.array(vision_embeddings1)
print(vision_embeddings1.shape)


# In[20]:


img_avg3 = np.array(vision_hidden5_output)
img_avg3 = np.reshape(img_avg3, (img_avg3.shape[0], img_avg3.shape[3]))
vision_embeddings3 = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    vision_embeddings3.append(np.mean(img_avg3[i:i+lt],axis=0))
    i+=lt
vision_embeddings3 = np.array(vision_embeddings3)
print(vision_embeddings3.shape)


# In[21]:


img_avg5 = np.array(vision_hidden7_output)
img_avg5 = np.reshape(img_avg5, (img_avg5.shape[0], img_avg5.shape[3]))
vision_embeddings5 = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    vision_embeddings5.append(np.mean(img_avg5[i:i+lt],axis=0))
    i+=lt
vision_embeddings5 = np.array(vision_embeddings5)
print(vision_embeddings5.shape)


# In[22]:


img_avg7 = np.array(vision_hidden8_output)
img_avg7 = np.reshape(img_avg7, (img_avg7.shape[0], img_avg7.shape[3]))
vision_embeddings7 = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    vision_embeddings7.append(np.mean(img_avg7[i:i+lt],axis=0))
    i+=lt
vision_embeddings7 = np.array(vision_embeddings7)
print(vision_embeddings7.shape)


# In[23]:


img_avg10 = np.array(vision_hidden9_output)
img_avg10 = np.reshape(img_avg10, (img_avg10.shape[0], img_avg10.shape[3]))
vision_embeddings10 = []
i = 0
for eachword in sorted(data['concept2caption'].keys()):
    lt = len(data['concept2caption'][eachword])
    vision_embeddings10.append(np.mean(img_avg10[i:i+lt],axis=0))
    i+=lt
vision_embeddings10 = np.array(vision_embeddings10)
print(vision_embeddings10.shape)


# In[24]:


np.save('lxmert_img_layer1_periera',vision_embeddings1)
np.save('lxmert_img_layer5_periera',vision_embeddings3)
np.save('lxmert_img_layer7_periera',vision_embeddings5)
np.save('lxmert_img_layer8_periera',vision_embeddings7)
np.save('lxmert_img_layer9_periera',vision_embeddings10)
#np.save('lxmert_img_layer12_periera',vision_embeddings12)


# In[ ]:




