import argparse
import nasspace
import datasets
import random
import numpy as np
import collections
import os
import time
import torch
from scores import get_score_func

class Model(object):
  def __init__(self):
    self.arch = None
    self.accuracy = None
    self.final_accuracy = None    
  def __str__(self):
    """Prints a readable version of this bitstring."""
    return '{:}'.format(self.arch)

parser = argparse.ArgumentParser(description='NAS Without Training')
parser.add_argument('--data_loc', default='../fishersearch_randomwirenetworks/cifardata/', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='../fimflam/NAS-Bench-201-v1_0-e61699.pth',
                    type=str, help='path to API')
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file')
parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use')
parser.add_argument('--batch_size', default=256, type=int)
parser.add_argument('--repeat', default=256, type=int, help='how often to repeat a single image with a batch')
parser.add_argument('--augtype', default='gaussnoise', type=str, help='which perturbations to use')
parser.add_argument('--sigma', default=0.01, type=float, help='noise level if augtype is "gaussnoise"')
parser.add_argument('--GPU', default='0', type=str)
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--trainval', action='store_true')
parser.add_argument('--activations', action='store_true')
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--n_samples', default=100, type=int)
parser.add_argument('--n_runs', default=500, type=int)
parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)')
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)')
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)')
parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)')
parser.add_argument('--ea_cycles', default=200, type=int, help='# ea cyles')
parser.add_argument('--ea_population', default=10, type=int, help='# ea population')
parser.add_argument('--ea_sample_size', default=10, type=int, help='# ea sample size')
parser.add_argument('--time_budget', default=12000, type=int, help='# ea time budget')
parser.add_argument('--naswot', action='store_true')
parser.add_argument('--naswot_population', default=1000, type=int, help='# of archs to filter with naswot')
parser.add_argument('--score', default='corrdistintegral0_025', type=str, help='the score to evaluate')

args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU

# Reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)


def get_batch_jacobian(net, x, target, device, args=None):
    net.zero_grad()
    x.requires_grad_(True)
    y, ints = net(x)
    y.backward(torch.ones_like(y))
    jacob = x.grad.detach()
    return jacob, target.detach(), ints.detach()


def regularized_evolution(cycles, population_size, sample_size, time_budget, searchspace, dataname, args):
    if args.dataset == 'cifar10':
        acc_type = 'ori-test'
        val_acc_type = 'x-valid'
    else:
        acc_type = 'x-test'
        val_acc_type = 'x-valid'
    population = collections.deque()
    history, total_time_cost = [], 0
    if args.naswot:
        train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
        # first search quickly for a good starting population
        npstate = np.random.get_state()
        ranstate = random.getstate()
        torchstate = torch.random.get_rng_state()
        naswot_population = []
        naswotarchs = [searchspace.random_arch() for i in range(args.naswot_population)]


        for arch in naswotarchs:
            random.setstate(ranstate)
            np.random.set_state(npstate)
            torch.set_rng_state(torchstate)
            model = Model()
            model.arch = arch
            network = searchspace.get_network(searchspace[model.arch])
            timestart = time.time()
            data_iterator = iter(train_loader)
            x, target = next(data_iterator)
            x, target = x.to(device), target.to(device)
            network = network.to(device)
            jacobs, labels, ints = get_batch_jacobian(network, x, target, device, args)
            total_time_cost += time.time() - timestart
            
            

            try:
                s = get_score_func(args.score)(jacobs, labels)
            except Exception as e:
                print(e)
                s = np.nan
            model.accuracy = s
            naswot_population.append(model)
        naswot_population.sort(key=lambda x: x.accuracy)
        naswot_population = naswot_population[-population_size:]
        for model in naswot_population:
            model.accuracy, model.final_accuracy, time_cost = searchspace.train_and_eval(model.arch, dataname, acc_type, args.trainval)
            population.append(model)
            history.append(model)
            total_time_cost += time_cost
    else:
        while len(population) < population_size:
            model = Model()
            model.arch = searchspace.random_arch()
            #model.accuracy, time_cost = searchspace.train_and_eval(model.arch, dataname)
            model.accuracy, model.final_accuracy, time_cost = searchspace.train_and_eval(model.arch, dataname, acc_type, args.trainval)
            population.append(model)
            history.append(model)
            total_time_cost += time_cost
    while total_time_cost < time_budget:
        # Sample randomly chosen models from the current population.
        start_time, sample = time.time(), []
        while len(sample) < sample_size:
            # Inefficient, but written this way for clarity. In the case of neural
            # nets, the efficiency of this line is irrelevant because training neural
            # nets is the rate-determining step.
            candidate = random.choice(list(population))
            sample.append(candidate)

        # The parent is the best model in the sample.
        parent = max(sample, key=lambda i: i.accuracy)

        # Create the child model and store it.
        child = Model()
        child.arch = searchspace.mutate_arch(parent.arch)
        total_time_cost += time.time() - start_time
        child.accuracy, child.final_accuracy, time_cost = searchspace.train_and_eval(child.arch, dataname, acc_type, args.trainval)
        if total_time_cost + time_cost > time_budget: # return
            return history, total_time_cost
        else:
            total_time_cost += time_cost
        population.append(child)
        history.append(child)

        # Remove the oldest model.
        population.popleft()
    return history, total_time_cost

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
os.makedirs(args.save_loc, exist_ok=True)
os.makedirs(os.path.join(args.save_loc, 'R_EA'), exist_ok=True)
os.makedirs(os.path.join(args.save_loc, 'R_EA', args.nasspace), exist_ok=True)
os.makedirs(os.path.join(args.save_loc, 'R_EA', args.nasspace, args.dataset), exist_ok=True)
os.makedirs(os.path.join(args.save_loc, 'R_EA', args.nasspace, args.dataset, 'naswot' if args.naswot else 'standard'), exist_ok=True)
if args.naswot:
    os.makedirs(os.path.join(args.save_loc, 'R_EA', args.nasspace, args.dataset, 'naswot', args.score), exist_ok=True)
    filename = f'{args.save_loc}/R_EA/{args.nasspace}/{args.dataset}/naswot/{args.score}/{args.save_string}_{args.augtype}_{args.sigma}_{args.trainval}_{args.repeat}_{args.batch_size}_{args.seed}.txt'
else:
    filename = f'{args.save_loc}/R_EA/{args.nasspace}/{args.dataset}/standard/{args.save_string}_{args.trainval}_{args.seed}.txt'


searchspace = nasspace.get_search_space(args)
history, total_cost = regularized_evolution(args.ea_cycles, args.ea_population, args.ea_sample_size, args.time_budget, searchspace, args.dataset, args)
best_arch = max(history, key=lambda i: i.accuracy)
with open(filename, 'w') as f:
    f.write(f'{best_arch.final_accuracy}\n')
    for arch in history:
        f.write(f'{arch.arch},{arch.accuracy},{arch.final_accuracy}\n')

