#%%
import pickle
import torch
from tqdm.auto import tqdm
import numpy as np
from globals import *

d_sae = 16384
layers = np.arange(0,26)

def get_freqs(data):
    total_counts = torch.zeros(d_sae, dtype=torch.float32)
    total_tokens = 0
    for sparse_tensor in data.values():
        context_length = 1024
        activation_counts = (sparse_tensor.to_dense() != 0).sum(dim=0)
        total_counts += activation_counts
        total_tokens += context_length
    freqs = total_counts / total_tokens
    return freqs

#%%
for layer in tqdm(layers):
    with open(f'gemmascope/layer{layer}_data_0-5000.pkl','rb') as f:
        data = pickle.load(f)
    freqs = get_freqs(data)
    torch.save(freqs,f'gemmascope/layer{layer}_freqs_0-5000.pt')

for layer in tqdm(layers):
    freqs,saes = get_mydata(layer,freqs=True,saes=True)
    elephants = print_elephants_with_pairs(freqs,saes,thres=0.1)
    elephant_acts_dict = {e:torch.zeros(5000*1024) for e in elephants}
    with open(f'gemmascope/layer{layer}_data_0-5000.pkl','rb') as f:
        data = pickle.load(f)

    for j, (contexti, sparse_tensor_dict) in enumerate(data.items()):
        for sparse_tensor in sparse_tensor_dict.values():
            dense_tensor = sparse_tensor.to_dense()
            for e in elephants:
                elephant_acts_dict[e][j*1024:(j+1)*1024] = dense_tensor[:,e]
    with open(f'gemmascope/layer{layer}_elephant_acts_dict_0-5000.pkl','wb') as f:
        pickle.dump(elephant_acts_dict,f)
