"""
Functions useful for plotting metrics tracked during training.
"""
import matplotlib.pyplot as plt
import math
import torch

def plot_df(df, save=False, include_layer_sparsity=False):
    """
    Plot all values in a dataframe consisting of scalar values.
    """
    fontsize = 20
    linewidth = 3
    # Assuming df is your DataFrame
    metrics_to_plot = [col for col in df.columns if col != 'step' and df[col].dtype != 'object'
                       and col[:4] != 'Test' and col != 'epoch']
    # Filter for layer_sparsity metrics
    if not include_layer_sparsity:
        metrics_to_plot = [name for name in metrics_to_plot if 'layer_sparsity' not in name]
    
    # Identify size of figure based on number of metrics to plot
    num_metrics = len(metrics_to_plot)
    cols = math.ceil(num_metrics ** .5)
    rows = math.ceil(num_metrics / cols)
    
    fig, axs = plt.subplots(rows, cols, figsize=(6 * cols, 4 * rows))
    
    # If axs is 1D, convert to 2D for consistent indexing
    axs = axs.flatten() if num_metrics > 1 else [axs]
    
    for i, metric in enumerate(metrics_to_plot):
        axs[i].plot( df[metric], linewidth=linewidth)
        final_value = df[metric].dropna().iloc[-1]
        axs[i].set_title(f'{metric}\n{final_value:.3f}', fontsize=fontsize)
        if metric.split('/')[0] == 'Train':
            xlabel = 'Step'
        else:
            xlabel = 'Epoch'
        axs[i].set_xlabel(xlabel, fontsize=fontsize)
        axs[i].grid(True)
        axs[i].tick_params(axis='both', labelsize=fontsize*.6)
    
    # Turn off any unused subplots
    for j in range(i+1, len(axs)):
        axs[j].axis('off')
    plt.tight_layout()
    
    if save:
        plt.savefig(f'{save}.png')

    plt.show()



def display_sparsity_table(model, column_width=30):
    print(f'{"Layer":<{column_width}} || {"Numel":>7} | {"NNZ":>7} | {"Sparsity":>9}')
    print('-' * (column_width + 33))
    
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear) or isinstance(module, torch.nn.Conv2d):
            # The currently populated weight and bias attributes are the most recent ones
            w = module.weight.detach()
            print(f'{name + "_weight":<{column_width}} || {w.numel():7} | {w.count_nonzero():6}  | {100-w.count_nonzero()/w.numel()*100:.2f}%')
            
            if module.bias is not None:
                b = module.bias.detach()
                print(f'{name + "_bias":<{column_width}} || {b.numel():7} | {b.count_nonzero():6}  | {100-b.count_nonzero()/b.numel()*100:.2f}%')
