#!/usr/bin/env python
# coding: utf-8

# In[1]:


import json
import numpy as np
import pandas as pd
import os
from torch import nn


# In[2]:


from transformers import ViTFeatureExtractor, ViTModel, DeiTFeatureExtractor, DeiTModel, BeitFeatureExtractor, BeitModel
from PIL import Image


# In[3]:


feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
model = BeitModel.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')


# In[4]:


sub1file = pd.read_csv('./stim_list/stim_lists/CSI01_stim_lists.txt', sep='\n',header=None)


# In[6]:


img_feat = []
img_avg1 = []
#img_avg2 = []
for i in sub1file[0]:
    i = i.replace('rep_','')
    if 'COCO' in i:
        image = Image.open("BOLD5000_Stimuli/Scene_Stimuli/Presented_Stimuli/COCO/"+str(i))
        inputs = feature_extractor(images=image, return_tensors="pt")
        outputs = model(**inputs, output_hidden_states=True)
    elif 'n0' in i or ('n1' in i and 'n1.' not in i and 'n11.' not in i):
        image = Image.open("BOLD5000_Stimuli/Scene_Stimuli/Presented_Stimuli/ImageNet/"+str(i))
        inputs = feature_extractor(images=image, return_tensors="pt")
        outputs = model(**inputs, output_hidden_states=True)
    else:
        image = Image.open("BOLD5000_Stimuli/Scene_Stimuli/Presented_Stimuli/Scene/"+str(i))
        inputs = feature_extractor(images=image, return_tensors="pt")
        outputs = model(**inputs, output_hidden_states=True)
        
    img_feat.append(outputs['last_hidden_state'].detach().numpy())
    img_avg1.append(outputs['pooler_output'].detach().numpy())   


# In[8]:


img_avg = np.array(img_avg1)
print(img_avg.shape)


# In[9]:


np.save('beit_img_feat',np.reshape(img_avg,(img_avg.shape[0],img_avg.shape[2])))


# In[10]:


img_feat = np.array(img_feat)


# In[11]:


img_feat = np.mean(img_feat,axis=2)


# In[12]:


np.save('beit_img_feat_final',img_feat)


# In[6]:


#img_feat = []
img_avg1 = []
img_avg2 = []
img_avg3 = []
img_avg4 = []
img_avg5 = []
img_avg6 = []
img_avg7 = []
img_avg8 = []
img_avg9 = []
img_avg10 = []
img_avg11 = []
img_avg12 = []
for i in sub1file[0]:
    i = i.replace('rep_','')
    if 'COCO' in i:
        image = Image.open("BOLD5000_Stimuli/Scene_Stimuli/Presented_Stimuli/COCO/"+str(i))
        inputs = feature_extractor(images=image, return_tensors="pt")
        outputs = model(**inputs, output_hidden_states=True)
    elif 'n0' in i or ('n1' in i and 'n1.' not in i and 'n11.' not in i):
        image = Image.open("BOLD5000_Stimuli/Scene_Stimuli/Presented_Stimuli/ImageNet/"+str(i))
        inputs = feature_extractor(images=image, return_tensors="pt")
        outputs = model(**inputs, output_hidden_states=True)
    else:
        image = Image.open("BOLD5000_Stimuli/Scene_Stimuli/Presented_Stimuli/Scene/"+str(i))
        inputs = feature_extractor(images=image, return_tensors="pt")
        outputs = model(**inputs, output_hidden_states=True)
        
    img_avg2.append(nn.functional.adaptive_avg_pool2d(outputs['hidden_states'][2], (1,768)).detach().numpy())
    img_avg4.append(nn.functional.adaptive_avg_pool2d(outputs['hidden_states'][4], (1,768)).detach().numpy())
    img_avg6.append(nn.functional.adaptive_avg_pool2d(outputs['hidden_states'][6], (1,768)).detach().numpy())
    img_avg8.append(nn.functional.adaptive_avg_pool2d(outputs['hidden_states'][8], (1,768)).detach().numpy())
    img_avg10.append(nn.functional.adaptive_avg_pool2d(outputs['hidden_states'][10], (1,768)).detach().numpy())
    img_avg12.append(nn.functional.adaptive_avg_pool2d(outputs['hidden_states'][12], (1,768)).detach().numpy())


# In[7]:


img_avg2 = np.array(img_avg2)
print(img_avg2.shape)


# In[8]:


img_avg4 = np.array(img_avg4)
print(img_avg4.shape)


# In[9]:


img_avg6 = np.array(img_avg6)
print(img_avg6.shape)


# In[10]:


img_avg8 = np.array(img_avg8)
print(img_avg8.shape)


# In[11]:


img_avg10 = np.array(img_avg10)
print(img_avg10.shape)


# In[12]:


img_avg12 = np.array(img_avg12)
print(img_avg12.shape)


# In[13]:


np.save('beit_img_feat_bold_layer2',np.reshape(img_avg2,(img_avg2.shape[0],img_avg2.shape[3])))
np.save('beit_img_feat_bold_layer4',np.reshape(img_avg4,(img_avg4.shape[0],img_avg4.shape[3])))
np.save('beit_img_feat_bold_layer6',np.reshape(img_avg6,(img_avg6.shape[0],img_avg6.shape[3])))
np.save('beit_img_feat_bold_layer8',np.reshape(img_avg8,(img_avg8.shape[0],img_avg8.shape[3])))
np.save('beit_img_feat_bold_layer10',np.reshape(img_avg10,(img_avg10.shape[0],img_avg10.shape[3])))
np.save('beit_img_feat_bold_layer12',np.reshape(img_avg12,(img_avg12.shape[0],img_avg12.shape[3])))


# In[ ]:




