#%%
import numpy as np
import pickle
import torch
import matplotlib.pyplot as plt
import re
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from globals import *
layers = np.arange(0,26)

#%%
n_components = 1
perc_dict = {}
for layer in tqdm(layers,desc='Layers'):
    perc_dict[layer] = {}
    freqs, saes = get_mydata(layer,freqs=True,saes=True)
    elephants = get_elephants_thres(freqs,0.1)
    pca = load_pickle(f'gemmascope/model_acts/layer{layer}_0-500_pca.pkl')
    if n_components == 1:
        pc1 = pca.components_[0].reshape(1,-1)
        for e in elephants:
            sae = saes[e].reshape(1,-1)
            perc_dict[layer][e] = np.abs(cosine_similarity(sae, pc1))[0,0]
    else:
        pcs = pca.components_[:n_components]
        for e in elephants:
            sae = saes[e]
            perc_dict[layer][e] =  vec_in_subspace_norm(sae, pcs)

#%%
import plotly.express as px
import plotly.graph_objects as go
fig = go.Figure()

for layer in perc_dict:
    for e in perc_dict[layer]:
        fig.add_trace(go.Scatter(
            x=[layer],
            y=[perc_dict[layer][e]],
            mode='markers',
            marker=dict(size=8,color='#636EFA',opacity=0.5),
            name=f'Elephant #{e}',
        ))

y_title = f'Fraction norm in top {n_components} PCs' if n_components > 1 else 'Absolute cosine similarity with PC1'
fig.update_layout(
    xaxis_title='Layer',
    yaxis_title=y_title,
    showlegend=False,
    height=300,
    width=400,
    margin=dict(l=50, r=10, t=10, b=50),
)

fig.write_image(f"*PLOTS/pc{n_components}_percent_pca.pdf",scale=20)

fig.show()
