import argparse
import nasspace
import datasets
import random
import numpy as np
import torch
import os
import sys
import matplotlib.pyplot as plt
from scipy import stats
from scores import get_score_func
from measure_model import measure_model
from tqdm import tqdm

parser = argparse.ArgumentParser(description='NAS Without Training')
parser.add_argument('--data_loc', default='../fishersearch_randomwirenetworks/cifardata/', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='../fimflam/NAS-Bench-201-v1_0-e61699.pth',
                    type=str, help='path to API')
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file')
parser.add_argument('--score', default='corrdistintegral0_025', type=str, help='the score to evaluate')
parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use')
parser.add_argument('--batch_size', default=256, type=int)
parser.add_argument('--repeat', default=256, type=int, help='how often to repeat a single image with a batch')
parser.add_argument('--augtype', default='cutout', type=str, help='which perturbations to use')
parser.add_argument('--sigma', default=0.01, type=float, help='noise level if augtype is "gaussnoise"')
parser.add_argument('--GPU', default='0', type=str)
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--trainval', action='store_true')
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--n_samples', default=100, type=int)
parser.add_argument('--n_runs', default=500, type=int)
parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)')
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)')
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)')
parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)')
parser.add_argument('--justplot', action='store_true')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU

# Reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)


def get_batch_jacobian(net, x, target, device, args=None):
    net.zero_grad()
    x.requires_grad_(True)
    y, _ = net(x)
    y.backward(torch.ones_like(y))
    jacob = x.grad.detach()
    return jacob, target.detach()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
os.makedirs(args.save_loc, exist_ok=True)

filename = f'{args.save_loc}/{args.save_string}_{args.nasspace}_{args.dataset}_params_accs'

params = 0
scores = 1
accs   = 2

if args.justplot:
    params_accs = np.load(filename+'.npy')
    x, y = params_accs[params], params_accs[accs]
    tau, p = stats.kendalltau(x, y)
    print(f"Kendall's Tau: {tau}")
    fig, ax = plt.subplots()
    ax.scatter(x, y, alpha=.8)
    ax.set_title(f"KTau: {tau}")
    ax.set_xlabel('Parameters')
    ax.set_ylabel('Accuracy')
    plt.tight_layout()
    plt.savefig(filename+'.pdf')
    
    ##
    x,y = params_accs[params], params_accs[scores]
    tau, p = stats.kendalltau(x,y)
    fig, ax = plt.subplots()
    ax.scatter(x, y, alpha=.8)
    ax.set_title(f'$\\tau={tau}$')
    ax.set_xlabel('Parameters')
    ax.set_ylabel('Score')
    plt.tight_layout()
    plt.savefig(filename+'VS_SCORES.pdf')
    sys.exit()


searchspace = nasspace.get_search_space(args)
data_iterator = iter(train_loader)
x, target = next(data_iterator)
x, target = x.to(device), target.to(device)

params, test_accs = [], []
s = []

for i, (uid, network) in tqdm(enumerate(searchspace)):
    # Reproducibility
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    x = x.to(device)
    network = network.to(device)

    ops, param_count = measure_model(network, x)
    test = searchspace.get_accuracy(uid, 'x-valid')
    x, target = next(data_iterator)
    x, target = x.to(device), target.to(device)
    network = network.to(device)
    jacobs, labels = get_batch_jacobian(network, x, target, device, args)
    s.append(get_score_func(args.score)(jacobs, labels))

    params.append(param_count)
    test_accs.append(test)

params_accs = np.array([params, s, test_accs])
np.save(filename, params_accs)

