import argparse
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mp
import matplotlib
from decimal import Decimal
from scipy.special import logit, expit
from scipy import stats

'''
font = {
        'size'   : 18}

matplotlib.rc('font', **font)
'''
SMALL_SIZE = 10
MEDIUM_SIZE = 12
BIGGER_SIZE = 14

plt.rc('font', size=MEDIUM_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

parser = argparse.ArgumentParser(description='NAS Without Training')
parser.add_argument('--data_loc', default='../fishersearch_randomwirenetworks/cifardata/', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='../fimflam/NAS-Bench-201-v1_0-e61699.pth',
                    type=str, help='path to API')
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file')
parser.add_argument('--score', default='corrdistintegral0_025', type=str, help='the score to evaluate')
parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use')
parser.add_argument('--batch_size', default=256, type=int)
parser.add_argument('--repeat', default=256, type=int, help='how often to repeat a single image with a batch')
parser.add_argument('--augtype', default='gaussnoise', type=str, help='which perturbations to use')
parser.add_argument('--sigma', default=0.01, type=float, help='noise level if augtype is "gaussnoise"')
parser.add_argument('--GPU', default='0', type=str)
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--trainval', action='store_true')
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--maxofn', default=10, type=int, help='score is the max of this many evaluations of the network')
parser.add_argument('--n_samples', default=100, type=int)
parser.add_argument('--n_runs', default=500, type=int)
parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)')
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)')
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)')
parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)')

args = parser.parse_args()

random.seed(args.seed)
np.random.seed(args.seed)

filename = f'{args.save_loc}/{args.save_string}_{args.score}_{args.nasspace}_{args.dataset}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.maxofn}_{args.seed}.npy'
accfilename = f'{args.save_loc}/{args.save_string}_accs_{args.nasspace}_{args.dataset}_{args.trainval}.npy'

scores = np.load(filename)
accs = np.load(accfilename)

if args.nasspace == 'nasbench201':
    colours = ['#811F41', '#A92941', '#D15141', '#EF7941', '#F99C4B']
else:
    colours = ['#190C30', '#241147', '#34208C', '#4882FA', '#81BAFC']
colours = [mp.colors.to_rgba(c) for c in colours]
if args.dataset == 'cifar10':
    cdict = {'red': [[0., colours[0][0], colours[0][0]]] + [[0.1*i + 0.6, colours[i][0], colours[i][0]] for i in range(len(colours))],
             'green':[[0., colours[0][1], colours[0][1]]] +  [[0.1*i + 0.6, colours[i][1], colours[i][1]] for i in range(len(colours))],
             'blue':[[0., colours[0][2], colours[0][2]]] +  [[0.1*i + 0.6, colours[i][2], colours[i][2]] for i in range(len(colours))]}
elif args.dataset == 'cifar100':

    cdict = {'red': [[0., colours[0][0], colours[0][0]]] + [[0.1*i + 0.3, colours[i][0], colours[i][0]]   for i in range(len(colours))]  + [[1., colours[-1][0], colours[-1][0]]] ,
             'green':[[0., colours[0][1], colours[0][1]]] +  [[0.1*i + 0.3, colours[i][1], colours[i][1]]  for i in range(len(colours))] + [[1., colours[-1][1], colours[-1][1]]] ,
             'blue':[[0., colours[0][2], colours[0][2]]] +  [[0.1*i + 0.3, colours[i][2], colours[i][2]]  for i in range(len(colours))]  + [[1., colours[-1][2], colours[-1][2]]] }
else:
    cdict = {'red': [[0.1*i, colours[i][0], colours[i][0]]    for i in range(len(colours))] + [[1., colours[-1][0], colours[-1][0]]] ,
             'green': [[0.1*i, colours[i][1], colours[i][1]]  for i in range(len(colours))] + [[1., colours[-1][1], colours[-1][1]]] ,
             'blue': [[0.1*i, colours[i][2], colours[i][2]]   for i in range(len(colours))] + [[1., colours[-1][2], colours[-1][2]]] }
newcmp = mp.colors.LinearSegmentedColormap('testCmap', segmentdata=cdict, N=256)



if args.nasspace == 'nasbench101':
    inds = accs > 0.5
    accs = accs[inds]
    scores = scores[inds]
    print(accs.shape)

inds = np.random.choice(accs.size, 1000, replace=False)
accs = accs[inds]
scores = scores[inds]

tau, p = stats.kendalltau(accs, scores)

if args.nasspace == 'nasbench101':
    #fig, ax = plt.subplots(1, 1, figsize=(15,8))
    fig, ax = plt.subplots(1, 1, figsize=(5,5))
else:
    fig, ax = plt.subplots(1, 1, figsize=(5,5))

def scale(x):
    return 2.**(10*x) - 1.

ax.scatter(scale(accs/100. if args.nasspace == 'nasbench201' else accs), scores, c=newcmp(accs/100. if args.nasspace == 'nasbench201' else accs))
if args.dataset == 'cifar100':
    ax.set_xticks([scale(float(a)/100.) for a in [40, 60, 70]])
    ax.set_xticklabels([f'{a}' for a in [40, 60, 70]])
elif args.dataset == 'ImageNet16-120':
    ax.set_xticks([scale(float(a)/100.) for a in [20, 30, 40, 45]])
    ax.set_xticklabels([f'{a}' for a in [20, 30, 40, 45]])
elif args.nasspace == 'nasbench101' and args.dataset == 'cifar10':
    ax.set_xticks([scale(float(a)/100.) for a in [50, 80, 90, 95]])
    ax.set_xticklabels([f'{a}' for a in [50, 80, 90, 95]])
else:
    ax.set_xticks([scale(float(a)/100.) for a in [50, 80, 90]])
    ax.set_xticklabels([f'{a}' for a in [50, 80, 90]])
    #ax.set_xscale('symlog')
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)

nasspacenames = {
    'nasbench101': 'NAS-Bench-101',
    'nasbench201': 'NAS-Bench-201'
}

ax.set_ylabel('Score')
ax.set_xlabel(f'{"Test" if not args.trainval else "Validation"} accuracy')
ax.set_title(f'{nasspacenames[args.nasspace]} {args.dataset} \n $\\tau=${tau:.3f}')
ax.set_yticks(list(range(0, 256*256, 10000)))
ax.set_yticklabels([f'{Decimal(r):1E}'.replace('000', '') for r in list(range(0, 256*256, 10000))])
filename = f'{args.save_loc}/{args.save_string}_{args.score}_{args.nasspace}_{args.dataset}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.maxofn}_{args.seed}.pdf'
#ax.text(0.2, 0.9, f'kendall-tau: {tau:.3f}', ha='center', va='center', transform=ax.transAxes)
plt.tight_layout()
plt.savefig(filename)

plt.show()
