#!/usr/bin/env python
# coding: utf-8

# In[1]:


import numpy as np
from scipy.io import loadmat,savemat
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge
from scipy.stats import pearsonr
from scipy import spatial
import random


# In[2]:


abs_conc_indices = loadmat('../data/periera/abs_conc_indices.mat')


# In[3]:


abs_indices = abs_conc_indices['abs_indices'][0]
conc_indices = abs_conc_indices['conc_indices'][0]


# In[4]:


def get_abs_conc_data(voxels, vectors):
    abs_voxels = []
    conc_voxels = []
    abs_vectors = []
    conc_vectors = []

    for i in range(len(voxels)):
        if i in abs_indices:
            abs_voxels.append(voxels[i])
            abs_vectors.append(vectors[i])
        else:
            conc_voxels.append(voxels[i])
            conc_vectors.append(vectors[i])

    return abs_voxels,abs_vectors,conc_voxels,conc_vectors


# In[5]:


def pairwise_accuracy(actual, predicted):
    true = 0
    total = 0
    for i in range(0,len(actual)):
        for j in range(i+1, len(actual)):
            total += 1

            s1 = actual[i]
            s2 = actual[j]
            b1 = predicted[i]
            b2 = predicted[j]

            result1 = spatial.distance.cosine(s1, b1)
            result2 = spatial.distance.cosine(s2, b2)
            result3 = spatial.distance.cosine(s1, b2)
            result4 = spatial.distance.cosine(s2, b1)

            if(result1 + result2 < result3 + result4):
                true += 1

    return(true/total)


# In[6]:


def pearcorr(actual, predicted):
    corr = []
    for i in range(0, len(actual)):
        corr.append(np.corrcoef(actual[i],predicted[i])[0][1])
    return np.mean(corr)


# In[7]:


def rdm_gen(data):
    # calculate the values in RDM
    abs=True
    rdm = np.zeros([len(data),len(data)])
    for i in range(len(data)):
        for j in range(len(data)):
            # calculate the Pearson Coefficient
            r = pearsonr(data[i], data[j])[0]
            # calculate the dissimilarity
            if abs == True:
                rdm[i, j] = 1 - np.abs(r)
            else:
                rdm[i, j] = 1 - r

    return rdm


# In[8]:


def rdm_sim(rdm1, rdm2):
    rp = np.array(pearsonr(rdm1.flatten(), rdm2.flatten()))
    return rp


# In[9]:


def generate_indices(data):
    Taskindices = []
    for j in data['meta'][0][0][11][0][5]:
        for k in j[0]:
            #print(k)
            Taskindices.append(int(k))
    #print(len(Taskindices))
    DMNindices = []
    for j in data['meta'][0][0][11][0][6]:
        for k in j[0]:
            #print(k)
            DMNindices.append(int(k))
    #print(len(DMNindices))
    Visualindices = []
    Visualindices_body = []
    Visualindices_face = []
    Visualindices_object = []
    Visualindices_scene = []
    for j in data['meta'][0][0][11][0][9]:
        for k in j[0]:
            #print(k)
            Visualindices_body.append(int(k))
    for j in data['meta'][0][0][11][0][10]:
        for k in j[0]:
            #print(k)
            Visualindices_face.append(int(k))
    for j in data['meta'][0][0][11][0][11]:
        for k in j[0]:
            #print(k)
            Visualindices_object.append(int(k))
    for j in data['meta'][0][0][11][0][12]:
        for k in j[0]:
            #print(k)
            Visualindices_scene.append(int(k))
    
    for j in data['meta'][0][0][11][0][13]:
        for k in j[0]:
            #print(k)
            Visualindices.append(int(k))
#     print(len(Visualindices))
    Languageindices_lh = []
    Languageindices_rh = []
    for j in data['meta'][0][0][11][0][7]:
        for k in j[0]:
            #print(k)
            Languageindices_lh.append(int(k))
    for j in data['meta'][0][0][11][0][8]:
        for k in j[0]:
            #print(k)
            Languageindices_rh.append(int(k))
        #Languageindices.append(int(j))
    #print(len(Languageindices))
    return Taskindices, DMNindices, Visualindices_body, Visualindices_face, Visualindices_object,Visualindices_scene, Visualindices, Languageindices_lh, Languageindices_rh


# In[10]:


def train(train_vectors,train_voxels,test_vectors,test_voxels):
    
    train_Y = np.array(train_voxels.copy())
    train_X = np.array(train_vectors.copy())
    
    test_Y = np.array(test_voxels.copy())
    test_X = np.array(test_vectors.copy())
        
    model = Ridge(alpha=1.0)
    model.fit(train_X,train_Y)
    
    pred = model.predict(test_X)
    acc = pairwise_accuracy(test_Y,pred)
    corr = pearcorr(test_Y,pred)
    
    return round(acc,3),round(corr,3)


# In[11]:


ROIS = ['language_lh', 'language_rh', 'vision_body', 'vision_face', 'vision_object', 'vision_scene','vision', 'dmn', 'task']
layers_lxmert = ['layer0','layer1','layer3','layer4','layer5','layer6','layer7','layer8','layer9']
layers_deit = ['layer1','layer3','layer5','layer7','layer10','layer12']
layers_beit = ['layer1','layer3','layer5','layer7','layer10','layer12']
layers_vit = ['layer1','layer3','layer5','layer7','layer10']
layers_visualbert = ['layer1','layer3','layer5','layer7','layer10']
layers_eff = ['add2','add8','add16','add24']


# In[12]:


def get_subject_data(subject,layers):
    file = open('../data/periera/stimuli_180concepts.txt','r')
    lines = file.readlines()
    images = [line.strip() for line in lines]
    
    data_pic = loadmat('../data/periera/'+subject+'/data_180concepts_pictures.mat')
    Taskindices, DMNindices, Visualindices_body, Visualindices_face, Visualindices_object,Visualindices_scene, Visualindices, Languageindices_lh, Languageindices_rh = generate_indices(data_pic)
    
    roi_indices = {'language_lh':Languageindices_lh, 'language_rh':Languageindices_rh, 
        'vision_body': Visualindices_body, 'vision_face':Visualindices_face, 
        'vision_object': Visualindices_object, 'vision_scene': Visualindices_scene,
        'vision': Visualindices,  'dmn': DMNindices, 'task': Taskindices}
    
    fmri = {}
    for roi,indices in roi_indices.items():
        fmri[roi] = data_pic['examples'][0:,np.array(indices)-1]
    
    vis_feats = {}
    for layer in layers:
        vis_feats[layer] = np.load('../multi_modal/beit_img_'+ layer + '_periera.npy')
    
    return fmri, vis_feats


# In[13]:


subjects = ['P01','M01','M02','M03','M04','M05','M06','M07','M08','M09','M10','M13','M14','M15','M16','M17']


# In[14]:


for roi in ROIS:
    print(roi)
    output_abs = {}
    output_abs['2v2'] = {}
    output_abs['pear'] = {}
    output_conc = {}
    output_conc['2v2'] = {}
    output_conc['pear'] = {}
    
    for subject in subjects:
        print(subject)
        output_abs['2v2'][subject] = []
        output_abs['pear'][subject] = []
        output_conc['2v2'][subject] = []
        output_conc['pear'][subject] = []
        
        fmri, vis_feats =  get_subject_data(subject, layers_beit)
        for layer in layers_beit:
            print(layer)
            voxels = np.array(fmri[roi])
            feats = np.array(vis_feats[layer])
            abs_voxels,abs_vectors,conc_voxels,conc_vectors = get_abs_conc_data(voxels, feats)
            acc_abs, corr_abs = train(abs_vectors,abs_voxels,conc_vectors,conc_voxels)
            acc_conc, corr_conc = train(conc_vectors,conc_voxels,abs_vectors,abs_voxels)
            
            output_abs['2v2'][subject].append(acc_abs)
            output_abs['pear'][subject].append(corr_abs)
            
            output_conc['2v2'][subject].append(acc_conc)
            output_conc['pear'][subject].append(corr_conc)
            
    fn = './results/'+roi+'_abs_beit.npy'
    np.save(fn,output_abs)
    
    fn1 = './results/'+roi+'_conc_beit.npy'
    np.save(fn1,output_conc)


# In[ ]:




