import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
import random


##preds = np.load('results/naswot_testpredictions_nasbench101_cifar10_gaussnoise_0.001_256_True_256_1.npy')
##accs = np.load('results/naswot_testaccs_nasbench101_cifar10_gaussnoise_0.001_256_True_256_1.npy')
#corrs = np.load('results/naswot_correlationmatrix_nasbench101_cifar10_gaussnoise_0.01_256_True_256_1.npy')
#accs = np.load('results/naswot_correlationmatrixaccs_nasbench101_cifar10_gaussnoise_0.01_256_True_256_1.npy')
#
#score1a = np.logical_and(corrs < 0.25, corrs > 0.).sum(axis=1)
#score2a = np.logical_and(corrs < 0.5, corrs > 0.).sum(axis=1)
#score3a = np.logical_and(corrs < 0.75, corrs > 0.).sum(axis=1)
#score4a = np.logical_and(corrs < 1., corrs > 0.).sum(axis=1)
#
#
#score1b = np.logical_and(corrs < 0.25, corrs > 0.25).sum(axis=1)
#score2b = np.logical_and(corrs < 0.5, corrs > 0.25).sum(axis=1)
#score3b = np.logical_and(corrs < 0.75, corrs > 0.25).sum(axis=1)
#score4b = np.logical_and(corrs < 1., corrs > 0.25).sum(axis=1)
#
#
#score1c = np.logical_and(corrs < 0.25, corrs > 0.5).sum(axis=1)
#score2c = np.logical_and(corrs < 0.5, corrs > 0.5).sum(axis=1)
#score3c = np.logical_and(corrs < 0.75, corrs > 0.5).sum(axis=1)
#score4c = np.logical_and(corrs < 1., corrs > 0.5).sum(axis=1)
#
#
#score1d = np.logical_and(corrs < 0.25, corrs > 0.75).sum(axis=1)
#score2d = np.logical_and(corrs < 0.5, corrs > 0.75).sum(axis=1)
#score3d = np.logical_and(corrs < 0.75, corrs > 0.75).sum(axis=1)
#score4d = np.logical_and(corrs < 1., corrs > 0.75).sum(axis=1)
#
#
##plt.scatter(accs, preds)
#fig, axes = plt.subplots(4, 4)
#
#for ax in axes.flatten():
#    ax.set_xlabel('test accuracy')
#axes[0, 0].scatter(accs, score1a)
#axes[0, 0].set_ylabel('sumcorr [0, 0.25]')
#axes[0, 1].scatter(accs, score2a)
#axes[0, 1].set_ylabel('sumcorr [0, 0.5]')
#axes[0, 2].scatter(accs, score3a)
#axes[0, 2].set_ylabel('sumcorr [0, 0.75]')
#axes[0, 3].scatter(accs, score4a)
#axes[0, 3].set_ylabel('sumcorr [0, 1.0]')
#
##axes[1, 0].scatter(accs, score1b)
#axes[1, 0].axis('off')
#axes[1, 1].scatter(accs, score2b)
#axes[1, 1].set_ylabel('sumcorr [0.25, 0.5]')
#axes[1, 2].scatter(accs, score3b)
#axes[1, 2].set_ylabel('sumcorr [0.25, 0.75]')
#axes[1, 3].scatter(accs, score4b)
#axes[1, 3].set_ylabel('sumcorr [0.25, 1.0]')
#axes[2, 0].axis('off')
#axes[2, 1].axis('off')
##axes[2, 0].scatter(accs, score1c)
##axes[2, 1].scatter(accs, score2c)
#axes[2, 2].scatter(accs, score3c)
#axes[2, 2].set_ylabel('sumcorr [0.5, 0.75]')
#axes[2, 3].scatter(accs, score4c)
#axes[2, 3].set_ylabel('sumcorr [0.5, 1.0]')
#axes[3, 0].axis('off')
#axes[3, 1].axis('off')
#axes[3, 2].axis('off')
##axes[3, 0].scatter(accs, score1d)
##axes[3, 1].scatter(accs, score2d)
##axes[3, 2].scatter(accs, score3d)
#axes[3, 3].scatter(accs, score4d)
#axes[3, 3].set_ylabel('sumcorr [0.75, 1.0]')



def plot_jacobcorr_prediction(name, out101, accs101, outs201, accs201, ax, testsplit='both'):
    M = 10000
    scores101 = []
    for j, i in zip(range(2*M), range(outs101.shape[0])):
        
        #corrs = np.corrcoef(outs101[i, :, :])
        #corrs[corrs < 0.] = 1.
        #corrs = corrs
    
        #corrs = np.abs(corrs)
        corrs = np.zeros((256, 256))
        ##print(corrs.shape)
        corrs[range(256), range(256)] = 1
        corrs[np.tril_indices(256, -1)] = outs101[i, :]
        #corrs = np.maximum(corrs, corrs.T)
        corrs.T[np.tril_indices(256, -1)] = corrs[np.tril_indices(256, -1)] 
        #corrs[corrs < 0.] = 0.
        #s, score = np.linalg.slogdet(corrs)
        #score = np.max(np.linalg.eigvals(corrs))
        score = np.linalg.eigvals(corrs)
        score = np.sort(score)
        scores101.append(score)
        #plt.scatter(accs[i], score, c='b', alpha=0.01)
        if j % 1000 == 0:
            print(j)
    
    scores201 = []
    for j, i in zip(range(M), range(outs201.shape[0])):
        
        #corrs = np.corrcoef(outs201[i, :, :])
        corrs = np.zeros((256, 256))
        corrs[range(256), range(256)] = 1
        corrs[np.tril_indices(256, -1)] = outs201[i, :]
        corrs.T[np.tril_indices(256, -1)] = corrs[np.tril_indices(256, -1)] 
        #corrs = np.maximum(corrs, corrs.T)

        #corrs[corrs < 0.] = 0.
        #s, score = np.linalg.slogdet(corrs)
        #try:
        #    score = np.nanmax(np.linalg.eigvals(corrs))
        #except:
        #    score = np.nan
        try:
            score = np.linalg.eigvals(corrs)
            score = np.sort(score)
        except:
            score = np.array([np.nan for k in range(256)])
        scores201.append(score)
        #plt.scatter(accs[i], score, c='b', alpha=0.01)
        if j % 1000 == 0:
            print(j)
    
    
    
    #ax.set_ylabel('logdet(J)')
    #ax.set_ylabel('logdet(activation correlation)')
    ax.set_ylabel('Prediction')
    ax.set_xlabel('Test accuracy')
    
    scores101 = np.array(scores101)
    accs101 = 100.*accs101[:len(scores101)]
    inds = accs101 > 50.
    #scores101 = scores101[inds]
    scores101 = scores101[inds, :]
    accs101 = accs101[inds]
    scores201 = np.array(scores201)
    accs201 = accs201[:len(scores201)]
    
    if testsplit == 'both':
        scores = np.concatenate([scores201, scores101])
        print(scores.shape)
        accs = np.concatenate([accs201, accs101])
        inds = np.isnan(scores).any(axis=1)
        accs = accs[~inds]
        scores = scores[~inds, :]
        inds = list(range(accs.size))
        random.shuffle(inds)
        print(inds[:10])
        train_scores = scores[inds[:len(inds)//2], :]
        train_accs = accs[inds[:len(inds)//2]]
        test_scores = scores[inds[len(inds)//2:], :]
        test_accs = accs[inds[len(inds)//2:]]
    elif testsplit == 'nasbench101':
        inds = np.isnan(scores201).any(axis=1)
        accs201 = accs201[~inds]
        scores201 = scores201[~inds]
        inds = range(0, 256, 50)
        scores201 = scores201[:, inds]
        scores101 = scores101[:, inds]
        train_scores = scores201
        train_accs = accs201
        test_scores = scores101
        test_accs = accs101
        
    elif testsplit == 'nasbench201':
        inds = np.isnan(scores201).any(axis=1)
        accs201 = accs201[~inds]
        scores201 = scores201[~inds]
        inds = range(0, 256, 50)
        scores201 = scores201[:, inds]
        scores101 = scores101[:, inds]
        train_scores = scores101
        train_accs = accs101
        test_scores = scores201
        test_accs = accs201

    #print(train_scores.shape)
    train_scores = train_scores/(train_scores[:, 0].reshape(-1, 1))
    train_scores = train_scores[:, 1:]
    inds = np.array(random.sample(range(train_scores.shape[0]), k=train_scores.shape[0]//2))
    test_scores2 = train_scores[inds, :]
    test_accs2 = train_accs[inds]
    train_scores = train_scores[~inds, :]
    train_accs = train_accs[~inds]
    #ax.scatter(train_scores[:, 0], train_accs)
    #ax.scatter(train_scores[:, 1], 50.+train_accs)
    #ax.scatter(train_scores[:, 2], 100.+train_accs)
    #ax.scatter(train_scores[:, 3], 150.+train_accs)
    #ax.scatter(train_scores[:, 4], 200.+train_accs)
    #ax.scatter(train_scores[:, 5], 250.+train_accs)
    

    test_scores = test_scores/(test_scores[:, 0].reshape(-1, 1))
    test_scores = test_scores[:, 1:]
    #ax.scatter(test_scores[:, 0] + 10., test_accs)
    #ax.scatter(test_scores[:, 1] + 10., 50.+test_accs)
    #ax.scatter(test_scores[:, 2] + 10., 100.+test_accs)
    #ax.scatter(test_scores[:, 3] + 10., 150.+test_accs)
    #ax.scatter(test_scores[:, 4] + 10., 200.+test_accs)
    #ax.scatter(test_scores[:, 5] + 10., 250.+test_accs)
    #ax.set_ylabel('Prediction')
    #ax.set_xlabel('Test Accuracy')
    #return None
    model = RandomForestRegressor(min_samples_split=50)
    #model = LinearRegression()
    model.fit(train_scores, train_accs)
    from sklearn.tree import export_graphviz
    # Export as dot file
    for g, estimator in enumerate(model.estimators_):
        export_graphviz(estimator, out_file=f'trees/tree_{name}_{g}.dot', 
                        feature_names = [f'eig{f}' for f in range(5)],
                        #class_names = iris.target_names,
                        rounded = True, proportion = False, 
                        precision = 2, filled = True)
    train_preds = model.predict(train_scores)
    test_preds = model.predict(test_scores)
    test_preds2 = model.predict(test_scores2)
    
    ax.scatter(train_accs, train_preds, label='Train')
    ax.scatter(test_accs2, test_preds2, label='Test')
    ax.scatter(test_accs, test_preds, label=f'Test {testsplit}')
    ax.legend() 
    #ax.scatter(accs101, scores101, c='b', alpha=0.05)
    #kc, _ = stats.kendalltau(accs101[:len(scores101)], scores101)
    #ax.text(0.1, 0.9, f'kendall: {kc:.3f}', ha='center', va='center', transform=ax.transAxes)
    #ax.scatter(accs201, scores201, c='r', alpha=0.05)


#preds = np.load('results/naswot_trainpredictions_nasbench101_cifar10_gaussnoise_0.001_256_True_256_1.npy')
#accs = np.load('results/naswot_trainaccs_nasbench101_cifar10_gaussnoise_0.001_256_True_256_1.npy')
#plt.scatter(accs, preds)


#outs101 = np.load('results/naswot_correlationmatrix_True_nasbench101_cifar10_gaussnoise_0.01_256_True_256_1.npy')
#accs101 = np.load('results/naswot_correlationmatrixaccs_True_nasbench101_cifar10_gaussnoise_0.01_256_True_256_1.npy')
#outs101 = outs101.reshape(-1, 256, 64)
#
#
#outs201 = np.load('results/naswot_correlationmatrix_True_nasbench201_cifar10_gaussnoise_0.01_256_True_256_1.npy')
#accs201 = np.load('results/naswot_correlationmatrixaccs_True_nasbench201_cifar10_gaussnoise_0.01_256_True_256_1.npy')
#outs201 = outs201.reshape(-1, 256, 64)

fig, axes = plt.subplots(1, 2)


#outs101 = np.load('results/naswot_correlationmatrix_nasbench101_cifar10_gaussnoise_0.01_1_True_256_1.npy')
#accs101 = np.load('results/naswot_correlationmatrixaccs_nasbench101_cifar10_gaussnoise_0.01_1_True_256_1.npy')
#outs201 = np.load('results/naswot_correlationmatrix_False_nasbench201_cifar10_gaussnoise_0.01_1_True_256_1.npy')
#accs201 = np.load('results/naswot_correlationmatrixaccs_False_nasbench201_cifar10_gaussnoise_0.01_1_True_256_1.npy')
#ax = axes[1]
#ax.set_title('RandomForest regression using\nscaled eigenvalues of\nJacobian correlation matrix\nwith standard minibatch')
#plot_jacobcorr_prediction('nasbench101train201test_normalbatch', outs101, accs101, outs201, accs201, ax, testsplit='nasbench201')
outs101, outs201 = None, None
outs101 = np.load('results/naswot_correlationmatrix_nasbench101_cifar10_gaussnoise_0.01_256_True_256_1.npy')
accs101 = np.load('results/naswot_correlationmatrixaccs_nasbench101_cifar10_gaussnoise_0.01_256_True_256_1.npy')
outs201 = np.load('results/naswot_correlationmatrix_False_nasbench201_cifar10_gaussnoise_0.01_256_True_256_1.npy')
accs201 = np.load('results/naswot_correlationmatrixaccs_False_nasbench201_cifar10_gaussnoise_0.01_256_True_256_1.npy')
ax = axes[0]
ax.set_title('RandomForest regression using\nscaled eigenvalues of\nJacobian correlation matrix\nwith single image + noise')
plot_jacobcorr_prediction('nasbench101train201test_noisebatch', outs101, accs101, outs201, accs201, ax, testsplit='nasbench201')


plt.show()
