import argparse
import nasspace
import datasets
import random
import numpy as np
import torch
import os
from scores import get_score_func
import matplotlib.pyplot as plt
import bisect


parser = argparse.ArgumentParser(description='NAS Without Training')
parser.add_argument('--data_loc', default='../fishersearch_randomwirenetworks/cifardata/', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='../fimflam/NAS-Bench-201-v1_0-e61699.pth',
                    type=str, help='path to API')
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file')
parser.add_argument('--score', default='corrdistintegral0_025', type=str, help='the score to evaluate')
parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use')
parser.add_argument('--batch_size', default=256, type=int)
parser.add_argument('--repeat', default=256, type=int, help='how often to repeat a single image with a batch')
parser.add_argument('--augtype', default='gaussnoise', type=str, help='which perturbations to use')
parser.add_argument('--sigma', default=0.01, type=float, help='noise level if augtype is "gaussnoise"')
parser.add_argument('--GPU', default='0', type=str)
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--trainval', action='store_true')
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--maxofn', default=10, type=int, help='score is the max of this many evaluations of the network')
parser.add_argument('--n_samples', default=100, type=int)
parser.add_argument('--n_runs', default=500, type=int)
parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)')
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)')
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)')
parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)')

args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU

# Reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)


def get_batch_jacobian(net, x, target, device, args=None):
    net.zero_grad()
    x.requires_grad_(True)
    y, _ = net(x)
    y.backward(torch.ones_like(y))
    jacob = x.grad.detach()
    return jacob, target.detach()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
searchspace = nasspace.get_search_space(args)
train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
os.makedirs(args.save_loc, exist_ok=True)

filename = f'{args.save_loc}/{args.save_string}_{args.score}_{args.nasspace}_{args.dataset}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.maxofn}_{args.seed}'
accfilename = f'{args.save_loc}/{args.save_string}_accs_{args.nasspace}_{args.dataset}_{args.trainval}'

if args.dataset == 'cifar10':
    acc_type = 'ori-test'
    val_acc_type = 'x-valid'
else:
    acc_type = 'x-test'
    val_acc_type = 'x-valid'
imgs = []
grads = []
boundaries = [50., 60., 70., 75., 80., 85.5, 90., 100.]
found = [0, 0, 0, 0, 0, 0, 0, 0]
fig, ax = plt.subplots(5, 2*(len(found)-1)+1)
for i, (uid, network) in enumerate(searchspace):
    # Reproducibility
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    acc = searchspace.get_accuracy(uid, acc_type, args.trainval)
    if args.nasspace == 'nasbench101':
        acc = 100.*acc
    ind =  bisect.bisect_left(boundaries, acc)
    print('')
    print(acc)
    print(found)
    if found[ind] > 0:
        continue
    found[ind] = 1
    print(found)
    s = []
    for j in range(args.maxofn):
        data_iterator = iter(train_loader)
        x, target = next(data_iterator)
        x, target = x.to(device), target.to(device)

        network = network.to(device)
        jacobs, labels = get_batch_jacobian(network, x, target, device, args)
        #jacobs = jacobs.abs()
        x = 0.5*x/(x.abs().max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]) + 0.5
        x = x/(x.max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0])

                
        print(x.max())
        print(x.min())
        print(jacobs.max())
        print(jacobs.min())
        print('')
    grads.append((acc, jacobs.transpose(1, -1).cpu().detach().numpy()))
     
    if np.sum(found) == len(found):
        break

grads.sort(key=lambda k: k[0])
for i in range(5):
    ax[i, 0].imshow(x.transpose(1, -1)[i, :, :, :].cpu().detach().numpy())
    ax[i, 0].axis('off')
    for j in range(1, len(found)):
        if i == 0:
            ax[i, 2*j-1].set_title(f'{grads[j][0]:.2f}')
        ax[i, 2*j-1].imshow(0.5*grads[j][1][i, :, :, :]/(grads[j][1][i, :, :, :].max(axis=1, keepdims=False).max(axis=0, keepdims=False)) + 0.5)
        #jacobs = 0.5*jacobs/(jacobs.abs().max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]) + 0.5
        ax[i, 2*j-1].axis('off')
        ax[i, 2*j].hist(grads[j][1][i, :, :, :].flatten(), bins=100)
        #ax[i, 2*j].axis('off')

plt.show()
