import random
import copy
from collections import defaultdict
from nasbench_space.robust_nasbench201 import RobustnessDataset

random.seed = 42

def random_spec(idx_to_spec):
    return random.choice(list(idx_to_spec.values()))


def mutate_spec(old_spec):
    idx_to_change = random.randrange(len(old_spec))
    entry_to_change = old_spec[idx_to_change]
    possible_entries = [x for x in range(5) if x != entry_to_change]
    new_entry = random.choice(possible_entries)
    new_spec = copy.copy(old_spec)
    new_spec[idx_to_change] = new_entry
    return new_spec


def mutate_spec_zero_cost(old_spec, spec_to_idx, proxy):
    possible_specs = []
    for idx_to_change in range(len(old_spec)): 
        entry_to_change = old_spec[idx_to_change]
        possible_entries = [x for x in range(5) if x != entry_to_change]
        for new_entry in possible_entries:
            new_spec = copy.copy(old_spec)
            new_spec[idx_to_change] = new_entry
            possible_specs.append((proxy[spec_to_idx[str(new_spec)]], new_spec))
    best_new_spec = sorted(possible_specs, key=lambda i:i[0])[-1][1]
    if random.random() > 0.75:
        best_new_spec = random.choice(possible_specs)[1]
    return best_new_spec


def random_combination(iterable, sample_size):
    pool = tuple(iterable)
    n = len(pool)
    indices = sorted(random.sample(range(n), sample_size))
    return tuple(pool[i] for i in indices)


def run_evolution_search(
    max_trained_models=1000, 
    pool_size=64, 
    tournament_size=10, 
    zero_cost_warmup=0, 
    zero_cost_move=False,
    idx_to_spec=None, 
    spec_to_idx=None,
    proxies=None,
    api=None,
    dataset='cifar10',
    is_adv=False
):
    
    if not is_adv and dataset == 'cifar10':
        dataset = 'cifar10-valid'
        
    best_valids = defaultdict(list)
    best_tests = defaultdict(list)
    
    if is_adv:
        robust_data = RobustnessDataset(path='./data/robust_nasbench201')
        results = robust_data.query(
            data = [dataset],
            measure = ['accuracy'],
            key = RobustnessDataset.keys_clean + RobustnessDataset.keys_adv + RobustnessDataset.keys_cc
        )
        
    for name, proxy in proxies.items():
        best_valids[name], best_tests[name] = [0.0], [0.0]
        pool = []   # (validation, spec) tuples
        num_trained_models = 0
        
        # fill the initial pool
        if zero_cost_warmup > 0:
            zero_cost_pool = []
            for _ in range(zero_cost_warmup):
                spec = random_spec(idx_to_spec)
                spec_idx = spec_to_idx[str(spec)]
                zero_cost_pool.append((proxy[spec_idx], spec))
                zero_cost_pool = sorted(zero_cost_pool, key=lambda i:i[0], reverse=True)
        
        for i in range(pool_size):
            if zero_cost_warmup > 0:
                spec = zero_cost_pool[i][1]
            else:
                spec = random_spec(idx_to_spec)
            
            net_index = spec_to_idx[str(spec)]
            if is_adv:
                val_acc = results[dataset]['fgsm@Linf']['accuracy'][robust_data.get_uid(net_index)][robust_data.meta["epsilons"]["fgsm@Linf"].index(8.0)]
                val_acc *= 100
                test_acc = val_acc
            else:
                info = api.get_more_info(net_index, dataset, iepoch=None, hp='200', is_random=False)
                val_acc = info['valid-accuracy']
                test_acc = info['test-accuracy']
                
            num_trained_models += 1
            pool.append((val_acc, spec))

            if val_acc > best_valids[name][-1]:
                best_valids[name].append(val_acc)
            else:
                best_valids[name].append(best_valids[name][-1])
                
            if test_acc > best_tests[name][-1]:
                best_tests[name].append(test_acc)
            else:
                best_tests[name].append(best_tests[name][-1])

        # After the pool is seeded, proceed with evolving the population.
        while(1):
            sample = random_combination(pool, tournament_size)
            best_spec = sorted(sample, key=lambda i:i[0])[-1][1]
            if zero_cost_move:
                new_spec = mutate_spec_zero_cost(best_spec, spec_to_idx, proxy)
            else:
                new_spec = mutate_spec(best_spec)

            net_index = spec_to_idx[str(new_spec)]
            if is_adv:
                val_acc = results[dataset]['fgsm@Linf']['accuracy'][robust_data.get_uid(net_index)][robust_data.meta["epsilons"]["fgsm@Linf"].index(8.0)]
                val_acc *= 100
                test_acc = val_acc
            else:
                info = api.get_more_info(net_index, dataset, iepoch=None, hp='200', is_random=False)
                val_acc = info['valid-accuracy']
                test_acc = info['test-accuracy']
                
            num_trained_models += 1

            # kill the oldest individual in the population.
            pool.append((val_acc, new_spec))
            pool.pop(0)

            if val_acc > best_valids[name][-1]:
                best_valids[name].append(val_acc)
            else:
                best_valids[name].append(best_valids[name][-1])
                
            if test_acc > best_tests[name][-1]:
                best_tests[name].append(test_acc)
            else:
                best_tests[name].append(best_tests[name][-1])

            if num_trained_models >= max_trained_models:
                break
            
        best_tests[name].pop(0)
        best_valids[name].pop(0)

    return best_valids, best_tests


def run_random_search(
    max_trained_models=1000, 
    zero_cost_warmup=0,
    idx_to_spec=None,
    spec_to_idx=None,
    proxies=None,
    api=None, 
    dataset='cifar10',
    is_adv=False
):
    
    if not is_adv and dataset == 'cifar10':
        dataset = 'cifar10-valid'
    
    best_valids = defaultdict(list)
    best_tests = defaultdict(list)
    
    if is_adv:
        robust_data = RobustnessDataset(path='./data/robust_nasbench201')
        results = robust_data.query(
            data = [dataset],
            measure = ['accuracy'],
            key = RobustnessDataset.keys_clean + RobustnessDataset.keys_adv + RobustnessDataset.keys_cc
        )
        
    for name, proxy in proxies.items():
        best_valids[name], best_tests[name] = [0.0], [0.0]
        
        # fill the initial pool
        if zero_cost_warmup > 0:
            zero_cost_pool = []
            for _ in range(zero_cost_warmup):
                spec = random_spec(idx_to_spec)
                spec_idx = spec_to_idx[str(spec)]
                zero_cost_pool.append((proxy[spec_idx], spec))
                zero_cost_pool = sorted(zero_cost_pool, key=lambda i:i[0], reverse=True)
        for i in range(max_trained_models):
            if i < zero_cost_warmup:
                spec = zero_cost_pool[i][1]
            else:
                spec = random_spec(idx_to_spec)
                
            net_index = spec_to_idx[str(spec)]
            
            if is_adv:
                val_acc = results[dataset]['fgsm@Linf']['accuracy'][robust_data.get_uid(net_index)][robust_data.meta["epsilons"]["fgsm@Linf"].index(8.0)]
                val_acc *= 100
                test_acc = val_acc
            else:
                info = api.get_more_info(net_index, dataset, iepoch=None, hp='200', is_random=False)
                val_acc = info['valid-accuracy']
                test_acc = info['test-accuracy']
                
            if val_acc > best_valids[name][-1]:
                best_valids[name].append(val_acc)
            else:
                best_valids[name].append(best_valids[name][-1])
                
            if test_acc > best_tests[name][-1]:
                best_tests[name].append(test_acc)
            else:
                best_tests[name].append(best_tests[name][-1])
                
        best_tests[name].pop(0)
        best_valids[name].pop(0)

    return best_valids, best_tests