#%%
import os
import argparse
from tqdm.auto import tqdm
import numpy as np
import torch
import torch.nn as nn
from collections import defaultdict
import pickle
import matplotlib.pyplot as plt
from globals import *
from sae_lens.sae import SAE

from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
import datasets
from huggingface_hub import hf_hub_download, list_repo_files, notebook_login
# %%
layers = np.arange(0,26)
n_docs_start = 0
n_docs = 5000
all_data = defaultdict(dict)

DATASET = "togethercomputer/RedPajama-Data-1T-Sample"
device = torch.device("cuda:1")
dtype = torch.float32
device_map = {'': 1}

#%%
with torch.no_grad():
    model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b",
    device_map=device_map,
    torch_dtype=dtype,
    token='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'
    )
    tokenizer =  AutoTokenizer.from_pretrained("google/gemma-2-2b",token='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')
    dataset = datasets.load_dataset(DATASET, split="train")
    dataset = dataset.shuffle(seed=42)

    res_saes = {}
    for layeri in layers:
        sae = load_sae_lens(f'layer_{layeri}/width_16k/canonical','gemma_2_2b').to(device)
        res_saes[layeri] = sae

    j = 0
    for i, doc in tqdm(enumerate(dataset)):
        if j >= n_docs:
            break
        doc = dataset[i]
        inputs = tokenizer(
            doc['text'],
            return_tensors="pt", 
            add_special_tokens=True,
            max_length=1024,  
            truncation=True
            ).to(device)
        if len(inputs[0]) < 1024:
            continue
        for layeri in layers:
            target_act = model(**inputs,labels=inputs['input_ids'],output_hidden_states=True).hidden_states[layeri+1][0]
            all_data[layeri][i] = target_act
        if (j+1) % 50 == 0:
            for layeri in layers:
                with open(f'gemmascope/model_acts/layer{layeri}_{j}.pkl', "wb") as f:
                    pickle.dump(all_data[layeri], f)
                all_data[layeri] = {}

        for layeri in layers:
            target_act = model(**inputs,labels=inputs['input_ids'],output_hidden_states=True).hidden_states[layeri+1][0]
            sae = res_saes[layeri]
            sae_acts = sae.encode(target_act).to_sparse().cpu()
            all_data[layeri][i] = sae_acts
        
        if (j+1) % 500 == 0:
            for layeri in layers:
                with open(f'gemmascope/layer{layeri}_data_{j}.pkl', "wb") as f:
                    pickle.dump(all_data[layeri], f)
                all_data[layeri] = {}
        j += 1
    
    for layeri in layers:
        with open(f'gemmascope/layer{layeri}_data_{j}.pkl','wb') as f:
            pickle.dump(all_data[layeri],f)

# %%
