import pickle
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as ss
#import CKA_funcs as ckaf
import rsatoolbox
from rsatoolbox import rdm
import sklearn.linear_model as skLM
import sklearn.decomposition as skPC
print('act')


fig_dir = './figs/'
data_dir = './saved_data/'
act_dir = './activity/'

def RSA(net_list,layer,plot_dissim=True,name=''):
    all_dissims = []

    #make dissim matrices (1-pearson cor) and flatten them:

    for netID in net_list:
         loaded = pickle.load( open( act_dir+netID+"_Act.pickle", "rb" ) )

         act = loaded[0][0][0][layer]
         #print(act[0,3,3,:])
         IMs = act.shape[0]
         #print(act.shape)
         act = np.reshape(act,[IMs,-1])
         dissim = 1 - np.corrcoef(act)
         if plot_dissim:
             plt.imshow(dissim,vmin=0,vmax=1.2); plt.colorbar(); plt.xticks([]); plt.yticks([])
             plt.savefig(fig_dir+'TaskDissim_'+netID+'_'+str(layer)); plt.close()
             #print(act[0,:]);
             rew = np.array(loaded[2]).flatten()
             inds = np.argsort(rew); #print(loaded[2],np.shape(loaded[2])); print(inds.shape)
             act = act[inds,:]; #print(act.shape)
             #print(act[0,:]); print(act[np.where(inds==0)[0],:]); #b=d
             dissim = 1 - np.corrcoef(act)
             #plt.imshow(dissim,vmin=0,vmax=1.2); plt.colorbar(); plt.xticks([]); plt.yticks([])
             #plt.savefig(fig_dir+'RewardDissim_'+netID+'_'+str(layer)); plt.close()
         #print(dissim.shape); print(sum(np.isnan(dissim.flatten())))
         all_dissims.append(dissim[np.triu_indices(IMs,1)])
    all_dissims = np.array(all_dissims)  
    #print(all_dissims.shape)       
    #rsa_mat, pvals = ss.spearmanr(all_dissims.T)
    rsa_mat = np.corrcoef(all_dissims)
    print(rsa_mat.shape)

    plt.imshow(rsa_mat,vmin=-.1,vmax=1.1); plt.colorbar(); plt.title('RSAcr '+str(layer)); 
    plt.xticks(np.arange(len(net_list)),net_list,rotation=90); plt.yticks(np.arange(len(net_list)),net_list); 
    plt.savefig(fig_dir+name+'RSAcr_'+str(layer)); plt.close()
    np.savez(data_dir+name+'RSAcr_'+str(layer),rsa_mat,np.array(net_list))



def RSA_use_toolbox(net_list,layer,method='cosine',name=''):
    print('rsa')

    models = []
    for netID in net_list:
         loaded = pickle.load( open( act_dir+netID+"_Act.pickle", "rb" ) )

         act = loaded[0][0][0][layer]
         #print(act[0,3,3,:])
         IMs = act.shape[0]
         print(act.shape)
         act = np.reshape(act,[IMs,-1])
         dissim = 1 - np.corrcoef(act)
         print('prerdm')
         model_rdms = rdm.RDMs(np.expand_dims(np.array(dissim),0),
                            rdm_descriptors={'brain_computational_model':netID, 'measurement_model':'complete'},
                            dissimilarity_measure='Correlation'
                           )
         models.append(model_rdms)
    print('postrdm')
    rsa_mat = np.zeros((len(net_list),len(net_list)))*np.nan
    for m1 in range(len(net_list)):
        for m2 in range(len(net_list)):
            print(m1,m2)
            rsa_mat[m1,m2] = rsa_mat[m2,m1] = rdm.compare(models[m1], models[m2], method=method);
            if m1 == m2:
                break
    strg = 'RSAtb_'+method
    plt.imshow(rsa_mat,vmin=0,vmax=1); plt.colorbar(); plt.title(strg+str(layer)); 
    plt.xticks(np.arange(len(net_list)),net_list,rotation=90); plt.yticks(np.arange(len(net_list)),net_list); 
    plt.savefig(fig_dir+strg+'_'+str(layer)+name); plt.close()
    np.savez(data_dir+strg+'_'+str(layer)+name,rsa_mat,np.array(net_list))



def activity_histos(net_list):
    models = []
    for netID in net_list:
         loaded = pickle.load( open( act_dir+netID+"_Act.pickle", "rb" ) )
         for l in range(4):
             act = loaded[0][0][0][l]
             plt.subplot(1,4,l+1)
             plt.hist(act.flatten())
         plt.savefig(fig_dir+'ActHist_'+netID); plt.close()


def run_RepSims(nets,name=''):
   print('rep')
   for l in range(4):
       for method in ['cosine_cov']: #['cosine_cov']: #,'spearman','rho-a','corr']:
           print(l)
           RSA_use_toolbox(nets,l,method=method,name=name)

def PCAplots(net_list,layer,name=''):
    print('pca')
    plt.rcParams['axes.labelsize'] = 15
    plt.rcParams['xtick.labelsize'] = 12
    plt.rcParams['ytick.labelsize'] = 12
    colors = {'Tsk':'red','Zax':'darkred','Rwd':'lightcoral','Al3':'darkorange','Prp':'magenta','Vim':'cadetblue','Img':'lightblue','Cpc':'blue','Vir':'k','Rnd':'gold'}
    PR = []
    VE85 = []
    #models = []
    for netID in net_list:
         loaded = pickle.load( open( act_dir+netID+"_Act.pickle", "rb" ) )
         net_type = netID[3:6]
         col = colors[net_type]        

         act = loaded[0][0][0][layer]
         #print(act[0,3,3,:])
         IMs = act.shape[0]
         act = np.reshape(act,[IMs,-1])

         if layer < 3:
             PCs = 850
         else:
             PCs = 128
         pca = skPC.PCA(n_components=PCs)
         pca.fit(act)
         varexpr = pca.explained_variance_ratio_
         varexp = pca.explained_variance_
         cum_ve = []
         for p in range(PCs):
             cum_ve.append(np.sum(varexpr[:p+1]))
         VE85.append(np.where(np.array(cum_ve)>.85)[0][0])
         PR.append(np.sum(varexp,0)**2/np.sum(varexp**2,0))
         plt.plot(np.arange(1,PCs+1),cum_ve,color=col)
    plt.xlabel('PCs'); plt.ylabel('Variance explained'); plt.ylim([0,1.05])

    plt.savefig(fig_dir+'PCA_'+str(layer)+name); plt.close()

    plt.rcParams['xtick.labelsize']=12
    plt.rcParams['ytick.labelsize']=12
    print(layer,PR,VE85)
    plt.subplot(2,1,1)
    plt.bar(np.arange(len(net_list)),PR,color='gray'); #plt.xticks(np.arange(len(net_list)),net_list,rotation=90);
    plt.xticks([],[]); plt.ylim([0,64]);  
    if layer  == 0:
       plt.plot([-.1,len(net_list)-.9],[9.1,9.1],'k:')
    if layer  == 3:
       plt.plot([-.1,len(net_list)-.9],[14.7,14.7],'k:')
    plt.subplot(2,1,2)
    plt.bar(np.arange(len(net_list)),VE85,color='gray'); plt.xticks(np.arange(len(net_list)),net_list,rotation=90);
    if layer  == 0:
       plt.plot([-.1,len(net_list)-.9],[78,78],'k:')
    if layer  == 3:
       plt.plot([-.1,len(net_list)-.9],[24,24],'k:')
    plt.savefig(fig_dir+'PCA_PR85VE_'+str(layer)+name); plt.close()


def run_PCA(nets,name=''):
   print('pca')
   for l in [3]: # range(4):
       PCAplots(nets,l,name=name)




def get_acc(act,labels,label_type,test_folds = 20,test_perc=.2):
    IMs = np.shape(act)[0]
    test_size = int(IMs*test_perc) 
    perfs = []
    for tf in range(test_folds):
        np.random.seed(100*tf)
        rand_inds = np.random.choice(IMs,IMs,replace=False) #doing test/train split
        if label_type=='reward':
                clf = skLM.LinearRegression() 
                clf.fit(act[rand_inds[test_size:],:], labels[rand_inds[test_size:]]) 
                perfs.append(clf.score(act[rand_inds[:test_size],:],labels[rand_inds[:test_size]]))
        elif label_type == 'task':
                clf = skLM.LogisticRegression() 
                clf.fit(act[rand_inds[test_size:],:], labels[rand_inds[test_size:]]) 
                perfs.append(clf.score(act[rand_inds[:test_size],:],labels[rand_inds[:test_size]]))
        elif label_type == 'zaxis':
                clf = skLM.LinearRegression() 
                clf1 = skLM.LinearRegression() 
                clf2 = skLM.LinearRegression() 
                clf.fit(act[rand_inds[test_size:],:], labels[rand_inds[test_size:],0]) 
                clf1.fit(act[rand_inds[test_size:],:], labels[rand_inds[test_size:],1])
                clf2.fit(act[rand_inds[test_size:],:], labels[rand_inds[test_size:],2])  
                perfs.append(clf.score(act[rand_inds[:test_size],:],labels[rand_inds[:test_size],0]))
                perfs.append(clf1.score(act[rand_inds[:test_size],:],labels[rand_inds[:test_size],1]))
                perfs.append(clf2.score(act[rand_inds[:test_size],:],labels[rand_inds[:test_size],2]))


    return np.array(perfs)


def classify_perf(net_list):
    layer =3 #should do PCA if using layer other than 3
    reward_perf = []
    task_perf = []
    zax_perf = []
    for netID in net_list:
         loaded = pickle.load( open( act_dir+netID+"_Act.pickle", "rb" ) )
         rew = np.array(loaded[2]).flatten()
         task = np.array(loaded[1]).flatten()
         zax = np.array(loaded[4]); print(zax.shape)
         act = loaded[0][0][0][layer] #ims x X x Y x F
         IMs = act.shape[0]
         print(act.shape)
         act = np.reshape(act,[IMs,-1])

         reward_perf.append(get_acc(act,rew,'reward'))
         task_perf.append(get_acc(act,task,'task'))
         zax_perf.append(get_acc(act,zax,'zaxis'))
    reward_perf = np.array(reward_perf)
    task_perf = np.array(task_perf)
    zax_perf = np.array(zax_perf)
    plt.subplot(3,1,1)
    plt.bar(np.arange(len(net_list)),np.mean(zax_perf,1),color='k')
    plt.errorbar(np.arange(len(net_list)),np.mean(zax_perf,1),np.std(zax_perf,1),ls='',color='k')
    plt.ylim([-.1,1.])
    plt.title('Orientation Prediction')
    plt.xticks([])
    plt.subplot(3,1,2)
    plt.bar(np.arange(len(net_list)),np.mean(reward_perf,1),color='k')
    plt.errorbar(np.arange(len(net_list)),np.mean(reward_perf,1),np.std(reward_perf,1),ls='',color='k')
    plt.ylim([-.1,1.])
    plt.title('Reward Prediction')
    plt.xticks([])
    plt.subplot(3,1,3)
    plt.bar(np.arange(len(net_list)),np.mean(task_perf,1),color='k')
    plt.errorbar(np.arange(len(net_list)),np.mean(task_perf,1),np.std(task_perf,1),ls='',color='k')
    plt.xticks(np.arange(len(net_list)),net_list,rotation=90);
    plt.title('Task Classification')
    plt.ylim([-.1,1.])
    plt.tight_layout()
    plt.savefig(fig_dir+'Classification.png'); plt.close()

def lifetime_sparse(act):
    N = act.shape[0]
    sumsq = np.sum(act,0)**2
    sqsum = np.sum(act**2,0)
    sqsum[sqsum==0] = np.nan
    sprs = (1-(1/N)*(sumsq/sqsum))/(1-(1/N))

    return sprs[~np.isnan(sprs)]


def sparse_violins(net_list,layer,name=''):
    sparseness = []
    for n1 in range(len(net_list)):
        netID = net_list[n1]
        loaded = pickle.load( open( act_dir+netID+"_Act.pickle", "rb" ) )
        act1 = loaded[0][0][0][layer]
        IMs = act1.shape[0]
        act1 = np.reshape(act1,[IMs,-1])
        sparseness.append(lifetime_sparse(act1))
    #print(layer,sparseness)
    plt.violinplot(sparseness,showmedians=True)
    net_names = [net_list[i][:6] for i in range(len(net_list))]
    plt.xticks(np.arange(1,len(net_list)+1),net_names,rotation=90); #plt.yticks(np.arange(len(net_list)),net_list); 
    plt.ylim([-.1,1.1])
    plt.savefig(fig_dir+'Sparseness_'+str(layer)+name); plt.close()

def sparseness_plots(nets,name=''):
   #make violin plots for all nets, one fig per layer
   for l in range(4):
       sparse_violins(nets,l,name)



def pixelaction_comp(net_list):
    for n1 in range(len(net_list)):
        netID = net_list[n1]
        loaded = pickle.load( open( act_dir+netID+"_Act.pickle", "rb" ) )
        actions = np.array(loaded[3]);
        IMs = actions.shape[0]
        pixels = np.array(loaded[5]); print('pix',np.shape(pixels))
        pixels = np.reshape(pixels,[IMs,64*64*3])/255.

        all_dissims = []
        dissim = 1 - np.corrcoef(pixels)
        all_dissims.append(dissim[np.triu_indices(IMs,1)])
        plt.imshow(dissim,vmin=0,vmax=1.5); plt.colorbar(); plt.xticks([]); plt.yticks([])
        plt.savefig(fig_dir+'TaskDissim_'+'Pixel_hl'); plt.close()
        for layer in range(4):
            act1 = loaded[0][0][0][layer]
            act1 = np.reshape(act1,[IMs,-1])
            dissim = 1 - np.corrcoef(act1)
            all_dissims.append(dissim[np.triu_indices(IMs,1)])
        dissim = 1 - np.corrcoef(actions)
        all_dissims.append(dissim[np.triu_indices(IMs,1)])
        plt.imshow(dissim,vmin=0,vmax=1.5); plt.colorbar(); plt.xticks([]); plt.yticks([])
        plt.savefig(fig_dir+'TaskDissim_'+'Action_hl'); plt.close(); b=d
        
        all_dissims = np.array(all_dissims)  
        #rsa_mat, pvals = ss.spearmanr(all_dissims.T)
        rsa_mat = np.corrcoef(all_dissims)
        plt.imshow(rsa_mat,vmin=-.1,vmax=1); plt.colorbar(); plt.title('RSAcr '+netID); 
        np.save(data_dir+'RSAcr_'+netID,rsa_mat)
        plt.savefig(fig_dir+'RSAcr_'+netID); plt.close()


         




