import torch
import argparse
import pickle
import os
from tqdm import tqdm
from nas_201_api import NASBench201API as API
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
from nasbench_space.models import *
from search import *
import pandas as pd


def plot_experiment_all(exp_list, title, outdir, proxies, is_adv=False):

    plt.rc('font', size=20)        
    plt.rc('axes', labelsize=20)   
    plt.rc('xtick', labelsize=15) 
    plt.rc('ytick', labelsize=15)  
    plt.rc('legend', fontsize=15)  
    plt.rc('figure', titlesize=30)

    def plot_exp(exp, label):
        exp = np.array(exp) 
        q_75 = np.quantile(exp, .75, axis=0)
        q_25 = np.quantile(exp, .25, axis=0)
        mean = np.mean(exp, axis=0)
        plt.plot(mean, label=label)
        plt.fill_between(range(len(q_25)), q_25, q_75, alpha=0.1)
    
    for name in proxies.keys():
        for exp, ename in exp_list:
            plot_exp(exp[name], f'{ename}')
            
    plt.grid()
    plt.xlabel('Trained Models')

    if is_adv:
        plt.ylabel('Test Accuracy (FGSM-8)')
    else:
        plt.ylabel('Test Accuracy (Clean)')

    plt.legend()
    plt.title(title)
        
    if is_adv:
        plt.ylim(30, 51) 
    else:
        plt.ylim(83, 92)
    
    plt.tight_layout()
    if is_adv:
        plt.savefig(os.path.join(outdir, f'adv_{title}.png'))
    else:
        plt.savefig(os.path.join(outdir, f'clean_{title}.png'))
    plt.clf()
    

def read_proxies(path, proxy, proxy_dict):
    f = open(path, 'rb')
    while(1):
        try:
            d = pickle.load(f)
            proxy_dict[proxy].append(d['logmeasures'][proxy])
        except EOFError:
            break
    f.close()
    return proxy_dict

    
def main(args):
    api = API(args.api_loc)
    proxies = defaultdict(list)
    if len(args.proxy_fpath) == len(args.proxy_types):
        for path, proxy in zip(args.proxy_fpath, args.proxy_types):
            read_proxies(path, proxy, proxies)
    else:
        for proxy in args.proxy_types:
            read_proxies(args.proxy_fpath, proxy, proxies)

    idx_to_spec = {}
    for i, arch_str in enumerate(api):
        idx_to_spec[i] = nasbench2.get_spec_from_arch_str(arch_str)
    
    spec_to_idx = {}
    for idx, spec in idx_to_spec.items():
        spec_to_idx[str(spec)] = idx 
 
    ae, ae_warmup, ae_move, ae_warmup_move, rand, rand_warmup = defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list)
    for _ in tqdm(range(args.num_rounds)):
        ae_best_valids, ae_best_tests = \
            run_evolution_search(
                max_trained_models=args.length, 
                idx_to_spec=idx_to_spec, 
                spec_to_idx=spec_to_idx, 
                proxies=proxies,
                api=api, 
                dataset=args.dataset,
                is_adv=args.is_adv
            )
        for name in proxies.keys():
            ae[name].append(ae_best_tests[name])
        
        ae_warmup_best_valids, ae_warmup_best_tests = \
            run_evolution_search(
                max_trained_models=args.length, 
                zero_cost_warmup=3000, 
                idx_to_spec=idx_to_spec, 
                spec_to_idx=spec_to_idx, 
                proxies=proxies,
                api=api, 
                dataset=args.dataset,
                is_adv=args.is_adv
            )
        for name in proxies.keys():
            ae_warmup[name].append(ae_warmup_best_tests[name])
        
        ae_move_best_valids, ae_move_best_tests = \
            run_evolution_search(
                max_trained_models=args.length, 
                zero_cost_move=True, 
                idx_to_spec=idx_to_spec, 
                spec_to_idx=spec_to_idx, 
                proxies=proxies,
                api=api, 
                dataset=args.dataset,
                is_adv=args.is_adv
            )
        for name in proxies.keys():
            ae_move[name].append(ae_move_best_tests[name])
        
        ae_warmup_move_best_valids, ae_warmup_move_best_tests = \
            run_evolution_search(
                max_trained_models=args.length, 
                zero_cost_move=True, 
                zero_cost_warmup=3000, 
                idx_to_spec=idx_to_spec, 
                spec_to_idx=spec_to_idx, 
                proxies=proxies,
                api=api, 
                dataset=args.dataset,
                is_adv=args.is_adv
            )
        for name in proxies.keys():
            ae_warmup_move[name].append(ae_warmup_move_best_tests[name])
        
        rand_best_valids, rand_best_tests = \
            run_random_search(
                max_trained_models=args.length, 
                idx_to_spec=idx_to_spec, 
                spec_to_idx=spec_to_idx, 
                proxies=proxies,
                api=api, 
                dataset=args.dataset,
                is_adv=args.is_adv
            )
        for name in proxies.keys():
            rand[name].append(rand_best_tests[name])
        
        rand_warmup_best_valids, rand_warmup_best_tests \
            = run_random_search(
                max_trained_models=args.length, 
                zero_cost_warmup=3000, 
                idx_to_spec=idx_to_spec, 
                spec_to_idx=spec_to_idx,
                proxies=proxies, 
                api=api, 
                dataset=args.dataset,
                is_adv=args.is_adv
            )
        for name in proxies.keys():
            rand_warmup[name].append(rand_warmup_best_tests[name])

    plot_experiment_all([(ae, 'AE'), (ae_warmup,'AE + warmup (3000)'), (ae_move,'AE + move'), (ae_warmup_move,'AE + warmup (3000) + move')], f'Aging Evolution Search', args.outdir, proxies, is_adv=args.is_adv)
    plot_experiment_all([(rand, 'RAND'), (rand_warmup,'RAND + warmup (3000)')], f'Random Search', args.outdir, proxies, is_adv=args.is_adv)
    return


if __name__ == "__main__":
    parser = argparse.ArgumentParser("warmup")
    parser.add_argument('--gpu', type=int, default=0, help='GPU index to work on')
    parser.add_argument('--seed', type=int, default=42, help='pytorch manual seed')
    parser.add_argument('--dataset', type=str, default='cifar10', help=['cifar10', 'cifar100', 'ImageNet16-120'])
    parser.add_argument('--search_space', type=str, default='darts', help='search space (darts/nasbench201)')
    parser.add_argument('--api_loc', default='',
                        type=str, help='path to API')
    parser.add_argument('--proxy_types', default=['grad_norm'], type=str, nargs="+")
    parser.add_argument('--proxy_fpath', default=[''], type=str, nargs="+")
    parser.add_argument('--num_rounds', type=int, default=10)
    parser.add_argument('--length', type=int, default=300)
    parser.add_argument('--outdir', type=str, default='./fig')
    parser.add_argument('--is_adv', action='store_true')
    
    args = parser.parse_args()
    args.device = torch.device("cuda:"+str(args.gpu) if torch.cuda.is_available() else "cpu")

    os.makedirs(args.outdir, exist_ok=True)
    main(args)