from scipy import stats
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.rc('legend', fontsize=12)
plt.rc('xtick', labelsize=12)
plt.rc('ytick', labelsize=12)
plt.rc('axes', labelsize=14)

METRICS = ['epe_nas', 'fisher', 'flops', 'grad_norm', 'grasp', 'jacov', 'l2_norm', 'nwot', 'params', 'plain', 'snip',
           'synflow', 'zen', 'swap', 'meco_opt', 'zico', 'val_accuracy']

def convert_problem_2_title(problem):
    if problem == 'nb201-cifar10': return 'NB201-CF10'
    elif problem == 'nb201-cifar100': return 'NB201-CF100'
    elif problem == 'nb201-ImageNet16-120': return 'NB201-IMGNT'
    elif problem == 'nb101-cifar10': return 'NB101-CF10'
    elif problem == 'nb301-cifar10': return 'NB301-CF10'
    elif problem == 'tnb101-micro-class_object': return 'TNB101-Micro-Object'
    elif problem == 'tnb101-micro-class_scene': return 'TNB101-Micro-Scene'
    elif problem == 'tnb101-micro-autoencoder': return 'TNB101-Micro-AutoEnc'
    elif problem == 'tnb101-micro-jigsaw': return 'TNB101-Micro-Jigsaw'
    elif problem == 'tnb101-macro-class_object': return 'TNB101-Macro-Object'
    elif problem == 'tnb101-macro-class_scene': return 'TNB101-Macro-Scene'
    elif problem == 'tnb101-macro-autoencoder': return 'TNB101-Macro-AutoEnc'
    elif problem == 'tnb101-macro-jigsaw': return 'TNB101-Macro-Jigsaw'


def plot():
    for ss in [
        'nb101',
        'nb201',
        'nb301',
        'transnb101_micro',
        'transnb101_macro'
    ]:
        if ss == 'nb101':
            list_dataset = ['cifar10']
        elif ss == 'nb201':
            list_dataset = ['cifar10', 'cifar100', 'ImageNet16-120']
        elif ss == 'nb301':
            list_dataset = ['cifar10']
        elif ss == 'transnb101_micro':
            list_dataset = ['class_scene', 'class_object', 'jigsaw', 'autoencoder']
        elif ss == 'transnb101_macro':
            list_dataset = ['class_scene', 'class_object', 'jigsaw', 'autoencoder']
        else:
            raise NotImplementedError

        for dataset in list_dataset:
            if ss == 'transnb101_micro':
                ss_ = 'tnb101-micro'
            elif ss == 'transnb101_macro':
                ss_ = 'tnb101-macro'
            else:
                ss_ = ss
            problem = f'{ss_}-{dataset}'

            df = pd.read_csv(f'result/{problem}_SR-designed.csv', index_col=0)
            zc_scores = df['SR-designed']
            ground_truth = df['GroundTruth']

            correlation = stats.kendalltau(zc_scores, ground_truth)[0]
            best_zc_score = np.argmax(zc_scores)

            plt.scatter(zc_scores, ground_truth, edgecolors='k', facecolor='royalblue')
            plt.scatter(zc_scores[best_zc_score], ground_truth[best_zc_score], edgecolors='k', facecolor='tab:red',
                        label=f'Best ZC Score')

            print(ground_truth[best_zc_score])
            plt.xlabel('Our ZC metric Score')
            plt.ylabel('Ground-truth Score')
            plt.legend()

            plt.title(rf"{convert_problem_2_title(problem)} (Kendall's $\tau$: {correlation:2.4f})", fontsize=14)
            plt.legend(loc=4)
            plt.savefig(f'fig/{problem}.pdf', bbox_inches='tight')
            plt.clf()
            # plt.show()

if __name__ == '__main__':
    plot()
