#!/usr/bin/env python
# coding: utf-8

# In[1]:


import numpy as np
from scipy.io import loadmat
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge
from scipy.stats import pearsonr
from scipy import spatial
from sklearn.metrics import auc, r2_score, roc_auc_score


# In[2]:


def pairwise_accuracy(actual, predicted):
    true = 0
    total = 0
    for i in range(0,len(actual)):
#         print(i)
        for j in range(i+1, len(actual)):
            total += 1

            s1 = actual[i]
            s2 = actual[j]
            b1 = predicted[i]
            b2 = predicted[j]

            result1 = spatial.distance.cosine(s1, b1)
            result2 = spatial.distance.cosine(s2, b2)
            result3 = spatial.distance.cosine(s1, b2)
            result4 = spatial.distance.cosine(s2, b1)

            if(result1 + result2 < result3 + result4):
                true += 1

    return(true/total)


# In[3]:


def pearcorr(actual, predicted):
    corr = []
    for i in range(0, len(actual)):
        corr.append(np.corrcoef(actual[i],predicted[i])[0][1])
    #np.save('correlations/'+str(i)+'.npy', corr)
    return np.mean(corr)


# In[7]:


kf = KFold(n_splits=10, shuffle=True)


# In[8]:


def train(vectors, voxels):
    
    dataset_X = np.array(voxels.copy())
    
    dataset_Y = np.array(vectors.copy())
    
    actual = []
    predicted = []
    pairwise_2v2 = []
    final_corr = []

    cnt = 0
    for train_index, test_index in kf.split(dataset_X):

        X_train, X_test = dataset_X[train_index], dataset_X[test_index]
        y_train, y_test = dataset_Y[train_index], dataset_Y[test_index]
           
        model = Ridge(alpha=1.0)
        model.fit(X_train,y_train)
        
        

        y_pred = model.predict(X_test)
        
        #pairwise_2v2.append(pairwise_accuracy(y_test,y_pred))
        final_corr.append(pearcorr(y_test,y_pred))
        actual.extend(y_test)
        predicted.extend(y_pred)
        cnt += 1
        #print(cnt)
        

    
    fin_acc = pairwise_accuracy(actual,predicted)
    #     return np.mean(accuracies),np.mean(accuracies01),fin_acc,fin_acc1
    return fin_acc, np.mean(final_corr)


# In[10]:


img_feat = np.load('beit_img_feat_periera.npy')
#img_feat = np.load('img_feat_periera_clip.npy')
#img_feat = np.load('vit_img_feat_periera.npy')
#img_feat = np.load('visualbert_coco_layer12.npy')
#img_feat = np.load('inceptionv2resnet_conv2d150_periera.npy')
#img_feat = np.load('lxmert_periera_visualembeddings.npy')
#img_feat = np.load('img_feat_periera_vilbert.npy')


# In[9]:


def generate_indices(data):
    Taskindices = []
    for j in data['meta'][0][0][11][0][5]:
        for k in j[0]:
            #print(k)
            Taskindices.append(int(k))
    #print(len(Taskindices))
    DMNindices = []
    for j in data['meta'][0][0][11][0][6]:
        for k in j[0]:
            #print(k)
            DMNindices.append(int(k))
    #print(len(DMNindices))
    Visualindices = []
    Visualindices_body = []
    Visualindices_face = []
    Visualindices_object = []
    Visualindices_scene = []
    for j in data['meta'][0][0][11][0][9]:
        for k in j[0]:
            #print(k)
            Visualindices_body.append(int(k))
    for j in data['meta'][0][0][11][0][10]:
        for k in j[0]:
            #print(k)
            Visualindices_face.append(int(k))
    for j in data['meta'][0][0][11][0][11]:
        for k in j[0]:
            #print(k)
            Visualindices_object.append(int(k))
    for j in data['meta'][0][0][11][0][12]:
        for k in j[0]:
            #print(k)
            Visualindices_scene.append(int(k))
    
    for j in data['meta'][0][0][11][0][13]:
        for k in j[0]:
            #print(k)
            Visualindices.append(int(k))
    #print(len(Visualindices))
    Languageindices_lh = []
    Languageindices_rh = []
    for j in data['meta'][0][0][11][0][7]:
        for k in j[0]:
            #print(k)
            Languageindices_lh.append(int(k))
    for j in data['meta'][0][0][11][0][8]:
        for k in j[0]:
            #print(k)
            Languageindices_rh.append(int(k))
        #Languageindices.append(int(j))
    #print(len(Languageindices))
    return Taskindices, DMNindices, Visualindices_body, Visualindices_face, Visualindices_object,Visualindices_scene, Visualindices, Languageindices_lh, Languageindices_rh


# In[14]:


rois = {'language_lh':Languageindices_lh, 'language_rh':Languageindices_rh, 
        'vision_body': Visualindices_body, 'vision_face':Visualindices_face, 
        'vision_object': Visualindices_object, 'vision_scene': Visualindices_scene,
        'vision': Visualindices,  'dmn': DMNindices, 'task': Taskindices}


# In[16]:



subjects = ['P01','M01','M02','M03','M04','M05','M06','M07','M08','M09',
           'M10','M13','M15','M16','M17']
kf = KFold(n_splits=18)

#img_feat = np.load('vit_img_feat_final_periera12.npy')
img_feat = np.load('vit_img_layer7_periera.npy')
#img_feat = np.load('beit_img_feat_periera.npy')
#img_feat = np.load('inceptionv2resnet_conv7b_periera.npy')
#img_feat = np.reshape(img_feat, (180, 6, img_feat.shape[1]))
#img_feat = np.mean(img_feat, axis=1)
#text_feat = np.load('periera_roberta.npy')
#img_feat = np.concatenate([img_feat, text_feat], axis=1)
#img_feat = text_feat
print(img_feat.shape)
for eachsub in subjects:
    #print(eachsub)
    if eachsub in 'P01':
        data_pic = loadmat('D:/Periera_Dataset/data_180concepts_pictures.mat')
    elif eachsub in 'M06' or eachsub in 'M01':
        data_pic = loadmat('D:/Periera_Dataset/'+eachsub+'/data_180concepts_pictures.mat')
    elif eachsub in 'M02':
        data_pic = loadmat('D:/Periera_Dataset/Experiment/Experiment/data/'+eachsub+'/data_180concepts_pictures.mat')
    else:
        data_pic = loadmat('D:/Periera_Dataset/exp/exp/data/'+eachsub+'/data_180concepts_pictures.mat')
    Taskindices, DMNindices, Visualindices_body, Visualindices_face, Visualindices_object,Visualindices_scene, Visualindices, Languageindices_lh, Languageindices_rh = generate_indices(data_pic)
    rois = { 'language_rh':Languageindices_rh, 
        'vision_body': Visualindices_body, 'vision_face':Visualindices_face, 
        'vision_object': Visualindices_object, 'vision_scene': Visualindices_scene,
        'vision': Visualindices,  'dmn': DMNindices, 'task': Taskindices}
    for roi,indices in rois.items():
        #print(roi)
        #print()
        roi_fmri = data_pic['examples'][0:,np.array(indices)-1]
        d, c = train(roi_fmri,img_feat)
        print(np.round(d,3), np.round(c,3))
        break

