#!/usr/bin/env python
# coding: utf-8

# In[1]:


import json
import numpy as np
import pandas as pd
import os
from torch import nn


# In[2]:


from transformers import ViTFeatureExtractor, ViTModel
from PIL import Image


# In[3]:


feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')


# In[4]:


sub1file = pd.read_csv('./stim_list/stim_lists/CSI01_stim_lists.txt', sep='\n',header=None)


# In[9]:


img_feat = []
img_avg1 = []
img_avg2 = []
for i in sub1file[0]:
    i = i.replace('rep_','')
    if 'COCO' in i:
        image = Image.open("BOLD5000_Stimuli/Scene_Stimuli/Presented_Stimuli/COCO/"+str(i))
        inputs = feature_extractor(images=image, return_tensors="pt")
        outputs = model(**inputs, output_hidden_states=True)
    elif 'n0' in i or ('n1' in i and 'n1.' not in i and 'n11.' not in i):
        image = Image.open("BOLD5000_Stimuli/Scene_Stimuli/Presented_Stimuli/ImageNet/"+str(i))
        inputs = feature_extractor(images=image, return_tensors="pt")
        outputs = model(**inputs, output_hidden_states=True)
    else:
        image = Image.open("BOLD5000_Stimuli/Scene_Stimuli/Presented_Stimuli/Scene/"+str(i))
        inputs = feature_extractor(images=image, return_tensors="pt")
        outputs = model(**inputs, output_hidden_states=True)
        
    img_feat.append(outputs['last_hidden_state'].detach().numpy())
    img_avg1.append(outputs['pooler_output'].detach().numpy())   
    break


# In[7]:


img_avg = np.array(img_avg)
print(img_avg.shape)


# In[8]:


np.save('vit_img_feat',np.reshape(img_avg,(img_avg.shape[0],img_avg.shape[2])))


# In[9]:


img_feat = np.array(img_feat)


# In[11]:


img_feat = np.mean(img_feat,axis=2)


# In[12]:


np.save('vit_img_feat_final',np.reshape(img_feat,(img_feat.shape[0],img_feat.shape[2])))


# In[5]:


#img_feat = []
img_avg1 = []
img_avg2 = []
img_avg3 = []
img_avg4 = []
img_avg5 = []
img_avg6 = []
img_avg7 = []
img_avg8 = []
img_avg9 = []
img_avg10 = []
img_avg11 = []
img_avg12 = []
for i in sub1file[0]:
    i = i.replace('rep_','')
    if 'COCO' in i:
        image = Image.open("BOLD5000_Stimuli/Scene_Stimuli/Presented_Stimuli/COCO/"+str(i))
        inputs = feature_extractor(images=image, return_tensors="pt")
        outputs = model(**inputs, output_hidden_states=True)
    elif 'n0' in i or ('n1' in i and 'n1.' not in i and 'n11.' not in i):
        image = Image.open("BOLD5000_Stimuli/Scene_Stimuli/Presented_Stimuli/ImageNet/"+str(i))
        inputs = feature_extractor(images=image, return_tensors="pt")
        outputs = model(**inputs, output_hidden_states=True)
    else:
        image = Image.open("BOLD5000_Stimuli/Scene_Stimuli/Presented_Stimuli/Scene/"+str(i))
        inputs = feature_extractor(images=image, return_tensors="pt")
        outputs = model(**inputs, output_hidden_states=True)
        
    img_avg2.append(nn.functional.adaptive_avg_pool2d(outputs['hidden_states'][2], (1,768)).detach().numpy())
    img_avg4.append(nn.functional.adaptive_avg_pool2d(outputs['hidden_states'][4], (1,768)).detach().numpy())
    img_avg6.append(nn.functional.adaptive_avg_pool2d(outputs['hidden_states'][6], (1,768)).detach().numpy())
    img_avg8.append(nn.functional.adaptive_avg_pool2d(outputs['hidden_states'][8], (1,768)).detach().numpy())
    img_avg9.append(nn.functional.adaptive_avg_pool2d(outputs['hidden_states'][9], (1,768)).detach().numpy())
    img_avg11.append(nn.functional.adaptive_avg_pool2d(outputs['hidden_states'][11], (1,768)).detach().numpy())


# In[17]:


outputs['hidden_states'][0].shape


# In[27]:


nn.(outputs['hidden_states'][0],(1,1)).detach().numpy().shape


# In[6]:


img_avg2 = np.array(img_avg2)
print(img_avg2.shape)


# In[7]:


img_avg4 = np.array(img_avg4)
print(img_avg4.shape)


# In[8]:


img_avg6 = np.array(img_avg6)
print(img_avg6.shape)


# In[9]:


img_avg8 = np.array(img_avg8)
print(img_avg8.shape)

img_avg9 = np.array(img_avg9)
print(img_avg9.shape)

img_avg11 = np.array(img_avg11)
print(img_avg11.shape)


# In[10]:


np.save('vit_img_feat_bold_layer2',np.reshape(img_avg2,(img_avg2.shape[0],img_avg2.shape[3])))
np.save('vit_img_feat_bold_layer4',np.reshape(img_avg4,(img_avg4.shape[0],img_avg4.shape[3])))
np.save('vit_img_feat_bold_layer6',np.reshape(img_avg6,(img_avg6.shape[0],img_avg6.shape[3])))
np.save('vit_img_feat_bold_layer8',np.reshape(img_avg8,(img_avg8.shape[0],img_avg8.shape[3])))
np.save('vit_img_feat_bold_layer9',np.reshape(img_avg9,(img_avg9.shape[0],img_avg9.shape[3])))
np.save('vit_img_feat_bold_layer11',np.reshape(img_avg11,(img_avg11.shape[0],img_avg11.shape[3])))


# In[ ]:




