import sys
import os
from os import path
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
from collections import defaultdict
import torch
import argparse
from tqdm import tqdm
import pandas as pd
import scipy.stats as stats
import pickle
import numpy as np

from nasbench_space.robust_nasbench201 import RobustnessDataset
from zero_cost_methods.pruners import *
from zero_cost_methods.dataset import *
from nasbench_space.models import *
from zero_cost_methods.weight_initializers import init_net
from darts_space.genotypes import Genotype
from darts_space.model import NetworkCIFAR as Network


def get_num_classes(args):
    return 100 if args.dataset == 'cifar100' else 10 if args.dataset == 'cifar10' else 120


def get_hrs(clean_acc, robust_acc):
    '''
    HRS = 2CR / (C+R)
    '''
    hrs = (2 * clean_acc * robust_acc) / (clean_acc + robust_acc)
    return hrs


def get_corr_zero_cost_nasbench(args, proxies, file_pickle):
    assert args.topk is not None or args.topp is not None
    if args.topk is not None:
        _k = args.topk
    if args.topp is not None:
        _k = int(args.topp*args.end / 100)

    robust_data = RobustnessDataset(path='./data/robust_nasbench201')
    results = robust_data.query(
        data = [args.dataset],
        measure = ['accuracy'],
        key = RobustnessDataset.keys_clean + RobustnessDataset.keys_adv + RobustnessDataset.keys_cc
    )
    
    test_accs, val_accs = [], []
    test_nas_accs, val_nas_accs = [], []
    robust_pgd_1, robust_fgsm_1, robust_fgsm_2, robust_fgsm_4, robust_fgsm_8 = [], [], [], [], []
    hrs_pgd_1, hrs_fgsm_8 = [], []
    robust_cc = [defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list)]
    path = file_pickle

    if proxies is None:
        proxies_r = defaultdict(list)
    
    with open(path, "rb") as f:
        for i in range(args.end):
            data = pickle.load(f)
            uid = data['i']

            if proxies is None:
                for proxy, val in data['logmeasures'].items():
                    proxies_r[proxy].append(val)

            clean_acc = results[args.dataset]["clean"]["accuracy"][robust_data.get_uid(uid)]
            test_accs.append(clean_acc)
            val_nas_accs.append(data['valacc'])
            test_nas_accs.append(data['testacc'])
            robust_pgd_1_ = results[args.dataset]['pgd@Linf']['accuracy'][robust_data.get_uid(uid)][robust_data.meta["epsilons"]["pgd@Linf"].index(1.0)]
            robust_fgsm_8_ = results[args.dataset]['fgsm@Linf']['accuracy'][robust_data.get_uid(uid)][robust_data.meta["epsilons"]["fgsm@Linf"].index(8.0)]
            robust_fgsm_1_ = results[args.dataset]['fgsm@Linf']['accuracy'][robust_data.get_uid(uid)][robust_data.meta["epsilons"]["fgsm@Linf"].index(1.0)]
            robust_fgsm_2_ = results[args.dataset]['fgsm@Linf']['accuracy'][robust_data.get_uid(uid)][robust_data.meta["epsilons"]["fgsm@Linf"].index(2.0)]
            robust_fgsm_4_ = results[args.dataset]['fgsm@Linf']['accuracy'][robust_data.get_uid(uid)][robust_data.meta["epsilons"]["fgsm@Linf"].index(4.0)]
            robust_pgd_1.append(robust_pgd_1_)
            robust_fgsm_8.append(robust_fgsm_8_)
            robust_fgsm_1.append(robust_fgsm_1_)
            robust_fgsm_2.append(robust_fgsm_2_)
            robust_fgsm_4.append(robust_fgsm_4_)

            hrs_pgd_1.append(get_hrs(clean_acc, robust_pgd_1_))
            hrs_fgsm_8.append(get_hrs(clean_acc, robust_fgsm_8_))
            if args.dataset == 'ImageNet16-120': continue
            for cc_type in RobustnessDataset.keys_cc:
                cc_accs = results[args.dataset][cc_type]['accuracy'][robust_data.get_uid(uid)]
                for level, acc in enumerate(cc_accs):
                    robust_cc[level][cc_type].append(acc)
            

    if proxies is None:
        proxies = proxies_r

    corr_dict = defaultdict(list)
    save_dir = f'./{args.corr_dir}/{args.end}'
    os.makedirs(save_dir, exist_ok=True)

    f = open(os.path.join(save_dir, 'corr.txt'), 'w')
    corr_dict['name'].append('test_acc')
    corr_dict['name'].append(f'test_acc top_{_k}')
    corr_dict['name'].append('test_nas_acc')
    corr_dict['name'].append(f'test_nas_acc top_{_k}')
    corr_dict['name'].append('val_nas_acc')
    corr_dict['name'].append(f'val_nas_acc top_{_k}')
    corr_dict['name'].append('pgd1')
    corr_dict['name'].append(f'pgd1 top_{_k}')
    corr_dict['name'].append('fgsm8')
    corr_dict['name'].append(f'fgsm8 top_{_k}')
    corr_dict['name'].append('fgsm1')
    corr_dict['name'].append(f'fgsm1 top_{_k}')
    corr_dict['name'].append('fgsm2')
    corr_dict['name'].append(f'fgsm2 top_{_k}')
    corr_dict['name'].append('fgsm4')
    corr_dict['name'].append(f'fgsm4 top_{_k}')
    corr_dict['name'].append('hrs_pgd1')
    corr_dict['name'].append('hrs_fgsm8')

    if args.dataset != 'ImageNet16-120':
        corr_dict['name'].append('cc_total')
    
    for zero_cost_proxy in args.ZERO_COST_PROXY_LIST:
        print(zero_cost_proxy, len(test_accs), len(proxies[zero_cost_proxy]))
        topk_proxy = get_topk_idx(proxies[zero_cost_proxy], k=_k)
        corr_dict[zero_cost_proxy].append(round(stats.spearmanr(test_accs, proxies[zero_cost_proxy])[0], 3))
        corr_dict[zero_cost_proxy].append(round(stats.spearmanr(get_topk_idx(test_accs, k=_k), topk_proxy)[0], 3))

        corr_dict[zero_cost_proxy].append(round(stats.spearmanr(test_nas_accs, proxies[zero_cost_proxy])[0], 3))
        corr_dict[zero_cost_proxy].append(round(stats.spearmanr(get_topk_idx(test_nas_accs, k=_k), topk_proxy)[0], 3))

        corr_dict[zero_cost_proxy].append(round(stats.spearmanr(val_nas_accs, proxies[zero_cost_proxy])[0], 3))
        corr_dict[zero_cost_proxy].append(round(stats.spearmanr(get_topk_idx(val_nas_accs, k=_k), topk_proxy)[0], 3))
        
        corr_dict[zero_cost_proxy].append(round(stats.spearmanr(robust_pgd_1, proxies[zero_cost_proxy])[0], 3))
        corr_dict[zero_cost_proxy].append(round(stats.spearmanr(get_topk_idx(robust_pgd_1, k=_k), topk_proxy)[0], 3))

        corr_dict[zero_cost_proxy].append(round(stats.spearmanr(robust_fgsm_8, proxies[zero_cost_proxy])[0], 3))
        corr_dict[zero_cost_proxy].append(round(stats.spearmanr(get_topk_idx(robust_fgsm_8, k=_k), topk_proxy)[0], 3))
        
        corr_dict[zero_cost_proxy].append(round(stats.spearmanr(robust_fgsm_1, proxies[zero_cost_proxy])[0], 3))
        corr_dict[zero_cost_proxy].append(round(stats.spearmanr(get_topk_idx(robust_fgsm_1, k=_k), topk_proxy)[0], 3))

        corr_dict[zero_cost_proxy].append(round(stats.spearmanr(robust_fgsm_2, proxies[zero_cost_proxy])[0], 3))
        corr_dict[zero_cost_proxy].append(round(stats.spearmanr(get_topk_idx(robust_fgsm_2, k=_k), topk_proxy)[0], 3))

        corr_dict[zero_cost_proxy].append(round(stats.spearmanr(robust_fgsm_4, proxies[zero_cost_proxy])[0], 3))
        corr_dict[zero_cost_proxy].append(round(stats.spearmanr(get_topk_idx(robust_fgsm_4, k=_k), topk_proxy)[0], 3))

        corr_dict[zero_cost_proxy].append(round(stats.spearmanr(hrs_pgd_1, proxies[zero_cost_proxy])[0], 3))
        corr_dict[zero_cost_proxy].append(round(stats.spearmanr(hrs_fgsm_8, proxies[zero_cost_proxy])[0], 3))
        
        if args.dataset != 'ImageNet16-120':
            corr_cc = 0
            for cc_type in RobustnessDataset.keys_cc:
                f.write(f'CC [{cc_type}]\n')
                corr_total = 0
                corr_str = ''
                for level in range(5):
                    corr = round(stats.spearmanr(robust_cc[level][cc_type], proxies[zero_cost_proxy])[0], 3)
                    corr_total += corr
                    corr_str += f'\tLevel {level+1}: {corr} |'
                f.write(f'\t\tTotal: {corr_total/5:.3f} |{corr_str}\n')
                corr_cc += corr_total/5
            corr_cc = round(corr_cc/len(RobustnessDataset.keys_cc), 3)
            corr_dict[zero_cost_proxy].append(corr_cc)
            f.write(f'CC accs: {corr_cc}\n')

    corr_df = pd.DataFrame(corr_dict, columns=list(corr_dict.keys()))
    file_name = '_'.join(args.ZERO_COST_PROXY_LIST)+'_results.csv'
    corr_df.to_csv(os.path.join(save_dir, file_name))

    return corr_df


def get_corr_zero_cost_nasbench101(args, proxies, file_pickle):
    assert args.topk is not None or args.topp is not None
    if args.topk is not None:
        _k = args.topk
    if args.topp is not None:
        _k = int(args.topp*args.end / 100)
    
    val_accs = []
    path = file_pickle

    if proxies is None:
        proxies_r = defaultdict(list)
    
    with open(path, "rb") as f:
        for i in range(args.end):
            data = pickle.load(f)
            uid = data['i']

            if proxies is None:
                for proxy, val in data['logmeasures'].items():
                    proxies_r[proxy].append(val)

            val_accs.append(data['valacc'])
    
    if proxies is None:
        proxies = proxies_r

    corr_dict = defaultdict(list)
    save_dir = f'./{args.corr_dir}/{args.end}'
    os.makedirs(save_dir, exist_ok=True)
    f = open(os.path.join(save_dir, 'corr.txt'), 'w')
    
    corr_dict['name'].append('test_acc')
    corr_dict['name'].append(f'test_acc top_{_k}')
    
    for zero_cost_proxy in args.ZERO_COST_PROXY_LIST:
        print(zero_cost_proxy, len(val_accs), len(proxies[zero_cost_proxy]))
        topk_proxy = get_topk_idx(proxies[zero_cost_proxy], k=_k)
        corr_dict[zero_cost_proxy].append(round(stats.spearmanr(val_accs, proxies[zero_cost_proxy])[0], 3))
        corr_dict[zero_cost_proxy].append(round(stats.spearmanr(get_topk_idx(val_accs, k=_k), topk_proxy)[0], 3))

    corr_df = pd.DataFrame(corr_dict, columns=list(corr_dict.keys()))
    file_name = '_'.join(args.ZERO_COST_PROXY_LIST)+'_results.csv'
    corr_df.to_csv(os.path.join(save_dir, file_name))

    return corr_df


def get_topk_idx(proxy_list, k=50):
    _, idx = torch.topk(torch.tensor(proxy_list), k=k)
    return idx
        

if __name__ == '__main__':
    parser = argparse.ArgumentParser("cifar")
    parser.add_argument('--gpu', type=int, default=0, help='GPU index to work on')
    parser.add_argument('--dataset', type=str, default='cifar10', help='name of dataset')
    parser.add_argument('--search_space', type=str, default='nasbench201', help='search space (nasbench101/nasbench201)')
    parser.add_argument('--api_loc', default='', type=str, help='path to API')
    parser.add_argument('--corr_dir', type=str, default='./corr')
    parser.add_argument('--end', type=int, default=0, help='end index')
    parser.add_argument('--proxy_types', default='croze', type=str)
    parser.add_argument('--proxy_fpath', type=str, default=None)
    parser.add_argument('--topk', type=int, default=None)
    parser.add_argument('--topp', type=float, default=None)

    args = parser.parse_args()
    
    if args.proxy_types == 'baselines':
        args.ZERO_COST_PROXY_LIST = ['grad_norm', 'snip', 'grasp', 'fisher', 'synflow']
    else:
        proxy_list = args.proxy_types.split(',')
        args.ZERO_COST_PROXY_LIST = proxy_list
    
    if args.search_space == 'nasbench201':
        if args.end == 0: args.end = 15625
        get_corr_zero_cost_nasbench(args, None, args.proxy_fpath)
    else:
        get_corr_zero_cost_nasbench101(args, None, args.proxy_fpath)