#%%
import os
import argparse
from tqdm.auto import tqdm
import numpy as np
import torch
import torch.nn as nn
from collections import defaultdict
import pickle
import matplotlib.pyplot as plt
from globals import *
from sae_lens.sae import SAE

from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
import datasets
from huggingface_hub import hf_hub_download, list_repo_files, notebook_login
# %%
layers = np.arange(13,26)
n_docs_start = 0
n_docs = 5000

DATASET = "togethercomputer/RedPajama-Data-1T-Sample"
device = torch.device("cuda:1")
dtype = torch.float32
device_map = {'': 1}

#%%
elephants_dict = {}
saes_dict = {}
b_dec_dict = {}

for layer in tqdm(layers):
    freqs, = get_mydata(layer,freqs=True)
    sae = load_sae_lens(f'layer_{layer}/width_16k/canonical','gemma_2_2b')
    saes = sae.W_dec.detach().cpu().numpy()
    b_dec = sae.b_dec.detach()
    saes_dict[layer] = torch.tensor(saes).to(device)
    b_dec_dict[layer] = b_dec.to(device)
    elephants_dict[layer] = print_elephants_with_pairs(freqs,saes,0.1)

all_data = {layer:torch.zeros((len(elephants_dict[layer]),1024*5000)) for layer in layers}

with torch.no_grad():
    model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b",
    device_map=device_map,
    torch_dtype=dtype,
    token='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'
    )
    tokenizer =  AutoTokenizer.from_pretrained("google/gemma-2-2b",token='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')
    dataset = datasets.load_dataset(DATASET, split="train")
    dataset = dataset.shuffle(seed=42)

    j = 0
    for i, doc in tqdm(enumerate(dataset)):
        if j >= n_docs:
            break
        doc = dataset[i]
        inputs = tokenizer(
            doc['text'],
            return_tensors="pt", 
            add_special_tokens=True,
            max_length=1024,  
            truncation=True  # Add this line
            ).to(device)
        if len(inputs[0]) < 1024:
            continue
        output = model(**inputs,labels=inputs['input_ids'],output_hidden_states=True)
        for layer in layers:
            target_act = output.hidden_states[layer+1][0]
            saes = saes_dict[layer]
            elephant_saes = saes[elephants_dict[layer]]
            b_dec = b_dec_dict[layer]
            projected = elephant_saes @ (target_act-b_dec).T
            all_data[layer][:,j*1024:(j+1)*1024] = projected

        j += 1

for layer in layers:
    torch.save(all_data[layer],f'gemmascope/temp/layer{layer}_proj.pt')
# %%
