import pickle as p
import numpy as np
from copy import deepcopy
import matplotlib.pyplot as plt
import seaborn
seaborn.set_theme(style='white')

plt.rc('legend', fontsize=20)
plt.rc('xtick', labelsize=20)
plt.rc('ytick', labelsize=28)
plt.rc('axes', labelsize=13)

METRICS = ['epe_nas', 'fisher', 'flops', 'grad_norm', 'grasp', 'jacov', 'l2_norm', 'nwot', 'params', 'plain', 'snip',
           'synflow', 'zen', 'swap', 'meco_opt', 'zico', 'val_accuracy']

def get_freq_each_proxy():
    counts = {metric: 0 for metric in METRICS[:-1]}
    for rid in range(1, 32):
        gp_model = p.load(open(f'exp/GP-Model_multiple_run{rid}.p', 'rb'))
        model = deepcopy(str(gp_model.our_program['program']))

        list_metrics = METRICS

        list_features = []
        for i in range(len(list_metrics) - 1, -1, -1):
            model = model.replace(f'X{i}', f'{list_metrics[i]}')
            if model.find(list_metrics[i]) != -1:
                list_features.append(list_metrics[i])
        for metric in list_features:
            counts[metric] += 1
    list_metrics = ['EPE-NAS', 'Fisher', 'FLOPs', 'Grad-norm', 'Grasp', 'Jacov',
                    'L2-norm', 'NWOT', 'Params', 'Plain',' Snip',' Synflow', 'Zen', 'SWAP', 'MeCo', 'ZiCo']
    freq = list(counts.values())
    idx = np.argsort(freq)
    list_metrics = np.array(list_metrics)[idx]
    freq = np.array(freq)[idx]

    def addlabels(x,y):
        for i in range(len(x)):
            plt.text(i, y[i] + 0.05, y[i], weight='bold', ha='center', fontsize=30)

    _, ax = plt.subplots(figsize=(14, 10))
    plt.bar(list_metrics, freq)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, fontsize=30, va='top', ha='right', rotation_mode='anchor')
    addlabels(list_metrics, freq)
    ax.set_ylabel('Frequency', fontsize=34)

    plt.tight_layout()
    plt.savefig('fig/freq.pdf', bbox_inches='tight')
    # plt.show()

get_freq_each_proxy()