#!/usr/bin/env python
# coding: utf-8

# In[1]:


import numpy as np
from scipy.io import loadmat
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge
from scipy.stats import pearsonr
from scipy import spatial


# In[2]:


def pairwise_accuracy(actual, predicted):
    true = 0
    total = 0
    for i in range(0,len(actual)):
#         print(i)
        for j in range(i+1, len(actual)):
            total += 1

            s1 = actual[i]
            s2 = actual[j]
            b1 = predicted[i]
            b2 = predicted[j]

            result1 = spatial.distance.cosine(s1, b1)
            result2 = spatial.distance.cosine(s2, b2)
            result3 = spatial.distance.cosine(s1, b2)
            result4 = spatial.distance.cosine(s2, b1)

            if(result1 + result2 < result3 + result4):
                true += 1

    return(true/total)


# In[3]:


def pearcorr(actual, predicted):
    corr = []
    for i in range(0, len(actual)):
        corr.append(np.corrcoef(actual[i],predicted[i])[0][1])
    return np.mean(corr)


# In[6]:


kf = KFold(n_splits=10)


# In[7]:


def train(vectors, voxels, scene_indices, coco_indices, imagenet_indices):
    
    dataset_X = np.array(voxels.copy())
    
    dataset_Y = np.array(vectors.copy())
    
    actual = []
    predicted = []
    pairwise_2v2 = []
    final_corr = []
    
    pairwise_2v2_im = []
    final_corr_im = []
    
    pairwise_2v2_sc = []
    final_corr_sc = []
    
    train_data = dataset_X[coco_indices]
    label_data = dataset_Y[coco_indices]
    
    test_data1 = dataset_X[imagenet_indices]
    test_data2 = dataset_X[scene_indices]
    
    label_data1 = dataset_Y[imagenet_indices]
    label_data2 = dataset_Y[scene_indices]

    cnt = 0
    for train_index, test_index in kf.split(train_data):

        X_train, X_test = train_data[train_index], train_data[test_index]
        y_train, y_test = label_data[train_index], label_data[test_index]
           
        model = Ridge(alpha=1.0)
        model.fit(X_train,y_train)
        
        

        y_pred = model.predict(X_test)
        
        pairwise_2v2.append(pairwise_accuracy(y_test,y_pred))
        final_corr.append(pearcorr(y_test,y_pred))
        
        y_pred = model.predict(test_data1)
        y_test = label_data1
        
        pairwise_2v2_im.append(pairwise_accuracy(y_test,y_pred))
        final_corr_im.append(pearcorr(y_test,y_pred))
        
        y_pred = model.predict(test_data2)
        y_test = label_data2
        
        pairwise_2v2_sc.append(pairwise_accuracy(y_test,y_pred))
        final_corr_sc.append(pearcorr(y_test,y_pred))

        cnt += 1
        #print(cnt)

#     return np.mean(accuracies),np.mean(accuracies01),fin_acc,fin_acc1
    return np.mean(pairwise_2v2), np.mean(final_corr), np.mean(pairwise_2v2_im), np.mean(final_corr_im), np.mean(pairwise_2v2_sc), np.mean(final_corr_sc)


# In[10]:


#img_feat = np.load('visualbert_coco_vision_avg.npy')
#img_feat1 = np.load('visualbert_imagenet_vision_pooled.npy')
#img_feat = np.load('lxmert_coco_common.npy')
#img_feat1 = np.load('lxmert_imagenet_common.npy')
#img_feat = np.load('../vilbert-multi-task/vilbert_imgfeat.npy')
#img_feat1 = np.load('../vilbert-multi-task/vilbert_imgfeat_imagenet.npy')
#img_feat = np.load('visualbert_coco_layer12.npy')
#img_feat = np.load('img_feat.npy')
img_feat = np.load('vit_img_feat_bold_layer12.npy')
#img_feat = np.load('./lxmert_img_feat_bold_layer1.npy')
#img_feat = np.load('./lxmert_img_feat_bold_layer4.npy')


# In[12]:


file = open('stim_list/stim_lists/CSI01_stim_lists.txt','r')
lines = file.readlines()


# In[13]:


#final_image_feat = []
count=0
count1=0
scene_indices = []
coco_indices = []
imagenet_indices = []
for i in np.arange(len(lines)):
    tt = lines[i].replace('rep_','')
    if 'COCO_train' in tt:
        #final_image_feat.append(img_feat[count])
        coco_indices.append(i)
        count+=1
    elif 'n0' in tt or ('n1' in tt and 'n1.' not in tt and 'n11.' not in tt):
        #print(i.split('_')[0])
        #final_image_feat.append(img_feat1[count1])
        imagenet_indices.append(i)
        count1+=1
    else:
        #final_image_feat.append(img_feat1[count1])
        count1+=1
        scene_indices.append(i)
#final_image_feat = np.array(final_image_feat)
#print(final_image_feat.shape)


# In[14]:


ROIS = ['LHPPA', 'RHPPA', 'LHLOC', 'RHLOC', 'LHEarlyVis', 'RHEarlyVis', 'LHOPA', 'RHOPA',  'LHRSC','RHRSC']


# # COCO to ImageNet and Scenes

# In[15]:


for roi in ROIS:
    print(roi)
    print()
    
    for i in np.arange(1,4):
        fmri_data1 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR1.mat')
        fmri_data2 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR2.mat')
        fmri_data3 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR3.mat')
        fmri_data4 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR34.mat')
        fmri_data5 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR4.mat')
        fmri_data6 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR5.mat')
        
        roi_fmri = np.mean([fmri_data1[roi],fmri_data2[roi],fmri_data3[roi],fmri_data4[roi],
            fmri_data5[roi],fmri_data6[roi]], axis=0)
        if i!=1:
            file1 = open('stim_list/stim_lists/CSI0'+str(i)+'_stim_lists.txt','r')
            lines_sub3 = file1.readlines()
            indices = []
            for j in lines_sub3:
                try:
                    indices.append(lines.index(j))
                except:
                    continue
                    
            d,c, di, ci, ds, cs = train(roi_fmri,img_feat[indices], scene_indices, coco_indices, imagenet_indices)
            print(np.round(d,3),np.round(c,3), np.round(di,3),np.round(ci,3),
                 np.round(ds,3),np.round(cs,3))
        else:
            d,c, di, ci, ds, cs = train(roi_fmri,img_feat, scene_indices, coco_indices, imagenet_indices)
            print(np.round(d,3),np.round(c,3), np.round(di,3),np.round(ci,3),
                 np.round(ds,3),np.round(cs,3))
            #print(len(indices))


# In[16]:


for roi in ROIS:
    print(roi)
    print()
    
    for i in np.arange(1,4):
        fmri_data1 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR1.mat')
        fmri_data2 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR2.mat')
        fmri_data3 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR3.mat')
        fmri_data4 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR34.mat')
        fmri_data5 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR4.mat')
        fmri_data6 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR5.mat')
        
        roi_fmri = np.mean([fmri_data1[roi],fmri_data2[roi],fmri_data3[roi],fmri_data4[roi],
            fmri_data5[roi],fmri_data6[roi]], axis=0)
        if i!=1:
            file1 = open('stim_list/stim_lists/CSI0'+str(i)+'_stim_lists.txt','r')
            lines_sub3 = file1.readlines()
            indices = []
            for j in lines_sub3:
                try:
                    indices.append(lines.index(j))
                except:
                    continue
                    
            d,c, di, ci, ds, cs = train(roi_fmri,img_feat[indices], scene_indices, imagenet_indices, coco_indices)
            print(np.round(d,3),np.round(c,3), np.round(di,3),np.round(ci,3),
                 np.round(ds,3),np.round(cs,3))
        else:
            d,c, di, ci, ds, cs = train(roi_fmri,img_feat, scene_indices, imagenet_indices, coco_indices)
            print(np.round(d,3),np.round(c,3), np.round(di,3),np.round(ci,3),
                 np.round(ds,3),np.round(cs,3))
            #print(len(indices))


# In[17]:


for roi in ROIS:
    print(roi)
    print()
    
    for i in np.arange(1,4):
        fmri_data1 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR1.mat')
        fmri_data2 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR2.mat')
        fmri_data3 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR3.mat')
        fmri_data4 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR34.mat')
        fmri_data5 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR4.mat')
        fmri_data6 = loadmat('s'+str(i)+'mat/mat/CSI'+str(i)+'_ROIs_TR5.mat')
        
        roi_fmri = np.mean([fmri_data1[roi],fmri_data2[roi],fmri_data3[roi],fmri_data4[roi],
            fmri_data5[roi],fmri_data6[roi]], axis=0)
        if i!=1:
            file1 = open('stim_list/stim_lists/CSI0'+str(i)+'_stim_lists.txt','r')
            lines_sub3 = file1.readlines()
            indices = []
            for j in lines_sub3:
                try:
                    indices.append(lines.index(j))
                except:
                    continue
                    
            d,c, di, ci, ds, cs = train(roi_fmri,img_feat[indices], imagenet_indices, scene_indices,  coco_indices)
            print(np.round(d,3),np.round(c,3), np.round(di,3),np.round(ci,3),
                 np.round(ds,3),np.round(cs,3))
        else:
            d,c, di, ci, ds, cs = train(roi_fmri,img_feat, imagenet_indices, scene_indices,  coco_indices)
            print(np.round(d,3),np.round(c,3), np.round(di,3),np.round(ci,3),
                 np.round(ds,3),np.round(cs,3))
            #print(len(indices))


# In[ ]:




