from utils import create_data, load_search_spaces
from scipy import stats
import json
import numpy as np
import argparse

class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)

def main(kwargs):
    measure = kwargs.measure
    filename = f'{measure}-correlation'

    METRICS = ['epe_nas', 'fisher', 'flops', 'grad_norm', 'grasp', 'jacov', 'l2_norm', 'nwot', 'params', 'plain', 'snip',
               'synflow', 'zen', 'swap', 'meco_opt', 'zico', 'val_accuracy']
    all_ss = load_search_spaces('database')

    all_corr = {}

    for ss in [
        'nb101',
        'nb201',
        'nb301',
        'transnb101_micro',
        'transnb101_macro'
    ]:
        if ss == 'nb101':
            list_dataset = ['cifar10']
        elif ss == 'nb201':
            list_dataset = ['cifar10', 'cifar100', 'ImageNet16-120']
        elif ss == 'nb301':
            list_dataset = ['cifar10']
        elif ss == 'transnb101_micro':
            list_dataset = ['class_scene', 'class_object', 'jigsaw', 'autoencoder']
        elif ss == 'transnb101_macro':
            list_dataset = ['class_scene', 'class_object', 'jigsaw', 'autoencoder']
        else:
            raise NotImplementedError

        for dataset in list_dataset:
            if ss == 'transnb101_micro':
                ss_ = 'tnb101-micro'
            elif ss == 'transnb101_macro':
                ss_ = 'tnb101-macro'
            else:
                ss_ = ss
            X, y = create_data(all_ss[ss], dataset)
            problem = f'{ss_}-{dataset}'
            all_corr[problem] = {}
            print('Problem:', problem)
            for i, metric in enumerate(METRICS[:-1]):
                if measure == 'kendall':
                    corr = stats.kendalltau(y, X[:, i])[0]
                elif measure == 'spearman':
                    corr = stats.spearmanr(y, X[:, i])[0]
                else:
                    raise NotImplementedError
                print(f'{metric}: {corr:.2f}')
                all_corr[problem][metric] = corr
            print('-'*40)
    with open(f'result/{filename}.json', 'w') as fp:
        json.dump(all_corr, fp, indent=4, cls=NumpyEncoder)
        print('- Results are saved on result/')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    ''' PROBLEM '''
    parser.add_argument('--measure', type=str, default='kendall', help='the rank correlation measure',
    choices=['kendall', 'spearman'])
    args = parser.parse_args()
    main(args)
