#%%
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import json
import torch
from sae_lens import SAE
from globals import get_elephants_thres, load_pickle, vecs_in_subspace_norm

l0s = [22,41,73,176,445]
d_saes = [16384,32768,65536,262144,524288,1048576]
d_saes_strs = ['16k','32k','65k','262k','524k','1m']
use_l0 = False # If True, vary L0. If False, vary dictionary size

#%%
freqs_dict = {}
saes_dict = {}
elephants_dict = {}

if use_l0:
    params_to_try = l0s
    d_sae = d_saes[0]
else:
    params_to_try = d_saes
    d_saes_dict = dict(zip(d_saes, d_saes_strs))

for param in tqdm(params_to_try):
    if use_l0:
        l0 = param
        d_sae = d_saes[0]
        if l0 == 73:
            path = 'neel/results/gemma-2-2b/12-gemmascope-res-16k.json'
        else:
            path = f'neel/results/gemma-2-2b/12-gemmascope-res-16k__l0-{l0}.json'
    else:
        d_sae = param
        path = f'neel/results/gemma-2-2b/12-gemmascope-res-{d_saes_dict[param]}.json'
        
    with open(path,'r') as f:
        original = json.load(f)
    freqs = torch.zeros(d_sae,dtype=torch.float32)
    for item in original['latents']:
        freqs[int(item['index'])] = item['actDensity']

    sae = SAE.from_pretrained(
            release = original['saelensRelease'],
            sae_id = original['saelensSaeId'],
            device = "cpu",
        )[0]
    saes = sae.W_dec.detach().cpu()
    elephants = get_elephants_thres(freqs,0.1)
    elephants_dict[param] = elephants
    elephants_saes = saes[elephants]

    freqs_dict[param] = freqs
    saes_dict[param] = elephants_saes

#%%
import plotly.graph_objects as go
from sklearn.metrics.pairwise import cosine_similarity

pca = load_pickle(f'gemmascope/model_acts/layer12_0-500_pca.pkl')
pca_components = pca.components_[0].reshape(1,-1)

fig = go.Figure()

x_positions = list(range(len(params_to_try)))
param_to_pos = dict(zip(params_to_try, x_positions))

for param in params_to_try:
    cos_sims = np.abs(cosine_similarity(saes_dict[param],pca_components).flatten())
    
    fig.add_trace(go.Scatter(
        x=[param_to_pos[param]]*len(cos_sims),
        y=cos_sims,
        mode='markers',
        marker=dict(
            size=7,
            opacity=0.5,
        )
    ))

fig.update_layout(
    xaxis_title='L0' if use_l0 else 'Dictionary Size',
    yaxis_title=f'Absolute Cosine Similarity with PC1',
    margin=dict(l=50, r=10, t=10, b=50),
    showlegend=False,
    height=300, 
    width=400,
    xaxis = dict(
        tickmode = 'array',
        ticktext = d_saes_strs,
        tickvals = x_positions,
    )
)

#%%
fig.write_image("*PLOTS/pc1_cosine_similarity_diffwidths.pdf",scale=20)
fig.show()

