import argparse
import random
import numpy as np
import os


parser = argparse.ArgumentParser(description='NAS Without Training')
parser.add_argument('--data_loc', default='../fishersearch_randomwirenetworks/cifardata/', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='../fimflam/NAS-Bench-201-v1_0-e61699.pth',
                    type=str, help='path to API')
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file')
parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use')
parser.add_argument('--batch_size', default=256, type=int)
parser.add_argument('--repeat', default=256, type=int, help='how often to repeat a single image with a batch')
parser.add_argument('--augtype', default='gaussnoise', type=str, help='which perturbations to use')
parser.add_argument('--sigma', default=0.01, type=float, help='noise level if augtype is "gaussnoise"')
parser.add_argument('--GPU', default='0', type=str)
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--trainval', action='store_true')
parser.add_argument('--activations', action='store_true')
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--n_samples', default=100, type=int)
parser.add_argument('--n_runs', default=500, type=int)
parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)')
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)')
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)')
parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)')
parser.add_argument('--cosine', action='store_true')

args = parser.parse_args()



filename = f'{args.save_loc}/{args.save_string}_correlationmatrix_{args.activations}_{args.nasspace}_{args.dataset}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.seed}{"_cosine" if args.cosine else ""}.npy'
accsfilename = f'{args.save_loc}/{args.save_string}_correlationmatrixaccs_{args.activations}_{args.nasspace}_{args.dataset}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.seed}{"_cosine" if args.cosine else ""}.npy'

corrs = np.load(filename)
accs = np.load(accsfilename)
scores = np.logical_and(corrs > 0, corrs < 0.25).sum(axis=1).astype(np.float64)
#scores = np.maximum(scores, 0.75*corrs.shape[1])

score_found = []
mineig_found = []
rand_found = []
N = 500
M = 10
sample = np.array(random.sample(range(accs.size), k=M))
#mineig_scores = np.zeros_like(sample)
for i in range(N):
    sample = np.array(random.sample(range(accs.size), k=M))
    #for k, j in enumerate(sample):
    #    corr = np.zeros((256, 256))
    #    corr[range(256), range(256)] = 1
    #    corr[np.tril_indices(256, -1)] = corrs[j, :]
    #    corr.T[np.tril_indices(256, -1)] = corr[np.tril_indices(256, -1)] 
    #    try:
    #        score = np.linalg.eigvals(corr)
    #        mineig_scores[k] = np.min(score)
    #    except:
    #        mineig_scores[k] = -999.

    #    score = np.log(np.sort(score))
    #    score = score.sum()-256*score[0]
    #    #scores[j] = np.linalg.slogdet(corr)[1]

    #mineig_found.append(accs[sample[np.argmax(mineig_scores)]])
    score_found.append(accs[sample[np.argmax(scores[sample])]])
    rand_found.append(accs[sample[0]])

rand_mean = np.mean(rand_found) 
rand_std = np.std(rand_found)
score_mean = np.mean(score_found)
score_std = np.std(score_found)
#mineig_mean = np.mean(mineig_found)
#mineig_std = np.std(mineig_found)
print(f'Random: {rand_mean} [{rand_mean-rand_std}, {rand_mean+rand_std}]')
print(f'Interval score: {score_mean} [{score_mean-score_std}, {score_mean+score_std}]')
#print(f'Min eig score: {mineig_mean} [{mineig_mean-mineig_std}, {mineig_mean+mineig_std}]')
