import argparse
import nasspace
import datasets
import random
import numpy as np
import torch
import os
from scores import get_score_func


parser = argparse.ArgumentParser(description='NAS Without Training')
parser.add_argument('--data_loc', default='../fishersearch_randomwirenetworks/cifardata/', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='../fimflam/NAS-Bench-201-v1_0-e61699.pth',
                    type=str, help='path to API')
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file')
parser.add_argument('--score', default='corrdistintegral0_025', type=str, help='the score to evaluate')
parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use')
parser.add_argument('--batch_size', default=256, type=int)
parser.add_argument('--repeat', default=256, type=int, help='how often to repeat a single image with a batch')
parser.add_argument('--augtype', default='gaussnoise', type=str, help='which perturbations to use')
parser.add_argument('--sigma', default=0.01, type=float, help='noise level if augtype is "gaussnoise"')
parser.add_argument('--GPU', default='0', type=str)
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--trainval', action='store_true')
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--maxofn', default=10, type=int, help='score is the max of this many evaluations of the network')
parser.add_argument('--n_samples', default=100, type=int)
parser.add_argument('--n_runs', default=500, type=int)
parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)')
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)')
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)')
parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)')

args = parser.parse_args()


filename = f'{args.save_loc}/{args.save_string}_{args.score}_{args.nasspace}_{args.dataset}_{args.augtype}_{args.sigma}_{args.repeat}_{args.trainval}_{args.batch_size}_{args.maxofn}_{args.seed}.npy'
accfilename = f'{args.save_loc}/{args.save_string}_accs_{args.nasspace}_{args.dataset}_{args.trainval}.npy'
paramfilename = f'{args.save_loc}/{args.save_string}_{args.nasspace}_{args.dataset}_params_accs.npy'


scores = np.load(filename)
accs = np.load(accfilename)
params = np.load(paramfilename)[0, :]

inds = np.where(accs > 90.)
ind = np.argmin(params[inds])
ind = inds[0][ind]
print(f'Rank by param: {(params > params[ind]).sum() + 1}')
print(f'Rank by score: {(scores > scores[ind]).sum() + 1}')

