import argparse
import nasspace
import datasets
import random
import numpy as np
import torch
import os
from scores import get_score_func
import matplotlib as mp
import matplotlib
import matplotlib.pyplot as plt

font = {
        'size'   : 18}

matplotlib.rc('font', **font)



parser = argparse.ArgumentParser(description='NAS Without Training')
parser.add_argument('--data_loc', default='../fishersearch_randomwirenetworks/cifardata/', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='../fimflam/NAS-Bench-201-v1_0-e61699.pth',
                    type=str, help='path to API')
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file')
parser.add_argument('--score', default='corrdistintegral0_025', type=str, help='the score to evaluate')
parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use')
parser.add_argument('--batch_size', default=256, type=int)
parser.add_argument('--repeat', default=256, type=int, help='how often to repeat a single image with a batch')
parser.add_argument('--augtype', default='gaussnoise', type=str, help='which perturbations to use')
parser.add_argument('--sigma', default=0.01, type=float, help='noise level if augtype is "gaussnoise"')
parser.add_argument('--GPU', default='0', type=str)
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--trainval', action='store_true')
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--maxofn', default=10, type=int, help='score is the max of this many evaluations of the network')
parser.add_argument('--n_samples', default=100, type=int)
parser.add_argument('--n_runs', default=500, type=int)
parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)')
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)')
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)')
parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)')

args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU

# Reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)



colours = ['#811F41', '#A92941', '#D15141', '#EF7941', '#F99C4B']
colours = [mp.colors.to_rgba(c) for c in colours]
colours2 = ['#190C30', '#241147', '#34208C', '#4882FA', '#81BAFC']
colours2 = [mp.colors.to_rgba(c) for c in colours2]
if args.dataset == 'cifar10':
    cdict = {'red': [[0., colours[0][0], colours[0][0]]] + [[(0.3/5.)*i + 0.7, colours[i][0], colours[i][0]] for i in range(len(colours))]  + [[1., colours[-1][0], colours[-1][0]]],
             'green':[[0., colours[0][1], colours[0][1]]] +  [[(0.3/5.)*i + 0.7, colours[i][1], colours[i][1]] for i in range(len(colours))]  + [[1., colours[-1][1], colours[-1][1]]],
             'blue':[[0., colours[0][2], colours[0][2]]] +  [[(0.3/5.)*i + 0.7, colours[i][2], colours[i][2]] for i in range(len(colours))]  + [[1., colours[-1][2], colours[-1][2]]]}
    cdict2 = {'red': [[0., colours2[0][0], colours2[0][0]]] + [[(0.3/5.)*i + 0.7, colours2[i][0], colours2[i][0]] for i in range(len(colours2))]  + [[1., colours2[-1][0], colours2[-1][0]]],
             'green':[[0., colours2[0][1], colours2[0][1]]] +  [[(0.3/5.)*i + 0.7, colours2[i][1], colours2[i][1]] for i in range(len(colours2))]  + [[1., colours2[-1][1], colours2[-1][1]]],
             'blue':[[0., colours2[0][2], colours2[0][2]]] +  [[(0.3/5.)*i + 0.7, colours2[i][2], colours2[i][2]] for i in range(len(colours2))]  + [[1., colours2[-1][2], colours2[-1][2]]]}
elif args.dataset == 'cifar100':
    
    cdict = {'red': [[0., colours[0][0], colours[0][0]]] + [[0.1*i + 0.3, colours[i][0], colours[i][0]]   for i in range(len(colours))]  + [[1., colours[-1][0], colours[-1][0]]] ,
             'green':[[0., colours[0][1], colours[0][1]]] +  [[0.1*i + 0.3, colours[i][1], colours[i][1]]  for i in range(len(colours))] + [[1., colours[-1][1], colours[-1][1]]] ,
             'blue':[[0., colours[0][2], colours[0][2]]] +  [[0.1*i + 0.3, colours[i][2], colours[i][2]]  for i in range(len(colours))]  + [[1., colours[-1][2], colours[-1][2]]] }
else:
    cdict = {'red': [[0.1*i, colours[i][0], colours[i][0]]    for i in range(len(colours))] + [[1., colours[-1][0], colours[-1][0]]] ,
             'green': [[0.1*i, colours[i][1], colours[i][1]]  for i in range(len(colours))] + [[1., colours[-1][1], colours[-1][1]]] ,
             'blue': [[0.1*i, colours[i][2], colours[i][2]]   for i in range(len(colours))] + [[1., colours[-1][2], colours[-1][2]]] }
newcmp = mp.colors.LinearSegmentedColormap('testCmap', segmentdata=cdict, N=256)
newcmp2 = mp.colors.LinearSegmentedColormap('testCmap', segmentdata=cdict2, N=256)



def get_batch_jacobian(net, x, target, device, args=None):
    net.zero_grad()
    x.requires_grad_(True)
    y, _ = net(x)
    y.backward(torch.ones_like(y))
    jacob = x.grad.detach()
    return jacob, target.detach()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
searchspace = nasspace.get_search_space(args)
train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
os.makedirs(args.save_loc, exist_ok=True)


if args.dataset == 'cifar10':
    acc_type = 'ori-test'
    val_acc_type = 'x-valid'
else:
    acc_type = 'x-test'
    val_acc_type = 'x-valid'



accs = np.zeros(len(searchspace))
for i in range(len(searchspace)):
    accs[i] = searchspace.get_accuracy(searchspace[i], acc_type, args.trainval)

    if i >  5000:
        accs = accs[:i]
        break
boundaries = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
networks = []
netaccs = []
for p1, p2 in zip(boundaries[:-1], boundaries[1:]):
    a1 = np.percentile(accs, p1)
    a2 = np.percentile(accs, p2)
    print(f'a1: {a1}, a2: {a2}')
    inds = np.where(np.logical_and(accs > a1, accs <= a2))[0]
    print(f'inds: {inds}')
    networks.append(np.random.choice(inds))
    netaccs.append(accs[networks[-1]])



#fig, axes = plt.subplots(1, 3, figsize=(15, 8))
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
figsmall, axessmall = plt.subplots(1, 2, figsize=(10, 4))
fig.set_tight_layout(True)
figsmall.set_tight_layout(True)

# ablate images
ss = []
for k, i in enumerate(networks):
    uid = searchspace[i]
    network = searchspace.get_network(uid)
    # Reproducibility
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    s = []
    for j in range(args.maxofn):
        data_iterator = iter(train_loader)
        x, target = next(data_iterator)
        x, target = x.to(device), target.to(device)
        network = network.to(device)
        jacobs, labels = get_batch_jacobian(network, x, target, device, args)
        s.append(get_score_func(args.score)(jacobs, labels))
    ss.append(s)
    #axes[0].scatter([netaccs[k]]*len(s), s)
    #axes[0].scatter([netaccs[k]]*len(s), s, c=newcmp(netaccs[k]/100. if args.nasspace == 'nasbench201' else netaccs[k]))
    #axes[0].set_ylabel('score')
    #axes[0].set_xlabel('CIFAR10 test accuracy')
    #axes[0].spines["top"].set_visible(False)
    #axes[0].spines["right"].set_visible(False)
    c = newcmp(netaccs[k]/100. if args.nasspace == 'nasbench201' else netaccs[k])
    axes[0, 0].boxplot(s, positions=[k], notch=True, patch_artist=True, boxprops=dict(facecolor=c, color=c), capprops=dict(color=c), whiskerprops=dict(color=c), flierprops=dict(color=c, markeredgecolor=c), medianprops=dict(color=c), widths=0.7)
    axessmall[0].boxplot(s, positions=[k], notch=True, patch_artist=True, boxprops=dict(facecolor=c, color=c), capprops=dict(color=c), whiskerprops=dict(color=c), flierprops=dict(color=c, markeredgecolor=c), medianprops=dict(color=c), widths=0.7)
#axes[0].boxplot(ss)
axes[0, 0].set_title('CIFAR-10 images')
axes[0, 0].set_ylabel('score')
axes[0, 0].set_xlabel('Uninitialised network')
axes[0, 0].spines["top"].set_visible(False)
axes[0, 0].spines["right"].set_visible(False)

axessmall[0].set_title('CIFAR-10 images')
axessmall[0].set_ylabel('score')
axessmall[0].set_xlabel('Uninitialised network')
axessmall[0].spines["top"].set_visible(False)
axessmall[0].spines["right"].set_visible(False)

# ablate initialisation
ss = []
for k, i in enumerate(networks):
    uid = searchspace[i]
    # Reproducibility
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    data_iterator = iter(train_loader)
    x, target = next(data_iterator)
    x, target = x.to(device), target.to(device)
    s = []
    for j in range(args.maxofn):
        network = searchspace.get_network(uid)
        network = network.to(device)
        jacobs, labels = get_batch_jacobian(network, x, target, device, args)
        s.append(get_score_func(args.score)(jacobs, labels))
    ss.append(s)
    #axes[1].scatter([netaccs[k]]*len(s), s)
    #axes[1].scatter([netaccs[k]]*len(s), s, c=newcmp(netaccs[k]/100. if args.nasspace == 'nasbench201' else netaccs[k]))
    #axes[1].set_ylabel('score')
    #axes[1].set_xlabel('CIFAR10 test accuracy')
    #axes[1].spines["top"].set_visible(False)
    #axes[1].spines["right"].set_visible(False)
    c = newcmp(netaccs[k]/100. if args.nasspace == 'nasbench201' else netaccs[k])
    axes[1, 0].boxplot(s, positions=[k], notch=True, patch_artist=True, boxprops=dict(facecolor=c, color=c), capprops=dict(color=c), whiskerprops=dict(color=c), flierprops=dict(color=c, markeredgecolor=c), medianprops=dict(color=c), widths=0.7)
axes[1, 0].set_title('Initialisations')
axes[1, 0].set_ylabel('score')
axes[1, 0].set_xlabel('Uninitialised network')
axes[1, 0].spines["top"].set_visible(False)
axes[1, 0].spines["right"].set_visible(False)

# ablate random images
train_loader = datasets.get_data('fake', args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
ss = []
for k, i in enumerate(networks):
    uid = searchspace[i]
    network = searchspace.get_network(uid)
    # Reproducibility
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    s = []
    for j in range(args.maxofn):
        data_iterator = iter(train_loader)
        x, target = next(data_iterator)
        x, target = x.to(device), target.to(device)
        network = network.to(device)
        jacobs, labels = get_batch_jacobian(network, x, target, device, args)
        s.append(get_score_func(args.score)(jacobs, labels))
    ss.append(s)
    #axes[2].scatter([netaccs[k]]*len(s), s, c=newcmp(netaccs[k]/100. if args.nasspace == 'nasbench201' else netaccs[k]))
    #axes[2].set_ylabel('score')
    #axes[2].set_xlabel('CIFAR10 test accuracy')
    #axes[2].spines["top"].set_visible(False)
    #axes[2].spines["right"].set_visible(False)
    c = newcmp(netaccs[k]/100. if args.nasspace == 'nasbench201' else netaccs[k])
    axes[0, 1].boxplot(s, positions=[k], notch=True, patch_artist=True, boxprops=dict(facecolor=c, color=c), capprops=dict(color=c), whiskerprops=dict(color=c), flierprops=dict(color=c, markeredgecolor=c), medianprops=dict(color=c), widths=0.7)
    axessmall[1].boxplot(s, positions=[k], notch=True, patch_artist=True, boxprops=dict(facecolor=c, color=c), capprops=dict(color=c), whiskerprops=dict(color=c), flierprops=dict(color=c, markeredgecolor=c), medianprops=dict(color=c), widths=0.7)
#axes[2].boxplot(ss)
axes[0, 1].set_title('Random images')
axes[0, 1].set_ylabel('score')
axes[0, 1].set_xlabel('Uninitialised network')
axes[0, 1].spines["top"].set_visible(False)
axes[0, 1].spines["right"].set_visible(False)

axessmall[1].set_title('Random images')
axessmall[1].set_ylabel('score')
axessmall[1].set_xlabel('Uninitialised network')
axessmall[1].spines["top"].set_visible(False)
axessmall[1].spines["right"].set_visible(False)


batch_size = 32
filename = f'{args.save_loc}/{args.save_string}_{args.score}_{args.nasspace}_{args.dataset}_{args.augtype}_{args.sigma}_{batch_size}_{args.trainval}_{batch_size}_1_{args.seed}.npy'
accfilename = f'{args.save_loc}/{args.save_string}_accs_{args.nasspace}_{args.dataset}_{args.trainval}.npy'
scores = np.load(filename)
accs = np.load(accfilename)
inds = np.array(networks)
accs = accs[inds]
allscores = [[float(s)/float(batch_size**2)] for s in scores[inds]]

for batch_size in [64, 128, 256, 512]:
    filename = f'{args.save_loc}/{args.save_string}_{args.score}_{args.nasspace}_{args.dataset}_{args.augtype}_{args.sigma}_{batch_size}_{args.trainval}_{batch_size}_1_{args.seed}.npy'
    accfilename = f'{args.save_loc}/{args.save_string}_accs_{args.nasspace}_{args.dataset}_{args.trainval}.npy'
    scores = np.load(filename)
    for i, s in enumerate(scores[inds]):
        allscores[i].append(float(s)/float(batch_size**2))




for k in range(len(allscores)):
    c = newcmp2(accs[k]/100. if args.nasspace == 'nasbench201' else accs[k])
    axes[1, 1].boxplot(allscores[k], positions=[k], notch=True, patch_artist=True, boxprops=dict(facecolor=c, color=c), capprops=dict(color=c), whiskerprops=dict(color=c), flierprops=dict(color=c, markeredgecolor=c), medianprops=dict(color=c), widths=0.7)
axes[1, 1].set_title('Mini-batch size')
axes[1, 1].set_ylabel('Normalised score')
axes[1, 1].set_xlabel('Uninitialised network')
axes[1, 1].spines["top"].set_visible(False)
axes[1, 1].spines["right"].set_visible(False)

filename = f'{args.save_loc}/{args.save_string}_ablation_{args.score}_{args.nasspace}_{args.dataset}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.maxofn}_{args.seed}.pdf'
filenamesmall = f'{args.save_loc}/{args.save_string}_ablationsmall_{args.score}_{args.nasspace}_{args.dataset}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.maxofn}_{args.seed}.pdf'

plt.tight_layout()
fig.savefig(filename)
figsmall.savefig(filenamesmall)
plt.show()

