import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="clean_v3")
parser.add_argument("--step", type=str, default="19900")
parser.add_argument("--num_examples", type=int, default=200)
parser.add_argument("--batch_size", type=int, default=1)
args = parser.parse_args()



BATCH_SIZE = args.batch_size
NUM_EXAMPLES = args.num_examples
MODEL_ID = args.model
STEP = args.step
MODEL_KEY = f"~/pythia_replicate_public_models/{MODEL_ID}/step={STEP}"
OUTPUT_PATH = f"./data/means_per_head_{MODEL_ID}.pt"


import torch

from nnsight import LanguageModel
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from datasets import load_dataset
from collections import defaultdict
from torch.utils.data import DataLoader
from tqdm import tqdm
import os




pythia = LanguageModel(
    AutoModelForCausalLM.from_pretrained(
        MODEL_KEY, 
        attn_implementation="eager", 
        torch_dtype=torch.float16, 
        device_map="cuda:0"), 
    tokenizer=AutoTokenizer.from_pretrained("EleutherAI/pythia-160m"), 
    config=AutoConfig.from_pretrained(MODEL_KEY))

pythia.tokenizer.pad_token = pythia.tokenizer.eos_token
pythia.eval()





dataset = load_dataset(
    "stas/openwebtext-10k",
    split="train",
    revision="refs/convert/parquet",              
    data_files="plain_text/train/*.parquet",      
)


dataset = dataset.select(range(NUM_EXAMPLES))


dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

breakpoint()



def get_means_per_head(pythia_model, dataloader):

    if not os.path.exists(OUTPUT_PATH):

        means_per_head = defaultdict(int)
        num_tokens = 0

        for batch in tqdm(dataloader, desc="Processing batches"):

            prompt = batch['text']

            with torch.no_grad():

                with pythia_model.trace() as tracer:

                    with tracer.invoke(prompt) as invoker:

                        num_tokens += invoker.inputs[1]["input_ids"].shape[1]
                        num_tokens.save()
                        
                        for l_idx, layer in enumerate(pythia_model.gpt_neox.layers):

                            attn_out = layer.attention.source.None_0.source.torch_matmul_0.output[0]

                            for h_idx in range(pythia_model.config.num_attention_heads):

                                head_out = attn_out[h_idx, :, :]

                                head_out_2 = torch.mean(head_out.clone().detach(), dim=0)

                                means_per_head[f"l_{l_idx}.h_{h_idx}"] += head_out_2.cpu()

        means_per_head = {k: v / num_tokens for k, v in means_per_head.items()}

        torch.save(means_per_head, OUTPUT_PATH)

get_means_per_head(pythia, dataloader)

