import argparse
import os
import pickle

import fla  # noqa
import git
import torch
from nethook import TraceDict
from ssmworkbench.text_datamodule import TextArrowFileModule
from transformers import AutoModelForCausalLM, AutoTokenizer


@torch.no_grad()
def collect_activations(
    model,
    data_module,
    device,
    max_length,
    test,
    dataset_name: str,
    seq_len: int,
    model_name: str,
    repo_root: str,
    tokenizer: AutoTokenizer,
):
    model.eval()

    if len(data_module.val_dataloader()) == 0:
        data_loader = data_module.train_dataloader()
    else:
        data_loader = data_module.val_dataloader()

    batch = next(iter(data_loader))
    input_ids: torch.Tensor = batch["src_seq"].to(device)

    layers = [f"model.layers.{m}.attn.k_id" for m in range(24)]
    layers += [f"model.layers.{m}.attn.v_id" for m in range(24)]
    layers += [f"model.layers.{m}.attn.q_id" for m in range(24)]
    layers += [f"model.layers.{m}.attn.beta_id" for m in range(24)]
    # layers += (
    #     [f"model.layers.{m}.attn.g_id" for m in range(24)]
    #     if model_name.startswith("G")
    #     else []
    # )
    with (
        TraceDict(
            model,
            layers=layers,
            retain_input=False,
            retain_output=True,
        ) as td,
        torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True),
    ):
        _ = model(input_ids)  # shape: (B, L, vocab_size)

    os.makedirs(
        os.path.join(
            repo_root,
            "data/activations",
        ),
        exist_ok=True,
    )
    # # <s> is the separator and is input id 1
    # indices_of_s = torch.where(input_ids==1)[1].cpu().tolist()

    with open(
        os.path.join(
            repo_root,
            f"data/activations/data_{dataset_name}_{seq_len}_{model_name}.pkl",
        ),
        "wb",
    ) as file:
        s_start_idx = [
            i
            for i, id in enumerate(input_ids.cpu().detach().numpy().flatten().tolist())
            if id == 1
        ]  # 1 is <s> token id

        data = {k: v.output.float().cpu().detach().numpy() for k, v in td.items()}
        # data["bos_idx"] = indices_of_s

        data["input_ids"] = input_ids.cpu().detach().numpy()
        data["s_start_idx"] = s_start_idx
        pickle.dump(
            data,
            file,
        )


def main():
    # Find repository root
    repo = git.Repo(search_parent_directories=True)
    repo_root = repo.working_tree_dir

    parser = argparse.ArgumentParser(
        description="Evaluate trained models on LM benchmarks"
    )
    parser.add_argument("--max_len", type=int, default=16384)
    args = parser.parse_args()
    device = "cuda"
    dtype = torch.float
    torch.manual_seed(0)

    model_path_name_dict = {
        "legacy/training/runs/simple_combination/Paninetto-Array-sc1_iq0_nq1_1_nq2_0_nv1_paninetto0_irmn1_cc0_ane0-20250819_091519-3430590_1-20250819_091519-3430590_1": "sc1_an0",
        "legacy/training/runs/simple_combination/Paninetto-Array-sc1_iq0_nq1_1_nq2_0_nv1_paninetto0_irmn1_cc0_ane1-20250819_091519-3430591_0-20250819_091519-3430591_0": "sc1_an1",
        "legacy/training/runs/paninetto_on_norm_combinations_rerun/Paninetto-Array-q1_0_nv1_paninetto1_ane0-20250818_100708-3428691_0-20250818_100708-3428691_0": "paninetto1_an0",
        "legacy/training/runs/paninetto_on_norm_combinations_rerun/Paninetto-Array-q1_0_nv1_paninetto1_ane1-20250818_093629-3428670_0-20250818_093629-3428670_0": "paninetto1_an1",
        "legacy/training/runs/paninetto_off_all_norm/Paninetto-Array-q1_0_nv1_paninetto0_ane0-20250818_164335-3428708_1-20250818_164335-3428708_1": "paninetto0_an0",
        "legacy/training/runs/paninetto_off_all_norm/Paninetto-Array-q1_0_nv1_paninetto0_ane1-20250818_110208-3428760_0-20250818_110208-3428760_0": "paninetto0_an1",
    }

    datasets = [
        "trivia_qa",
        "codeparrot",
        "openthoughts-114k-math",
        "openthoughts-114k-code",
    ]

    for model_path, model_name in model_path_name_dict.items():
        model = AutoModelForCausalLM.from_pretrained(
            os.path.join(repo_root, model_path),
            device_map={"": device},
            torch_dtype=dtype,
        )
        tokenizer = AutoTokenizer.from_pretrained(
            os.path.join(repo_root, model_path),
        )
        model.eval()

        for dataset in datasets:
            data_module = TextArrowFileModule(
                tokenizer=os.path.join(repo_root, model_path),
                dataset_name=dataset,
                batch_size=1,
                num_cpu_worker=8,
                max_sample_len=args.max_len,
                data_dir=os.getenv("HF_HOME"),
                cache_dir=os.getenv("HF_DATASETS_CACHE"),
                val_ratio=0.0005,
                val_split_seed=2357,
                seed=42,
            )
            print(f"Collecting activations for {dataset} with {model_name}")
            collect_activations(
                model,
                data_module,
                device,
                args.max_len,
                test=False,
                dataset_name=dataset,
                seq_len=args.max_len,
                model_name=model_name,
                repo_root=repo_root,
                tokenizer=tokenizer,
            )


if __name__ == "__main__":
    main()
