#!/usr/bin/env python
# coding: utf-8

# In[1]:


import numpy as np
from scipy.io import loadmat
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge
from scipy.stats import pearsonr
from scipy import spatial


# In[2]:


def pairwise_accuracy(actual, predicted):
    true = 0
    total = 0
    for i in range(0,len(actual)):
#         print(i)
        for j in range(i+1, len(actual)):
            total += 1

            s1 = actual[i]
            s2 = actual[j]
            b1 = predicted[i]
            b2 = predicted[j]

            result1 = spatial.distance.cosine(s1, b1)
            result2 = spatial.distance.cosine(s2, b2)
            result3 = spatial.distance.cosine(s1, b2)
            result4 = spatial.distance.cosine(s2, b1)

            if(result1 + result2 < result3 + result4):
                true += 1

    return(true/total)


# In[3]:


def pearcorr(actual, predicted):
    corr = []
    for i in range(0, len(actual)):
        corr.append(np.corrcoef(actual[i],predicted[i])[0][1])
    return np.mean(corr)


# In[6]:


kf = KFold(n_splits=10)


# In[7]:


def train(vectors, voxels):
    
    dataset_X = np.array(voxels.copy())
    
    dataset_Y = np.array(vectors.copy())
    
    actual = []
    predicted = []
    pairwise_2v2 = []
    final_corr = []

    cnt = 0
    for train_index, test_index in kf.split(dataset_X):

        X_train, X_test = dataset_X[train_index], dataset_X[test_index]
        y_train, y_test = dataset_Y[train_index], dataset_Y[test_index]
           
        model = Ridge(alpha=1.0)
        model.fit(X_train,y_train)
        
        

        y_pred = model.predict(X_test)
        
        #pairwise_2v2.append(pairwise_accuracy(y_test,y_pred))
        final_corr.append(pearcorr(y_test,y_pred))
        actual.extend(y_test)
        predicted.extend(y_pred)
        #print(pairwise_2v2[cnt],final_corr[cnt],rdm_acc[cnt])
        cnt += 1
        #print(cnt)
    
    fin_acc = pairwise_accuracy(actual,predicted)
    #fin_corr = pearcorr(actual, predicted)

#     return np.mean(accuracies),np.mean(accuracies01),fin_acc,fin_acc1
    return fin_acc, np.mean(final_corr)


# In[13]:


#img_feat = np.load('visualbert_coco_vision_avg.npy')
#img_feat1 = np.load('visualbert_imagenet_vision_pooled.npy')
#img_feat = np.load('lxmert_coco_common.npy')
#img_feat1 = np.load('lxmert_imagenet_common.npy')
#img_feat = np.load('../vilbert-multi-task/vilbert_imgfeat.npy')
#img_feat1 = np.load('../vilbert-multi-task/vilbert_imgfeat_imagenet.npy')
#img_feat = np.load('visualbert_coco_layer12.npy')
#img_feat = np.load('img_feat.npy')
#final_image_feat = np.load('img_feat_bold5000_clip.npy')
#img_feat = np.load('./beit_img_feat.npy')
#img_feat = np.load('./effcientnetb5_fc_bold500.npy')
img_feat = np.load('./DEit_img_feat.npy')


# In[21]:


img_feat = np.mean(img_feat, axis=2)
img_feat = np.reshape(img_feat, (img_feat.shape[0], img_feat.shape[2]))
print(img_feat.shape)


# In[9]:


file = open('stim_list/stim_lists/CSI01_stim_lists.txt','r')
lines = file.readlines()


# In[10]:


ROIS = ['LHPPA', 'RHPPA', 'LHLOC', 'RHLOC', 'LHEarlyVis', 'RHEarlyVis', 'LHOPA', 'RHOPA',  'LHRSC','RHRSC']


# In[12]:


for roi in ROIS:
    print(roi)
    print()
    
    for i in np.arange(1,5):
        fmri_data1 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR1.mat')
        fmri_data2 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR2.mat')
        fmri_data3 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR3.mat')
        fmri_data4 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR34.mat')
        fmri_data5 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR4.mat')
        fmri_data6 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR5.mat')
        
        roi_fmri = np.mean([fmri_data1[roi],fmri_data2[roi],fmri_data3[roi],fmri_data4[roi],
            fmri_data5[roi],fmri_data6[roi]], axis=0)
        if i!=1:
            file1 = open('stim_list/stim_lists/CSI0'+str(i)+'_stim_lists.txt','r')
            lines_sub3 = file1.readlines()
            indices = []
            for j in lines_sub3:
                try:
                    indices.append(lines.index(j))
                except:
                    continue
                    
            d,c = train(roi_fmri,img_feat[indices])
            print(np.round(d,3),np.round(c,3))
        else:
            d,c = train(roi_fmri,img_feat)
            print(np.round(d,3),np.round(c,3))
            #print(len(indices))

