import json
import numpy as np
import pandas as pd

COMPETITORS = [
    'params', 'flops', 'snip', 'fisher', 'jacov', 'synflow', 'grasp', 'plain', 'grad_norm', 'epe_nas',
    'l2_norm', 'zen', 'nwot', 'zico', 'meco_opt', 'swap', 'Ours', 'Optimal'
]

def load_database(ss):
    if ss == 'nb101':
        database = json.load(open('database/NB101/zc_nasbench101.json'))
    elif ss == 'nb201':
        database = json.load(open('database/NB201/zc_nasbench201.json'))
    elif ss == 'nb301':
        database = json.load(open('database/NB301/zc_nasbench301.json'))
    elif ss == 'tnb101-micro':
        database = json.load(open('database/TransNAS101/zc_transbench101_micro.json'))
    elif ss == 'tnb101-macro':
        database = json.load(open('database/TransNAS101/zc_transbench101_macro.json'))
    else:
        raise ValueError()
    return database

def get_best_arch(database, our_database, task, metric):
    zc_score, ground_truth = [], []
    if metric in ['Ours', 'Optimal']:
        zc_score = our_database['SR-designed']
        ground_truth = our_database['GroundTruth']
        if metric == 'Ours':
            return np.round(ground_truth[np.argmax(zc_score)], 2)
        else:
            return np.round(np.max(ground_truth), 2)
    else:
        for _, info in database[task].items():
            zc_score.append(info[metric]['score'])
            ground_truth.append(info['val_accuracy'])
        return np.round(ground_truth[np.argmax(zc_score)], 2)

def main_1():
    list_best_archs = []
    list_tasks = []
    for ss in ['nb101', 'nb201', 'nb301']:
        database = load_database(ss=ss)
        if ss == 'nb101':
            TASKS = ['cifar10']
        elif ss == 'nb201':
            TASKS = ['cifar10', 'cifar100', 'ImageNet16-120']
        elif ss == 'nb301':
            TASKS = ['cifar10']
        else:
            raise NotImplementedError
        for task in TASKS:
            best_arch_each_task = []
            problem = f'{ss}-{task}'
            list_tasks.append(problem)
            for zc_metric in COMPETITORS:
                _database = None
                if zc_metric in ['Ours', 'Optimal']:
                    _database = pd.read_csv(f'result/{problem}_SR-designed.csv', index_col=0)
                best_arch = get_best_arch(database, our_database=_database, task=task, metric=zc_metric)
                best_arch_each_task.append(best_arch)
            list_best_archs.append(best_arch_each_task)
    list_best_archs = np.array(list_best_archs).T
    _COMPETITORS = COMPETITORS.copy()
    _COMPETITORS[-4] = 'meco'
    score_board = pd.DataFrame(list_best_archs, index=_COMPETITORS, columns=list_tasks)
    score_board.to_csv('result/appendixF_table12.csv', encoding='utf-8')

def main_2():
    list_best_archs = []
    list_tasks = []
    ss = 'tnb101-micro'
    database = load_database(ss=ss)
    TASKS = ['class_scene', 'class_object', 'jigsaw', 'autoencoder']
    for task in TASKS:
        best_arch_each_task = []
        problem = f'{ss}-{task}'
        list_tasks.append(problem)
        for zc_metric in COMPETITORS:
            _database = None
            if zc_metric in ['Ours', 'Optimal']:
                _database = pd.read_csv(f'result/{problem}_SR-designed.csv', index_col=0)
            best_arch = get_best_arch(database, our_database=_database, task=task, metric=zc_metric)
            best_arch_each_task.append(best_arch)
        list_best_archs.append(best_arch_each_task)
    list_best_archs = np.array(list_best_archs).T
    _COMPETITORS = COMPETITORS.copy()
    _COMPETITORS[-4] = 'meco'
    score_board = pd.DataFrame(list_best_archs, index=_COMPETITORS, columns=list_tasks)
    score_board.to_csv('result/appendixF_table13.csv', encoding='utf-8')

def main_3():
    list_best_archs = []
    list_tasks = []
    ss = 'tnb101-macro'
    database = load_database(ss=ss)
    TASKS = ['class_scene', 'class_object', 'jigsaw', 'autoencoder']
    for task in TASKS:
        best_arch_each_task = []
        problem = f'{ss}-{task}'
        list_tasks.append(problem)
        for zc_metric in COMPETITORS:
            _database = None
            if zc_metric in ['Ours', 'Optimal']:
                _database = pd.read_csv(f'result/{problem}_SR-designed.csv', index_col=0)
            best_arch = get_best_arch(database, our_database=_database, task=task, metric=zc_metric)
            best_arch_each_task.append(best_arch)
        list_best_archs.append(best_arch_each_task)
    list_best_archs = np.array(list_best_archs).T
    _COMPETITORS = COMPETITORS.copy()
    _COMPETITORS[-4] = 'meco'
    score_board = pd.DataFrame(list_best_archs, index=_COMPETITORS, columns=list_tasks)
    score_board.to_csv('result/appendixF_table14.csv', encoding='utf-8')

if __name__ == '__main__':
    main_1()
    main_2()
    main_3()