import os
from modules import constants
os.environ["HF_HOME"] = os.path.join(constants.DATA_DIR, "raw")
import torch
import pandas as pd

from modules.helpers import shuffle_ngrams

from sae_lens import SAE
from datasets import load_dataset
from transformer_lens import HookedTransformer
import argparse
import zarr

argparser = argparse.ArgumentParser()
argparser.add_argument(
    "--ngram",
    type=str,
    default="normal",
    help="shuffle to use",
)
args = argparser.parse_args()
ngram = args.ngram

output_dir = os.path.join(constants.DATA_DIR, "interim", "sae_activations", f"{ngram}", "gemma-2b")
os.makedirs(output_dir, exist_ok=True)

model = HookedTransformer.from_pretrained("gemma-2b-it", device="cuda")

prompt = "You are a chat bot answering questions using data. You must stick to the answers provided solely by the text in the passage provided. \
You are asked the question 'Provide a concise summary of the following passage, covering the core pieces of information described.' <PASSAGE>\n"
source_dataset = load_dataset("vectara/leaderboard_results", split="train")
source_dataset = source_dataset.filter(lambda example: example["model"] == "microsoft/Phi-2")
source_text = source_dataset["source"]

source_df = pd.DataFrame(source_text, columns=["source"])
source_df.to_csv(os.path.join(output_dir, f"source_text.csv"), index=True)

with torch.no_grad():
    for layer_num in [0,6,10,12,17]:
        os.makedirs(os.path.join(output_dir, f"layer_{layer_num}"), exist_ok=True)
        sae, cfg_dict, sparsity = SAE.from_pretrained(
            release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
            sae_id = f"blocks.{layer_num}.hook_resid_post", # won't always be a hook point
            device = "cuda"
        )

        result_dict = {}
        for i in range(source_df.shape[0]):
            print(f"Processing {i} of {source_df.shape[0]}")
            if ngram == "normal":
                new_prompt = prompt.replace("<PASSAGE>", source_df["source"][i])
            else:
                ngram = int(ngram)
                shuffled_passage = shuffle_ngrams(source_df["source"][i], n=ngram)
                new_prompt = prompt.replace("<PASSAGE>", shuffled_passage)
            tokens = model.to_tokens(new_prompt)

            _, cache = model.run_with_cache(
                        tokens,
                        prepend_bos=True,
                        names_filter=[sae.cfg.hook_name]
                    )
            feature_acts = sae.encode(cache[sae.cfg.hook_name]).squeeze().cpu().numpy()

            # Save the activations to a Zarr file
            zarr_path = os.path.join(output_dir, f"layer_{layer_num}", f"acts_{i}.zarr")
            zarr_file = zarr.open(zarr_path, mode='w', shape=feature_acts.shape, dtype=feature_acts.dtype)
            zarr_file[:] = feature_acts

            del cache
            del feature_acts
