import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats
from sklearn.metrics.pairwise import cosine_similarity


##preds = np.load('results/naswot_testpredictions_nasbench101_cifar10_gaussnoise_0.001_256_True_256_1.npy')
##accs = np.load('results/naswot_testaccs_nasbench101_cifar10_gaussnoise_0.001_256_True_256_1.npy')
#corrs = np.load('results/naswot_correlationmatrix_nasbench101_cifar10_gaussnoise_0.01_256_True_256_1.npy')
#accs = np.load('results/naswot_correlationmatrixaccs_nasbench101_cifar10_gaussnoise_0.01_256_True_256_1.npy')
#
#score1a = np.logical_and(corrs < 0.25, corrs > 0.).sum(axis=1)
#score2a = np.logical_and(corrs < 0.5, corrs > 0.).sum(axis=1)
#score3a = np.logical_and(corrs < 0.75, corrs > 0.).sum(axis=1)
#score4a = np.logical_and(corrs < 1., corrs > 0.).sum(axis=1)
#
#
#score1b = np.logical_and(corrs < 0.25, corrs > 0.25).sum(axis=1)
#score2b = np.logical_and(corrs < 0.5, corrs > 0.25).sum(axis=1)
#score3b = np.logical_and(corrs < 0.75, corrs > 0.25).sum(axis=1)
#score4b = np.logical_and(corrs < 1., corrs > 0.25).sum(axis=1)
#
#
#score1c = np.logical_and(corrs < 0.25, corrs > 0.5).sum(axis=1)
#score2c = np.logical_and(corrs < 0.5, corrs > 0.5).sum(axis=1)
#score3c = np.logical_and(corrs < 0.75, corrs > 0.5).sum(axis=1)
#score4c = np.logical_and(corrs < 1., corrs > 0.5).sum(axis=1)
#
#
#score1d = np.logical_and(corrs < 0.25, corrs > 0.75).sum(axis=1)
#score2d = np.logical_and(corrs < 0.5, corrs > 0.75).sum(axis=1)
#score3d = np.logical_and(corrs < 0.75, corrs > 0.75).sum(axis=1)
#score4d = np.logical_and(corrs < 1., corrs > 0.75).sum(axis=1)
#
#
##plt.scatter(accs, preds)
#fig, axes = plt.subplots(4, 4)
#
#for ax in axes.flatten():
#    ax.set_xlabel('test accuracy')
#axes[0, 0].scatter(accs, score1a)
#axes[0, 0].set_ylabel('sumcorr [0, 0.25]')
#axes[0, 1].scatter(accs, score2a)
#axes[0, 1].set_ylabel('sumcorr [0, 0.5]')
#axes[0, 2].scatter(accs, score3a)
#axes[0, 2].set_ylabel('sumcorr [0, 0.75]')
#axes[0, 3].scatter(accs, score4a)
#axes[0, 3].set_ylabel('sumcorr [0, 1.0]')
#
##axes[1, 0].scatter(accs, score1b)
#axes[1, 0].axis('off')
#axes[1, 1].scatter(accs, score2b)
#axes[1, 1].set_ylabel('sumcorr [0.25, 0.5]')
#axes[1, 2].scatter(accs, score3b)
#axes[1, 2].set_ylabel('sumcorr [0.25, 0.75]')
#axes[1, 3].scatter(accs, score4b)
#axes[1, 3].set_ylabel('sumcorr [0.25, 1.0]')
#axes[2, 0].axis('off')
#axes[2, 1].axis('off')
##axes[2, 0].scatter(accs, score1c)
##axes[2, 1].scatter(accs, score2c)
#axes[2, 2].scatter(accs, score3c)
#axes[2, 2].set_ylabel('sumcorr [0.5, 0.75]')
#axes[2, 3].scatter(accs, score4c)
#axes[2, 3].set_ylabel('sumcorr [0.5, 1.0]')
#axes[3, 0].axis('off')
#axes[3, 1].axis('off')
#axes[3, 2].axis('off')
##axes[3, 0].scatter(accs, score1d)
##axes[3, 1].scatter(accs, score2d)
##axes[3, 2].scatter(accs, score3d)
#axes[3, 3].scatter(accs, score4d)
#axes[3, 3].set_ylabel('sumcorr [0.75, 1.0]')




#preds = np.load('results/naswot_trainpredictions_nasbench101_cifar10_gaussnoise_0.001_256_True_256_1.npy')
#accs = np.load('results/naswot_trainaccs_nasbench101_cifar10_gaussnoise_0.001_256_True_256_1.npy')
#plt.scatter(accs, preds)


#outs = np.load('results/naswot_correlationmatrix_True_nasbench101_cifar10_gaussnoise_0.01_256_True_256_1.npy')
#accs = np.load('results/naswot_correlationmatrixaccs_True_nasbench101_cifar10_gaussnoise_0.01_256_True_256_1.npy')

outs = np.load('results/naswot_correlationmatrix_False_nasbench101_cifar10_gaussnoise_0.01_256_True_256_1_cosine.npy')
accs = np.load('results/naswot_correlationmatrixaccs_False_nasbench101_cifar10_gaussnoise_0.01_256_True_256_1_cosine.npy')

#outs = np.load('results/naswot_correlationmatrix_nasbench101_cifar10_gaussnoise_0.01_256_True_256_1.npy')
#accs = np.load('results/naswot_correlationmatrixaccs_nasbench101_cifar10_gaussnoise_0.01_256_True_256_1.npy')

#outs = outs.reshape(-1, 256, 64)

print(outs.shape)

scores = []
fig2, ax2 = plt.subplots(1, 1)
for j, i in zip(range(1000), range(outs.shape[0])):
    
    #corrs = np.corrcoef(outs[i, :, :])
    #corrs = cosine_similarity(outs[i, :, :])
    #corrs[corrs < 0.] = 1.
    #corrs = corrs 

    #corrs = np.abs(corrs)
    corrs = np.zeros((256, 256))
    ##print(corrs.shape)
    corrs[range(256), range(256)] = 1
    corrs[np.tril_indices(256, -1)] = outs[i, :]
    corrs.T[np.tril_indices(256, -1)] = corrs[np.tril_indices(256, -1)] 
    #corrs = np.maximum(corrs, corrs.T)

    score = np.linalg.eigvals(corrs)
    score = np.sort(score)
    ax2.scatter(score[0], accs[j], c='b')
    ax2.scatter(score[-1], accs[j], c='r')
    s, score = np.linalg.slogdet(corrs)
    scores.append(score)
    #plt.scatter(accs[i], score, c='b', alpha=0.01)
    if j % 1000 == 0:
        print(j)
kc, _ = stats.kendalltau(accs[:len(scores)], scores)
fig, ax = plt.subplots(1, 1)
#ax.set_ylabel('logdet(J)')
ax.set_ylabel('logdet(activation correlation)')
ax.set_xlabel('Test accuracy')

scores = np.array(scores)
accs = accs[:len(scores)]
inds = accs > 0.5
scores = scores[inds]
accs = accs[inds]
ax.scatter(accs, scores, c='b', alpha=0.05)
ax.text(0.1, 0.9, f'kendall-tau: {tau:.3f}', ha='center', va='center', transform=ax.transAxes)

plt.show()
