import pickle

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Patch
from os.path import join

def main():
    data = []
    model_sizes = ['small', 'base', 'large', 'giant']
    for model_size in model_sizes:
        with open(join('results', 'register_models', model_size + '_results' + '.pkl'), 'rb') as fp:
            data.append(pickle.load(fp))

    # Number of groups (i.e., the number of dicts)
    n_groups = len(data)

    ### Attention

    # Extract the values for each key from your dicts
    values_A = [d["mean_attention_global"].numpy() for d in data]
    values_B = [d["mean_attention_local"].numpy() for d in data]

    # The x locations for the groups
    x = np.arange(n_groups)

    # Instead of bar_width, we use it as a horizontal offset for the violin plots.
    bar_width = 0.35
    offset = bar_width / 2

    # Set default font size to 9pt for all text
    matplotlib.rcParams['font.size'] = 8
    plt.rcParams['axes.titlesize'] = 8

    # Convert 7 cm to inches for the figure width (height chosen arbitrarily; adjust as needed)
    fig_width = 8 / 2.54   # ≈ 2.76 inches
    fig_height = 2         # example height in inches; adjust according to your layout

    fig, ax = plt.subplots(figsize=(fig_width, fig_height))

    # Loop through each group and plot two violin plots (side by side)
    for i in range(n_groups):
        pos_A = x[i] - offset
        pos_B = x[i] + offset
        # Violin for 'Register tokens'
        vp_A = ax.violinplot(values_A[i], positions=[pos_A], showmedians=True)
        for b in vp_A['bodies']:
            b.set_facecolor('skyblue')
            b.set_edgecolor('black')
            b.set_alpha(0.7)
        # Violin for 'Patch Tokens'
        vp_B = ax.violinplot(values_B[i], positions=[pos_B], showmedians=True)
        for b in vp_B['bodies']:
            b.set_facecolor('lightgreen')
            b.set_edgecolor('black')
            b.set_alpha(0.7)

        vp_A['cmedians'].set_color('black')
        vp_A['cbars'].set_color('black')
        vp_A['cmins'].set_color('black')
        vp_A['cmaxes'].set_color('black')
        vp_B['cmedians'].set_color('black')
        vp_B['cbars'].set_color('black')
        vp_B['cmins'].set_color('black')
        vp_B['cmaxes'].set_color('black')

    # Set x-axis tick positions and labels
    ax.set_xticks(x)
    ax.set_xticklabels(model_sizes)
    ax.set_ylabel('Attention')
    ax.set_xlabel('Model')
    ax.set_title('Attention of cls token to register and patch tokens')

    # Create a custom legend using proxy artists
    legend_elements = [Patch(facecolor='skyblue', edgecolor='black', label='Register tokens'),
                       Patch(facecolor='lightgreen', edgecolor='black', label='Patch Tokens')]
    ax.legend(handles=legend_elements)

    plt.savefig(join('results', 'register_models', 'attention') + '.pdf')


    ### CKA between total output and patch/register output

    # Extract the values for each key from your dicts
    values_A = [d["cka_total_global"].mean().numpy() for d in data]
    values_B = [d["cka_total_local"].mean().numpy() for d in data]

    # The x locations for the groups
    x = np.arange(n_groups)

    # Width of each bar
    bar_width = 0.35

    fig_width = 7 / 2.54   # ≈ 2.76 inches
    fig_height = 2         # example height in inches; adjust according to your layout

    fig, ax = plt.subplots(figsize=(fig_width, fig_height))

    # Plot the bars for each key, shifting them to appear side by side
    rects1 = ax.bar(x - bar_width / 2, values_A, bar_width, label='Register tokens')
    rects2 = ax.bar(x + bar_width / 2, values_B, bar_width, label='Patch tokens')

    # Add labels, title, and custom x-axis tick labels
    ax.set_ylabel('CKA')
    ax.set_xlabel('Model')
    ax.set_title('CKA with global representations')
    ax.set_xticks(x)
    ax.set_xticklabels(model_sizes)
    ax.legend()
    fig.tight_layout()


    plt.savefig(join('results', 'register_models', 'patch_cka') + '.pdf')


    ### Pre/post norm activation
    fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(6 / 2.54, 2.3), sharex=False)



    # Plot first vector as a bar plot
    ax1.bar(np.arange(len(data[-1]['pre_norm_hidden_states'])), data[-1]['pre_norm_hidden_states'], color='skyblue')
    ax1.set_title('Pre-Layer Norm Activations')
    ax1.set_xlabel('Pre-Norm Feature rank')
    ax1.set_ylabel('Value')

    # Plot second vector as a bar plot
    ax2.bar(np.arange(len(data[-1]['post_norm_hidden_states'])), data[-1]['post_norm_hidden_states'], color='salmon')
    ax2.set_title('Post-Layer Norm Activations')
    ax2.set_xlabel('Pre-Norm Feature rank')
    ax2.set_ylabel('Value')

    # Adjust layout for clarity
    plt.tight_layout()
    plt.savefig(join('results', 'register_models', 'pre_post_norm_activations') + '.pdf')


if __name__ == '__main__':
    main()