import os
from modules import constants
os.environ["HF_HOME"] = os.path.join(constants.DATA_DIR, "raw")
import torch
import pickle

from sae_lens import SAE
from datasets import load_dataset
from transformer_lens import HookedTransformer
from transformer_lens.utils import tokenize_and_concatenate

from tqdm import tqdm
import argparse
import sys

parser = argparse.ArgumentParser(description="Evaluate SAE L0 norm")
parser.add_argument("--layer_num", type=int, default=6, help="Layer number to evaluate")
parser.add_argument("--model_name", type=str, default="pythia", help="Model name to evaluate")
args = parser.parse_args()

print(f"Evaluating SAE L0 norm for layer {args.layer_num} with ef {args.ef}")

baseline_sae_folder = os.path.join(constants.MODEL_DIR, "baseline_saes")
baseline_saes = os.listdir(baseline_sae_folder)
baseline_sae_paths = [os.path.join(baseline_sae_folder, baseline_sae) for baseline_sae in baseline_saes]
baseline_sae_path = [path for path in baseline_sae_paths if path.__contains__(f"blocks.{args.layer_num}.hook")][0]


sae = SAE.load_from_pretrained(baseline_sae_path, device="cuda")

ngram_list = ["normal", 1, 2, 6, 10, 30]

model_name = baseline_sae_path.split("/")[-3].split("_")[3]
print(f"Using model name: {model_name}")

model = HookedTransformer.from_pretrained(model_name, device="cuda")

for sweep_num in range(5):
    ngram_l0_dict = {}
    for ngram in ngram_list:
        if ngram == "normal":
            dataset = load_dataset("EleutherAI/fineweb-edu-dedup-10b", split="train")
        else:
            dataset = load_dataset(f"<DATASET>-{ngram}gram-shuffled", split="train")
        
        dataset = dataset.shuffle(seed=42).select(range(1000*sweep_num, 1000*(sweep_num+1)))

        dataset_tok = tokenize_and_concatenate(
            dataset, 
            model.tokenizer,
            streaming=False,
            max_length=sae.cfg.context_size,
            add_bos_token=sae.cfg.prepend_bos
        )

        sae.eval()
        batch_size = 4 
        num_batches = len(dataset_tok) // batch_size

        l0s = []
        l0s_alt = []
        with torch.no_grad():
            for i in tqdm(range(num_batches)):
                batch = dataset_tok[i * batch_size : (i + 1) * batch_size]
                batch_tokens = batch["tokens"]

                _, cache = model.run_with_cache(
                    batch_tokens,
                    prepend_bos=True
                )
                
                feature_acts = sae.encode(cache[sae.cfg.hook_name])
                del cache

                l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
                l0s.append(l0)
                l0_alt = ((feature_acts[:, 1:] > 0).float().sum(1) > 0).float().sum(-1).mean().detach()
                l0s_alt.append(l0_alt)

        
        l0_mean = torch.hstack(l0s).mean()
        l0_mean = l0_mean.item()
        l0_std = torch.hstack(l0s).std()
        l0_std = l0_std.item()

        ngram_l0_dict[ngram] = (l0_mean, l0_std)
        print(f"{ngram}gram dataset for layer {args.layer_num}: {l0_mean:.2f} ± {l0_std:.2f}")

    # Save the results using pickle
    with open(os.path.join(constants.DATA_DIR, "interim", "sae_l0_comp", "pythia_robustness", f"ngram_l0_layer{args.layer_num}_{args.ef}_{args.model_name}_{sweep_num}.pkl"), "wb") as f:
        pickle.dump(ngram_l0_dict, f)

