#!/usr/bin/env python
import os
import argparse
import numpy as np
import pandas as pd
import torch
import h5py
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformers import AutoTokenizer
from datasets import load_dataset

def autodetect_pad_id(tokenizer):
    """Autodetect the padding token ID from the tokenizer."""
    # First try the tokenizer's pad_token_id
    if tokenizer.pad_token_id is not None:
        print(f"Using tokenizer's pad_token_id: {tokenizer.pad_token_id}")
        return tokenizer.pad_token_id

    # If no pad token, try eos_token_id (common fallback)
    if tokenizer.eos_token_id is not None:
        print(f"No pad_token_id found, using eos_token_id: {tokenizer.eos_token_id}")
        return tokenizer.eos_token_id

    # Try <|endoftext|> token
    try:
        endoftext_id = tokenizer('<|endoftext|>').input_ids[0]
        print(f"Using '<|endoftext|>' token ID: {endoftext_id}")
        return endoftext_id
    except:
        pass

    # Default to 0 (common for many models)
    print("Warning: Could not detect pad token, defaulting to 0")
    return 0

def main():
    parser = argparse.ArgumentParser(description="Process dataset with transformer model")
    parser.add_argument("--max_length", type=int, default=2048, help="Maximum sequence length")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
    parser.add_argument("--dataset", type=str, default="timaeus/pile-github", help="Dataset to use")
    parser.add_argument("--truncate_to_max_length", action="store_true", help="Truncate sequences to max length.")
    parser.add_argument("--layers", type=int, nargs='+', default=[26], help="List of layers to hook into")
    parser.add_argument("--output_files", type=str, nargs='+', default=["Pile-github_Qwen2.5-1.5B_L26.h5"], help="Output HDF5 file paths")
    parser.add_argument("--model_path", type=str, default="Qwen/Qwen2.5-1.5B", help="Path to the model")
    parser.add_argument("--pad_id", type=int, default=None, help="Padding token ID (default: autodetect from tokenizer)")
    args = parser.parse_args()

    # Ensure one output file per layer
    assert len(args.layers) == len(args.output_files), \
        f"Must provide one output file per layer. Got {len(args.layers)} layers and {len(args.output_files)} files."

    device = "cuda" if torch.cuda.is_available() else "cpu"
    ds = load_dataset(args.dataset)
    ds_len = len(ds["train"])

    # Load model and tokenizer
    tl_model = HookedTransformer.from_pretrained_no_processing(
        args.model_path, device=device
    ).to(device)
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    tl_model.eval()

    # Autodetect or use specified pad_id
    if args.pad_id is None:
        pad_id = autodetect_pad_id(tokenizer)
    else:
        pad_id = args.pad_id
        print(f"Using user-specified pad_id: {pad_id}")

    # Build hook names
    layers = args.layers
    # hook_names = [f"blocks.{L}.hook_mlp_out" for L in layers]
    hook_names = [f"blocks.{L}.hook_resid_post" for L in layers]
    max_layer = max(layers)

    # Determine hidden_dim with a dummy batch
    dummy_texts = ds["train"][0:args.batch_size]["text"]
    dummy_toks = tokenizer(
        dummy_texts,
        return_tensors="pt",
        padding='max_length',
        truncation=True,
        max_length=args.max_length
    ).input_ids.to(device)

    with torch.no_grad():
        _, cache = tl_model.run_with_cache(
            dummy_toks,
            stop_at_layer=max_layer+1,
            names_filter=hook_names
        )
    pad_mask = dummy_toks == pad_id
    example = cache[hook_names[-1]][~pad_mask]
    hidden_dim = example.shape[1]
    emb_dtype = example.dtype

    # Open HDF5 files & create datasets
    h5_handles = {}
    dsets = {}
    parts = {}
    for L, path in zip(layers, args.output_files):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        f = h5py.File(path, "w")
        dsets[L] = f.create_dataset(
            "non_padding_cache",
            shape=(0, hidden_dim),
            maxshape=(None, hidden_dim),
            dtype=np.float32,
            chunks=(args.batch_size, hidden_dim)
        )
        # Partition info per example
        parts[L] = {
            "start":  f.create_dataset("start",  shape=(ds_len,), dtype=np.int64),
            "end":    f.create_dataset("end",    shape=(ds_len,), dtype=np.int64),
            "length": f.create_dataset("length", shape=(ds_len,), dtype=np.int64)
        }
        h5_handles[L] = f

    total_trunc = 0

    # Process in batches
    for i in tqdm(range(0, ds_len, args.batch_size), desc="Processing batches"):
        batch_texts = ds["train"][i:i+args.batch_size]["text"]
        if args.truncate_to_max_length:
            toks = tokenizer(
                batch_texts,
                return_tensors="pt",
                padding='max_length',
                truncation=True,
                max_length=args.max_length
            ).input_ids.to(device)
        else:
            toks = tokenizer(
                batch_texts,
                return_tensors="pt",
                padding='longest'
            ).input_ids.to(device)

        with torch.no_grad():
            _, cache = tl_model.run_with_cache(
                toks,
                stop_at_layer=max_layer+1,
                names_filter=hook_names
            )

        pad_mask = toks == pad_id
        nonpad_mask = ~pad_mask
        lengths = nonpad_mask.sum(dim=1)
        start0 = parts[layers[0]]["end"][i-1] if i>0 else 0
        starts = start0 + lengths.cumsum(dim=0) - lengths
        ends = start0 + lengths.cumsum(dim=0)

        # Write partition info into each HDF5
        for L in layers:
            parts[L]["start"][i:i+len(lengths)]  = starts.cpu().numpy()
            parts[L]["end"][i:i+len(lengths)]    = ends.cpu().numpy()
            parts[L]["length"][i:i+len(lengths)] = lengths.cpu().numpy()

        total_trunc += (nonpad_mask[:, -1]).sum().item()

        # Write embeddings per layer
        for L, hook in zip(layers, hook_names):
            raw = cache[hook]               # (B, T, hidden_dim)
            flat = raw[nonpad_mask]         # (sum(lengths), hidden_dim)
            arr = flat.cpu().numpy()

            ds_ = dsets[L]
            old = ds_.shape[0]
            new = arr.shape[0]
            ds_.resize((old + new, hidden_dim))
            ds_[old:old+new] = arr

        if (i // args.batch_size) % 10 == 0:
            total_rows = sum(d.shape[0] for d in dsets.values())
            print(f"Batch {i//args.batch_size}: total rows {total_rows},"
                  f" truncated frac {total_trunc/(i+args.batch_size):.2%}")

    # Close all HDF5 files
    for L, path in zip(layers, args.output_files):
        h5_handles[L].close()
        print(f"Finished writing layer {L} to {path}")

if __name__ == "__main__":
    main()
