import os
import sys
from collections import defaultdict
import argparse
from tqdm.auto import tqdm

# Ensure the lmkit library is in the Python path
# This assumes you run the script from the root of the project directory
if "." not in sys.path:
    sys.path.insert(0, ".")

from lmkit.sparse.sae import SAEKit
from lmkit.sparse.atlas_utils import StatsCollector, save_collector
from lmkit.tools import data


def main(args):
    """
    Main function to run the collector extraction process.
    """
    print(f"Loading SAEKit for model '{args.model_dir}' and SAE '{args.sae_name}'...")
    sae_dir = os.path.join(args.saes_base_dir, args.sae_name)

    # 1. Load the SAE Kit which contains the models, tokenizers, and configs
    try:
        sae_kit = SAEKit.load(
            model_dir=args.model_dir, checkpoint_id=args.ckpt_id, sae_dir=sae_dir
        )
    except FileNotFoundError as e:
        print(f"Error loading models: {e}")
        print("Please ensure your model_dir, saes_base_dir, and ckpt_id are correct.")
        sys.exit(1)

    print("SAEKit loaded successfully.")

    # 2. Load and tokenize the dataset
    print(f"Loading and tokenizing dataset from '{args.dataset_dir}'...")
    try:
        ds = data.load_and_tokenize(
            args.dataset_dir,
            sae_kit.tokenizer,
            args.batch_size,
            num_processes=args.num_processes,
            seed=2002,
            caching=True,
            limit=args.num_molecules,
        )
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print(
            "Please check the dataset path and ensure it's a valid Hugging Face dataset directory."
        )
        sys.exit(1)

    print(f"Dataset loaded with {len(ds)} batches.")

    # 3. Create the output directory
    output_dir = os.path.join(args.output_dir, args.sae_name)
    os.makedirs(output_dir, exist_ok=True)
    print(f"Collectors will be saved to: {output_dir}")

    # 4. Determine which layers to process
    num_layers = len(sae_kit.sae_configs)
    layers_to_process = range(num_layers)
    if args.layer is not None:
        if 0 <= args.layer < num_layers:
            layers_to_process = [args.layer]
            print(f"Only processing specified layer: {args.layer}")
        else:
            print(f"Error: --layer must be between 0 and {num_layers - 1}.")
            sys.exit(1)
    else:
        print(f"Found {num_layers} layers to process.")

    # 5. Iterate through the selected layers, collect stats, and save the collector
    for layer_id in layers_to_process:
        print(f"\n----- Processing Layer {layer_id} -----")

        # Instantiate a StatsCollector for the current layer
        collector = StatsCollector(
            latent_size=sae_kit.sae_configs[layer_id].latent_size,
            top_sequences=args.top_sequences,
            top_tokens=args.top_tokens,
        )

        # Iterate through the dataset batches and update the collector
        for batch in tqdm(ds, desc=f"Layer {layer_id}", total=len(ds)):
            try:
                collector.update(batch, sae_kit, layer_id=layer_id)
            except Exception as e:
                print(
                    f"An error occurred during batch processing for layer {layer_id}: {e}"
                )
                # Optionally, decide if you want to skip the batch or stop
                # continue

        # Save the populated collector to a file
        output_path = os.path.join(output_dir, f"{args.sae_name}_layer{layer_id}.pkl")
        save_collector(collector, output_path)
        print(f"✅ Saved collector for layer {layer_id} to {output_path}")

    print("\nAll selected layers processed successfully.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Extract and save SAE StatsCollectors for each layer of a model."
    )

    # Model and SAE paths
    parser.add_argument(
        "--model_dir",
        type=str,
        default="models/transformer_sm",
        help="Directory of the base transformer model.",
    )
    parser.add_argument(
        "--saes_base_dir",
        type=str,
        default="models/saes",
        help="Base directory containing different SAE runs.",
    )
    parser.add_argument(
        "--sae_name",
        type=str,
        required=True,
        help="Name of the specific SAE run to process (e.g., 'relu_4x_e9a211').",
    )
    parser.add_argument(
        "--ckpt_id",
        type=int,
        default=59712,
        help="Checkpoint ID of the transformer model.",
    )

    # Dataset arguments
    parser.add_argument(
        "--dataset_dir",
        type=str,
        default="~/data/z20ll_filtered_scafsplit/valid",
        help="Path to the tokenized dataset directory.",
    )
    parser.add_argument(
        "--num_molecules",
        type=int,
        default=1_000_000,
        help="Number of molecules to process from the dataset.",
    )

    # Processing arguments
    parser.add_argument(
        "--batch_size", type=int, default=512, help="Batch size for processing."
    )
    parser.add_argument(
        "--num_processes",
        type=int,
        default=4,
        help="Number of parallel processes for data loading.",
    )
    # ADDED: Layer flag
    parser.add_argument(
        "--layer",
        type=int,
        default=None,
        help="Specify a single layer to process. If not set, all layers will be processed.",
    )

    # Collector arguments
    parser.add_argument(
        "--top_sequences",
        type=int,
        default=100,
        help="Number of top activating sequences to store per neuron.",
    )
    parser.add_argument(
        "--top_tokens",
        type=int,
        default=100,
        help="Number of top activating tokens to store per neuron.",
    )

    # Output directory
    parser.add_argument(
        "--output_dir",
        type=str,
        default="atlas",
        help="Directory to save the output collector files.",
    )

    args = parser.parse_args()

    # Expand user path for dataset directory
    args.dataset_dir = os.path.expanduser(args.dataset_dir)

    main(args)
