# %%
from pathlib import Path
from typing import Dict, List, Optional, Union

import h5py
import numpy as np
import torch
from datasets import Dataset, DatasetDict, load_dataset
from peft import AutoPeftModelForCausalLM
from sandbagging_research_sprint.utils import load_hf_model
from torch.utils.data import DataLoader

# %%
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast,
)

from eliciting_contexts.utils.constants import DATA_DIR, DEVICE


def get_layer_indices(model: AutoModelForCausalLM) -> List[int]:
    """Get indices of every second layer in the model."""
    return list(range(len(model.model.layers)))


indices_system_prompts = []


def collect_activations(
    model: Union[AutoModelForCausalLM, AutoPeftModelForCausalLM],
    tokenizer: Union[AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast],
    dataset: Union[DatasetDict, Dataset],
    batch_size: int = 8,
    max_length: int = 512,
    num_samples: Optional[int] = None,
    layer_indices: Optional[List[int]] = None,
    act_position: int = -1,  # Position of the activation to collect
    extract_at_system_prompt_position: bool = False,
) -> Dict:
    """Collect activations from specified layers at the last prompt token."""

    if layer_indices is None:
        layer_indices = get_layer_indices(model)

    dataloader = DataLoader(dataset, batch_size=batch_size)
    device = next(model.parameters()).device

    # Initialize storage for activations and metadata
    activations_by_layer = {layer_idx: [] for layer_idx in layer_indices}
    metadata = []

    # Register hooks for all specified layers
    hooks = []

    def create_hook(layer_idx):
        def hook_fn(module, input, output):
            # Get last prompt token activation and convert to float32 before moving to CPU
            last_prompt_activation = output[0][:, act_position, :].to(dtype=torch.float32).detach().cpu()
            activations_by_layer[layer_idx].append(last_prompt_activation)

        return hook_fn

    def create_system_prompt_hook(layer_idx):
        def hook_fn(module, input, output):
            # Get last prompt token activation and convert to float32 before moving to CPU
            global indices_system_prompts

            indices_system_prompts_tensor = torch.tensor(indices_system_prompts)

            first_index = torch.arange(output[0].shape[0], device=output[0].device)

            last_system_prompt_activation = (
                output[0][first_index, indices_system_prompts_tensor].to(dtype=torch.float32).detach().cpu()
            )
            activations_by_layer[layer_idx].append(last_system_prompt_activation)

        return hook_fn

    for layer_idx in layer_indices:
        if extract_at_system_prompt_position:
            hook = model.model.layers[layer_idx].register_forward_hook(create_system_prompt_hook(layer_idx))
        else:
            hook = model.model.layers[layer_idx].register_forward_hook(create_hook(layer_idx))
        hooks.append(hook)

    try:
        samples_processed = 0
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Collecting activations"):
                # Tokenize inputs

                encoded = tokenizer(
                    batch["prompt"],
                    padding=True,
                    truncation=True,
                    max_length=max_length,
                    return_tensors="pt",
                ).to(device)

                if extract_at_system_prompt_position:
                    global indices_system_prompts
                    indices_system_prompts = batch["system_prompt_tokens"]

                inputs = {k: v.to(device) for k, v in encoded.items()}

                # Forward pass
                outputs = model(**inputs)

                # Get the predicted answer from the last token
                last_token_logits = outputs.logits[:, -2, :]  # Get logits for the last prompt token
                predicted_tokens = torch.argmax(last_token_logits, dim=-1)
                predicted_answers = [tokenizer.decode(token).strip() for token in predicted_tokens]

                # Check if the predictions are correct
                is_correct = [
                    pred.lower() in batch["correct_answer"][i].lower() for i, pred in enumerate(predicted_answers)
                ]

                # Store metadata for each example
                for i in range(len(batch["prompt"])):
                    metadata.append(
                        {
                            "sandbagging_environment": batch["sandbagging_environment"][i],
                            "desired_answer": batch["desired_answer"][i],
                            "correct_answer": batch["correct_answer"][i],
                            "predicted_answer": predicted_answers[i],
                            "is_correct": is_correct[i],
                            "prompt": batch["prompt"][i],
                            "act_layer_idx": layer_idx,
                            "act_position": act_position,
                            "extract_at_system_prompt_position": extract_at_system_prompt_position,
                        }
                    )

                samples_processed += len(batch["prompt"])
                if num_samples and samples_processed >= num_samples:
                    break

                # Clear memory
                del inputs, outputs
                torch.cuda.empty_cache()

    finally:
        # Remove hooks
        for hook in hooks:
            hook.remove()

    # Convert activations to numpy arrays (already in float32)
    for layer_idx in layer_indices:
        activations_by_layer[layer_idx] = torch.cat(activations_by_layer[layer_idx], dim=0).numpy()

    return {
        "activations": activations_by_layer,
        "metadata": metadata,
        "layer_indices": layer_indices,
    }


def save_activations(activations_dict: Dict, output_path: Union[str, Path]):
    """Save activations and metadata to a single HDF5 file.

    Handles both single-type activations and difference activations with dual metadata.
    """
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    with h5py.File(str(output_path.with_suffix(".h5")), "w") as f:
        # Store layer indices
        f.create_dataset("layer_indices", data=np.array(activations_dict["layer_indices"]))

        # Store activations for each layer
        acts_group = f.create_group("activations")
        for layer_idx, acts in activations_dict["activations"].items():
            acts_group.create_dataset(str(layer_idx), data=acts)

        # Handle metadata based on structure
        if isinstance(activations_dict["metadata"], dict):
            # Case: difference activations with pos/neg metadata
            meta_group = f.create_group("metadata")
            for meta_type in ["pos_metadata", "neg_metadata"]:
                type_group = meta_group.create_group(meta_type)
                metadata_list = activations_dict["metadata"][meta_type]
                _save_metadata_group(type_group, metadata_list)
        else:
            # Case: single-type activations
            meta_group = f.create_group("metadata")
            _save_metadata_group(meta_group, activations_dict["metadata"])


def _save_metadata_group(group, metadata_list):
    """Helper function to save metadata fields to an HDF5 group."""
    if not metadata_list:  # Handle empty metadata case
        return

    # Get all fields from first metadata entry
    fields = metadata_list[0].keys()

    for field in fields:
        data = [m[field] for m in metadata_list]

        # Convert data if needed and determine dtype
        if isinstance(data[0], bool):
            dtype = bool
        elif isinstance(data[0], (int, np.integer)):
            dtype = int
        elif isinstance(data[0], (float, np.floating)):
            dtype = float
        elif isinstance(data[0], (str, np.str_)):
            dtype = h5py.special_dtype(vlen=str)
        else:
            # Convert non-string objects to strings
            data = [str(item) for item in data]
            dtype = h5py.special_dtype(vlen=str)

        # Convert to numpy array with proper dtype
        if dtype == h5py.special_dtype(vlen=str):
            data = np.array(data, dtype=object)
        else:
            data = np.array(data, dtype=dtype)

        group.create_dataset(field, data=data)


def main():
    # Configuration
    models = ["AtticusAttlee/train-sandbagging-f1dpbban-step11264"]  # Sandbagging model

    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
    # Add padding token if it doesn't exist
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    dataset = load_dataset("contextmodification/sandbagging-sciq")
    assert isinstance(dataset, DatasetDict)

    for model_name in models:
        print(f"\nProcessing model: {model_name}")

        # Load model
        model = load_hf_model(model_name, device_map=DEVICE, attn_implementation="eager")
        model.config.pad_token_id = tokenizer.pad_token_id

        # Get activations for each split
        for split in ["test", "validation"]:
            print(f"\nProcessing {split} split")

            activations_dict = collect_activations(
                model=model,
                tokenizer=tokenizer,
                dataset=dataset[split],
                batch_size=8,
                num_samples=1000,  # Process all samples
                max_length=799,
            )

            # Save activations
            output_path = DATA_DIR / Path(f"activations/{model_name.split('/')[-1]}/{split}")
            print(f"Saving activations to {output_path}")
            save_activations(activations_dict, output_path)

        # Clear memory
        del model
        torch.cuda.empty_cache()


if __name__ == "__main__":
    main()
