import numpy as np
import pickle
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import matplotlib
from os.path import join


def main():
    with open(join('results', 'residual_results') + '.pkl', 'rb') as fp:
        data = pickle.load(fp)

    model_sizes = ['small', 'base', 'large', 'giant']

    data = [data[model] for model in model_sizes]

    # Number of groups (i.e., the number of dicts)
    n_groups = len(data)

    ### Attention

    # Extract the values for each key from your dicts
    values_A = [d["residual_attention"].numpy() for d in data]
    values_B = [d["patch_attention"].numpy() for d in data]

    # The x locations for the groups
    x = np.arange(n_groups)

    # Instead of bar_width, we use it as a horizontal offset for the violin plots.
    bar_width = 0.35
    offset = bar_width / 2

    # Set default font size to 9pt for all text
    matplotlib.rcParams['font.size'] = 8
    plt.rcParams['axes.titlesize'] = 8

    # Convert 7 cm to inches for the figure width (height chosen arbitrarily; adjust as needed)
    fig_width = 8 / 2.54   # ≈ 2.76 inches
    fig_height = 2        # example height in inches; adjust according to your layout

    fig, ax = plt.subplots(figsize=(fig_width, fig_height))

    # Loop through each group and plot two violin plots (side by side)
    for i in range(n_groups):
        pos_A = x[i] - offset
        pos_B = x[i] + offset
        # Violin for 'Register tokens'
        vp_A = ax.violinplot(values_A[i], positions=[pos_A], showmedians=True)
        for b in vp_A['bodies']:
            b.set_facecolor('skyblue')
            b.set_edgecolor('black')
            b.set_alpha(0.7)
        # Violin for 'Patch Tokens'
        vp_B = ax.violinplot(values_B[i], positions=[pos_B], showmedians=True)
        for b in vp_B['bodies']:
            b.set_facecolor('lightgreen')
            b.set_edgecolor('black')
            b.set_alpha(0.7)

            vp_A['cmedians'].set_color('black')
            vp_A['cbars'].set_color('black')
            vp_A['cmins'].set_color('black')
            vp_A['cmaxes'].set_color('black')
            vp_B['cmedians'].set_color('black')
            vp_B['cbars'].set_color('black')
            vp_B['cmins'].set_color('black')
            vp_B['cmaxes'].set_color('black')

    # Set x-axis tick positions and labels
    ax.set_xticks(x)
    ax.set_xticklabels(model_sizes)
    ax.set_ylabel('Norm ratio')
    ax.set_xlabel('Model')
    ax.set_title('Norm of skip connection and attention output at last layer')

    # Create a custom legend using proxy artists
    legend_elements = [Patch(facecolor='skyblue', edgecolor='black', label='Skip features'),
                       Patch(facecolor='lightgreen', edgecolor='black', label='Patch features')]
    ax.legend(handles=legend_elements)

    plt.savefig(join('results', 'residual_attention') + '.pdf')


    ### CKA between total output and patch/register output

    # Extract the values for each key from your dicts
    values_A = [d["cka_residual"].mean().numpy() for d in data]
    values_B = [d["cka_patch"].mean().numpy() for d in data]

    # The x locations for the groups
    x = np.arange(n_groups)

    # Width of each bar
    bar_width = 0.35

    fig_width = 7 / 2.54   # ≈ 2.76 inches
    fig_height = 2         # example height in inches; adjust according to your layout

    fig, ax = plt.subplots(figsize=(fig_width, fig_height))

    # Plot the bars for each key, shifting them to appear side by side
    rects1 = ax.bar(x - bar_width / 2, values_A, bar_width, label='Skip features')
    rects2 = ax.bar(x + bar_width / 2, values_B, bar_width, label='Patch features')

    # Add labels, title, and custom x-axis tick labels
    ax.set_ylabel('CKA')
    ax.set_title('CKA with global representations')
    ax.set_xlabel('Model')
    ax.set_xticks(x)
    ax.set_xticklabels(model_sizes)
    ax.legend()

    plt.savefig(join('results', 'residual_cka') + '.pdf')

if __name__ == '__main__':
    main()