#%%
import pickle
import numpy as np
from globals import *
layers = np.arange(0,26)
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
#%%
results = {}
for layer in tqdm(layers):
    freqs, = get_mydata(layer,freqs=True)
    results[layer] = {'freqs':freqs}
    sae = load_sae_lens(f'layer_{layer}/width_16k/canonical','gemma_2_2b')
    saes = sae.W_dec.detach().numpy()
    bias = sae.b_dec.detach().numpy()
    cosine_sim_all = cosine_similarity(saes,bias.reshape(1,-1)).flatten()
    results[layer]['cosine_sim_all'] = cosine_sim_all

# %%
fig = go.Figure()

x_data = np.concatenate([results[layer]['freqs'] for layer in layers])
y_data = np.abs(np.concatenate([results[layer]['cosine_sim_all'] for layer in layers]))
layer_colors = np.repeat(layers, 16384)

fig.add_trace(go.Scatter(
    x=x_data,
    y=y_data,
    mode='markers',
    marker=dict(
        size=2,
        color=layer_colors,
        colorscale='Viridis',
        colorbar=dict(title='Layer'),
        showscale=True
    ),
    opacity=0.6
))

fig.update_layout(
    xaxis_title='Frequency',
    yaxis_title='Absolute Cosine Similarity with Bias',
    xaxis_type='log',
    width=600,
    height=400,
    showlegend=False
)

fig.write_image('bias.png',scale=10)
# fig.show()

# %%
