#!/usr/bin/env python
# coding: utf-8

# In[1]:


import numpy as np
from scipy.io import loadmat
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge
from scipy.stats import pearsonr
from scipy import spatial
import torch
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.figure_factory as ff


# In[2]:


img_feat_lxmert_text = np.load('lxmert_periera_languageembeddings.npy')
final_image_feat = np.load('img_feat_periera_clip.npy')
img_feat_vit = np.load('./vit_img_feat_periera.npy')
img_feat_vit_layers = np.load('./vit_img_feat_final_periera12.npy')
img_feat_deit = np.load('./deit_img_feat_periera.npy')
img_feat_deit_layers = np.load('./deit_img_feat_final_periera12.npy')
img_feat_beit = np.load('./beit_img_feat_periera.npy')
img_feat_beit_layers = np.load('./deit_img_feat_final_periera12.npy')
img_feat_lxmert_layer = np.load('./lxmert_periera_visualembeddings.npy')


# In[4]:


img_feat_visual = np.load('visualbert_periera_lastlayer.npy')
img_feat_visual_layer = np.load('visualbert_periera_lastlayer_avgpatch.npy')
img_feat_lxmert = np.load('lxmert_periera_pooled.npy')
img_feat_vilbert = np.load('img_feat_periera_vilbert.npy')


# In[5]:


# @title Helper functions
from IPython.display import display, Image # to visualize images

# @markdown Function to set test custom torch RSM function: `test_custom_torch_RSM_fct()`
def test_custom_torch_RSM_fct(custom_torch_RSM_fct):
  rand_feats = torch.rand(100, 1000)
  RSM_custom = custom_torch_RSM_fct(rand_feats)
  return RSM_custom


# In[6]:


def custom_torch_RSM_fct(features):
  """
  custom_torch_RSM_fct(features)

  Custom function to calculate representational similarity matrix (RSM) of a feature
  matrix using pairwise cosine similarity.

  Complete the function below given the specific guidelines.
  Uses torch.nn.functional.cosine_similarity()

  Required args:
  - features (2D torch Tensor): feature matrix (nbr items x nbr features)

  Returns:
  - rsm (2D torch Tensor): similarity matrix
      (nbr items x nbr items)
  """

  num_items, num_features = features.shape

  #################################################
  # Fill in missing code below (...),
  # then remove or comment the line below to test your function
  #raise NotImplementedError("Exercise: Implement RSM calculation.")
  #################################################
  # EXERCISE: Implement RSM calculation
  rsm = torch.nn.functional.cosine_similarity(features.reshape(1, num_items, num_features),
      features.reshape(num_items, 1, num_features),
      dim=2)

  if not rsm.shape == (num_items, num_items):
    raise ValueError(
        f"RSM should be of shape ({num_items}, {num_items})"
        )

  return rsm


# add event to airtable
#atform.add_event('Coding Exercise 2.1.1: Complete a function that calculates RSMs')


# In[7]:


def plot_RSMs(rsms, titles=None):
    """
    plot_RSMs(rsms)
    Plots representational similarity matrices.
    Required args:
    - rsms (list): list of 2D RSMs arrays.
    Optional args:
    - titles (list): title for each RSM. (default: None)
    Returns:
    - fig (plt.Figure): figure
    - axes (plt.Axes): axes
    """

    if not isinstance(rsms, list):
        rsms = [rsms]
        titles = [titles]

    if len(rsms) != len(titles):
        raise ValueError("If providing titles, must provide as many "
            "as the number of RSMs.")

    min_val = np.min([rsm.min() for rsm in rsms] + [-1])
    max_val = np.max([rsm.max() for rsm in rsms] + [1])

    ncols = len(rsms)
    wid = 5

    fig, axes = plt.subplots(
        ncols=ncols, figsize=[ncols * wid, wid], squeeze=False
        )
    fig.suptitle("Representational Similarity Matrices (RSMs)", y=1.05, fontsize=22)
    
    cm_w = 0.05 / ncols
    fig.subplots_adjust(right=1-cm_w*2)
    cbar_ax = fig.add_axes([1, 0.15, cm_w, 0.7])

    for ax, rsm, title in zip(axes.flatten(), rsms, titles):
        im = ax.imshow(rsm, vmin=min_val, vmax=max_val, interpolation="none")
        ax.set_title(title, y=1.02, fontsize=22)
        ax.tick_params(labelsize=14)

    cbar = fig.colorbar(im, cax=cbar_ax)
    cbar.set_label(label="Similarity", size=22)
    cbar_ax.yaxis.set_label_position("left")
    #plt.xticks(fontsize=16)
    #plt.yticks(fontsize=16)

    return fig, axes


# In[8]:


def generate_indices(data):
    Taskindices = []
    for j in data['meta'][0][0][11][0][5]:
        for k in j[0]:
            #print(k)
            Taskindices.append(int(k))
    #print(len(Taskindices))
    DMNindices = []
    for j in data['meta'][0][0][11][0][6]:
        for k in j[0]:
            #print(k)
            DMNindices.append(int(k))
    #print(len(DMNindices))
    Visualindices = []
    Visualindices_body = []
    Visualindices_face = []
    Visualindices_object = []
    Visualindices_scene = []
    for j in data['meta'][0][0][11][0][9]:
        for k in j[0]:
            #print(k)
            Visualindices_body.append(int(k))
    for j in data['meta'][0][0][11][0][10]:
        for k in j[0]:
            #print(k)
            Visualindices_face.append(int(k))
    for j in data['meta'][0][0][11][0][11]:
        for k in j[0]:
            #print(k)
            Visualindices_object.append(int(k))
    for j in data['meta'][0][0][11][0][12]:
        for k in j[0]:
            #print(k)
            Visualindices_scene.append(int(k))
    
    for j in data['meta'][0][0][11][0][13]:
        for k in j[0]:
            #print(k)
            Visualindices.append(int(k))
    #print(len(Visualindices))
    Languageindices_lh = []
    Languageindices_rh = []
    for j in data['meta'][0][0][11][0][7]:
        for k in j[0]:
            #print(k)
            Languageindices_lh.append(int(k))
    for j in data['meta'][0][0][11][0][8]:
        for k in j[0]:
            #print(k)
            Languageindices_rh.append(int(k))
        #Languageindices.append(int(j))
    #print(len(Languageindices))
    return Taskindices, DMNindices, Visualindices_body, Visualindices_face, Visualindices_object,Visualindices_scene, Visualindices, Languageindices_lh, Languageindices_rh


# In[9]:


img_feat_efficient = np.load('efficientnetb5/effcientnetb5_fc_periera.npy')
img_feat_inv2res = np.load('inceptionv2resnet_conv7b_periera.npy')
img_feat_vggnet = np.load('vgg_pereira_feat.npy', allow_pickle=True)
img_feat_resnet = np.load('resnet_pereira_feat.npy', allow_pickle=True)


# In[10]:


img_feat_roberta = np.load('periera_roberta.npy', allow_pickle=True)
print(img_feat_roberta.shape)


# In[11]:


img_feat_resnet = dict(enumerate(img_feat_resnet.flatten(),1))[1]


# In[12]:


img_feat_vggnet = dict(enumerate(img_feat_vggnet.flatten(),1))[1]


# In[13]:


aa = []
count = 0
temp = []
for key in sorted(img_feat_resnet.keys()):
    temp.append(img_feat_resnet[key]['fc'])
    #print(img_feat_resnet[key]['fc'].shape)
    count+=1
    if count==6:
        temp = np.array(temp)
        #print(temp.shape)
        temp = np.mean(temp, axis=0)
        aa.append(temp)
        temp = []
        count = 0
img_feat_resnet1 = np.array(aa)
print(img_feat_resnet1.shape)


# In[14]:


aa = []
count = 0
temp = []
for key in sorted(img_feat_vggnet.keys()):
    temp.append(img_feat_vggnet[key]['fc8'])
    #print(img_feat_vggnet[key])
    count+=1
    if count==6:
        temp = np.array(temp)
        #print(temp.shape)
        temp = np.mean(temp, axis=0)
        aa.append(temp)
        temp = []
        count = 0
img_feat_vggnet1 = np.array(aa)
print(img_feat_vggnet1.shape)


# In[15]:


img_feat_efficient = np.reshape(img_feat_efficient, (180, 6, img_feat_efficient.shape[1]))
img_feat_efficient = np.mean(img_feat_efficient, axis=1)
print(img_feat_efficient.shape)


# In[16]:


img_feat_inv2res = np.reshape(img_feat_inv2res, (180, 6, img_feat_inv2res.shape[1]))
img_feat_inv2res = np.mean(img_feat_inv2res, axis=1)
print(img_feat_inv2res.shape)


# In[25]:


img_rsm_clip = 1 - spatial.distance.cdist(final_image_feat, final_image_feat, 'cosine')
img_rsm_lxmert = 1 - spatial.distance.cdist(img_feat_lxmert, img_feat_lxmert, 'cosine')
img_rsm_lxmert_layer = 1 - spatial.distance.cdist(img_feat_lxmert_layer, img_feat_lxmert_layer, 'cosine')
#img_rsm_lxmert_text = 1 - spatial.distance.cdist(img_feat_lxmert_text, img_feat_lxmert_text, 'cosine')
img_rsm_visbert = 1 - spatial.distance.cdist(img_feat_visual, img_feat_visual, 'cosine')
img_rsm_visbert_layer = 1 - spatial.distance.cdist(img_feat_visual_layer, img_feat_visual_layer, 'cosine')
img_rsm_vilbert = 1 - spatial.distance.cdist(img_feat_vilbert, img_feat_vilbert, 'cosine')
img_rsm_vit = 1 - spatial.distance.cdist(img_feat_vit, img_feat_vit, 'cosine')
img_rsm_vit_layer = 1 - spatial.distance.cdist(img_feat_vit_layers, img_feat_vit_layers, 'cosine')
img_rsm_beit = 1 - spatial.distance.cdist(img_feat_beit, img_feat_beit, 'cosine')
img_rsm_beit_layer = 1 - spatial.distance.cdist(img_feat_beit_layers, img_feat_beit_layers, 'cosine')
img_rsm_deit = 1 - spatial.distance.cdist(img_feat_deit, img_feat_deit, 'cosine')
img_rsm_deit_layer = 1 - spatial.distance.cdist(img_feat_deit_layers, img_feat_deit_layers, 'cosine')
img_rsm_vgg = 1 - spatial.distance.cdist(img_feat_vggnet1, img_feat_vggnet1, 'cosine')
img_rsm_resnet = 1 - spatial.distance.cdist(img_feat_resnet1, img_feat_resnet1, 'cosine')
img_rsm_efficient = 1 - spatial.distance.cdist(img_feat_efficient, img_feat_efficient, 'cosine')
img_rsm_inv2res = 1 - spatial.distance.cdist(img_feat_inv2res, img_feat_inv2res, 'cosine')
img_rsm_roberta = 1 - spatial.distance.cdist(img_feat_roberta, img_feat_roberta, 'cosine')


# In[22]:


#kf = KFold(n_splits=18, shuffle=True)
subjects = ['P01','M01','M02','M03','M04','M05','M06','M07','M08','M09',
           'M10','M13','M15','M16','M17']
ROIS = ['Language_LH', 'Language_RH', 'Vision_Body', 'Vision_Face', 'Vision_Object', 'Vision_Scene', 'Vision', 'DMN', 'TP']
sub_rsm = []
for eachsub in subjects:
    #print(eachsub)
    if eachsub in 'P01':
        data_pic = loadmat('D:/Periera_Dataset/data_180concepts_pictures.mat')
    elif eachsub in 'M06' or eachsub in 'M01':
        data_pic = loadmat('D:/Periera_Dataset/'+eachsub+'/data_180concepts_pictures.mat')
    elif eachsub in 'M02':
        data_pic = loadmat('D:/Periera_Dataset/Experiment/Experiment/data/'+eachsub+'/data_180concepts_pictures.mat')
    else:
        data_pic = loadmat('D:/Periera_Dataset/exp/exp/data/'+eachsub+'/data_180concepts_pictures.mat')
    Taskindices, DMNindices, Visualindices_body, Visualindices_face, Visualindices_object,Visualindices_scene, Visualindices, Languageindices_lh, Languageindices_rh = generate_indices(data_pic)
    rois = {'language_lh': Languageindices_lh, 'language_rh': Languageindices_rh,
        'vision_body': Visualindices_body, 'vision_face':Visualindices_face, 
        'vision_object': Visualindices_object, 'vision_scene': Visualindices_scene,
        'vision': Visualindices,  'dmn': DMNindices, 'task': Taskindices}
    rsms = []
    for roi,indices in rois.items():
        #print(roi)
        #print()
        temp = []
        roi_fmri = data_pic['examples'][0:,np.array(indices)-1]
        roi_rsm = 1 - spatial.distance.cdist(roi_fmri, roi_fmri, 'cosine')
        temp.append(np.round(pearsonr(roi_rsm.flatten(), img_rsm_clip.flatten())[0],3))
        temp.append(np.round(pearsonr(roi_rsm.flatten(), img_rsm_lxmert.flatten())[0], 3))
        #temp.append(np.round(pearsonr(roi_rsm.flatten(), img_rsm_lxmert_layer.flatten())[0], 3))
        temp.append(np.round(pearsonr(roi_rsm.flatten(), img_rsm_lxmert_text.flatten())[0], 3))
        temp.append(np.round(pearsonr(roi_rsm.flatten(), img_rsm_visbert_layer.flatten())[0], 3))
        temp.append(np.round(pearsonr(roi_rsm.flatten(), img_rsm_visbert.flatten())[0], 3))
        temp.append(np.round(pearsonr(roi_rsm.flatten(), img_rsm_vilbert.flatten())[0], 3))
        temp.append(np.round(pearsonr(roi_rsm.flatten(), img_rsm_vit.flatten())[0], 3))
        temp.append(np.round(pearsonr(roi_rsm.flatten(), img_rsm_vit_layer.flatten())[0], 3))
        temp.append(np.round(pearsonr(roi_rsm.flatten(), img_rsm_deit.flatten())[0], 3))
        temp.append(np.round(pearsonr(roi_rsm.flatten(), img_rsm_deit_layer.flatten())[0], 3))
        temp.append(np.round(pearsonr(roi_rsm.flatten(), img_rsm_beit.flatten())[0], 3))
        temp.append(np.round(pearsonr(roi_rsm.flatten(), img_rsm_beit_layer.flatten())[0], 3))
        temp.append(np.round(pearsonr(roi_rsm.flatten(), img_rsm_inv2res.flatten())[0], 3))
        temp.append(np.round(pearsonr(roi_rsm.flatten(), img_rsm_efficient.flatten())[0], 3))
        temp.append(np.round(pearsonr(roi_rsm.flatten(), img_rsm_resnet.flatten())[0], 3))
        temp.append(np.round(pearsonr(roi_rsm.flatten(), img_rsm_vgg.flatten())[0], 3))
        temp.append(np.round(pearsonr(roi_rsm.flatten(), img_rsm_roberta.flatten())[0], 3))
        rsms.append(temp)
    rsms = np.array(rsms)
    sub_rsm.append(rsms)


# In[24]:


sub_rsm = np.array(sub_rsm)
print(sub_rsm.shape)


# In[25]:


np.save('rsa_pereira', sub_rsm)


# In[24]:


ROIS = ['Language_LH', 'Language_RH', 'Vision_Body', 'Vision_Face', 'Vision_Object', 'Vision_Scene', 'Vision', 'DMN', 'TP']
fig = ff.create_annotated_heatmap(z=np.delete(np.round(np.mean(sub_rsm, axis=0),3),[2,4], axis=1),
                   x=['CLIP', 'LXMERT', 'VB', 'ViLBERT', 'ViT', 'ViT+Patch','DEiT','DEiT+Patch', 'BEiT', 'BEiT+Patch',  'IncV2Res','EfficentNetB5', 'ResNet50', 'VGGNet19', 'RoBERTa'],
                y=ROIS, annotation_text=np.delete(np.round(np.mean(sub_rsm, axis=0),3),[2,4], axis=1), showscale=True, zmin=-0.01, zmax=0.4)
fig['layout']['xaxis']['side'] = 'bottom'
fig.layout.update(width=1050,
    height=500,
  xaxis=dict(
        title='',
        showgrid=False,
        titlefont=dict(
           # family='Gill sans, monospace',
            size=14,
            #color='#7f7f7f'
        ),
        showticklabels=True,
        tickangle=25,
        tickfont=dict(
            size=14,
            color='black'
        ),
    ),
    yaxis=dict(
        title='',
        showgrid=False,
        titlefont=dict(
            #family='Gill sans',
            #size=12,
            #color='#7f7f7f'
        ),
        showticklabels=True,
        tickfont=dict(
            size=14,
            color='black'
        ),
)
)
fig.show()

