import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

from scipy.interpolate import interp1d

_PARAM_LIST = [
    r'$\dot{L}_{\text{eval}}$', r'$\sigma_{\theta}$', 
    r'$\min C_{NTK}$', 
    r'$\text{CKA}(K_Y, K_{NTK})$', 'edge_alignment', r'$\lambda_0$', 
    r'$C_{\infty}$', 'norm_spatial_mean_log_D', #r'$\beta_{FORT}$', 
]

def plot_stats(pk_info, p_df, save_name=None, param_list=_PARAM_LIST):
    fig, axs = plt.subplots(len(param_list), 1, figsize=(8,10), sharex=True)
    
    t = p_df.epoch.values
    t_interpolator = interp1d(np.arange(len(t)), t, kind='linear')
    for i, p in enumerate(param_list):
        if p == 'norm_spatial_mean_log_D':
            new_title = 'MAG-Ma'
        elif p == 'edge_alignment':
            new_title = r'$\text{AUC}(v_0, \nabla I)$'
        elif p == r'$\dot{L}_{\text{eval}}$':
            new_title = r'$\frac{d}{dt}L_{\text{eval}}$'
        else:
            new_title = p
        axs[i].plot(1 + p_df.epoch, p_df[p], color='b', label=new_title)
        
        _left = pk_info[f'{p} start'] #a
        _right = pk_info[f'{p} end'] #b
        _pk = pk_info[f'{p} pk'] #b
    
        ylim = axs[i].get_ylim()
        axs[i].vlines([_pk], ylim[0], ylim[1], color='r', linestyle='--')
        axs[i].vlines([_left, _right], ylim[0], ylim[1], color='g', linestyle='--')
        axs[i].set_ylim(ylim)
        
        axs[i].set_ylabel('')
    for ax in axs:
        ax.set_xscale('log')
        ax.legend(loc='upper left')
        ax.yaxis.set_tick_params(labelsize=15)
        ax.yaxis.set_major_locator(MaxNLocator(nbins=2))
        for txt in ax.get_legend().get_texts():
            txt.set_fontsize(15)
    axs[-1].xaxis.label.set_size(20)
    axs[-1].xaxis.set_tick_params(labelsize=15)
    #plt.title(f'Signals')
    axs[-1].set_xlabel('Epoch')


    if save_name is not None:
        plt.savefig(save_name, format="pdf", bbox_inches="tight")
    
    plt.show()