import os
import json
import argparse
from pathlib import Path
from typing import Dict, List, Optional

import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer


def check_gpu_availability() -> None:
    """Check for available GPUs and print their names."""
    gpu_count = torch.cuda.device_count()
    if gpu_count > 0:
        print(f"Found {gpu_count} GPU(s) on this machine.")
        for i in range(gpu_count):
            print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    else:
        print("No GPU found on this machine.")


def load_model_and_tokenizer(model_name: str):
    """Load model and tokenizer based on the model name."""
    print(f"Loading model and tokenizer: {model_name}")
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto').eval()
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer


def get_model_layers(model) -> int:
    """Get the number of layers in the model."""
    if hasattr(model.config, 'num_hidden_layers'):
        return model.config.num_hidden_layers
    elif hasattr(model.config, 'n_layer'):
        return model.config.n_layer
    else:
        raise AttributeError("Model config missing 'num_hidden_layers' or 'n_layer' attribute.")


class CounterFactDataset(Dataset):
    """Thin wrapper to load the CounterFact JSON and optionally truncate size."""
    def __init__(self, data_path: str, size: Optional[int] = None):
        data_path = Path(data_path)
        with open(data_path, "r") as f:
            self.data = json.load(f)
        if size is not None:
            self.data = self.data[:size]
        print(f"Loaded CounterFact dataset from {data_path} with {len(self)} elements.")

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int):
        return self.data[idx]
    

def gather_all_activations(
    subjects: List[str],
    model,
    tokenizer,
    pos: str = "down_proj_input",
    device: str = "cuda",
) -> Dict[str, List[torch.Tensor]]:
    """
    Collect activations for each subject across all layers.

    Args:
        subjects: List of text prompts (subjects).
        model: Loaded HF model.
        tokenizer: Loaded HF tokenizer.
        pos:
            - "down_proj_input"  : input to MLP down_proj (use module inputs)
            - "down_proj_output" : output of MLP down_proj (use module outputs)
            - "post_feedforward_layernorm_output": output of post-FFN LayerNorm
        device: Target device for input tensors.

    Returns:
        Dict mapping subject -> list of tensors [layer0_vec, layer1_vec, ...].
        Each tensor is 1D (hidden_size), taken from the last token position.
    """
    activations = {subject: [] for subject in subjects}
    num_layers = get_model_layers(model)
    layer_activations: Dict[int, torch.Tensor] = {}

    def make_hook(layer_idx: int):
        def hook(mod, inputs, outputs):
            nonlocal layer_activations
            # based on the requested position.
            if pos == "down_proj_input":
                layer_activations[layer_idx] = inputs[0].detach()
            elif pos in ("down_proj_output", "post_feedforward_layernorm_output"):
                layer_activations[layer_idx] = outputs.detach()
            else:
                raise ValueError(f"Unsupported position: {pos}")
            return outputs
        return hook

    # Register hooks
    handles = []
    try:
        for layer_idx in range(num_layers):
            if pos == "down_proj_input":
                handle = model.model.layers[layer_idx].mlp.down_proj.register_forward_hook(make_hook(layer_idx))
            elif pos == "down_proj_output":
                handle = model.model.layers[layer_idx].mlp.down_proj.register_forward_hook(make_hook(layer_idx))
            elif pos == "post_feedforward_layernorm_output":
                handle = model.model.layers[layer_idx].post_feedforward_layernorm.register_forward_hook(make_hook(layer_idx))
            else:
                raise ValueError(f"Unsupported position: {pos}")
            handles.append(handle)

        # Run forward passes and capture the last-token vector per layer
        for subject in tqdm(subjects, desc="Processing subjects"):
            inputs = tokenizer.encode(subject, return_tensors="pt", add_special_tokens=True).to(device)
            layer_activations.clear()
            with torch.no_grad():
                _ = model(inputs, use_cache=False)

            activations[subject] = [
                layer_activations[layer_idx][0, -1, :].detach().cpu()
                for layer_idx in range(num_layers)
            ]
            torch.cuda.empty_cache()
    finally:
        for h in handles:
            h.remove()

    return activations


def main(args):
    check_gpu_availability()
    
    model, tokenizer = load_model_and_tokenizer(args.model_name)

    # Load dataset and construct subject list
    if args.dataset == "Counterfact":
        cf_data = CounterFactDataset("data/counterfact.json")
        subjects = [cf_data[i]["requested_rewrite"]["subject"] for i in range(len(cf_data))]
    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}")

    # Collect activations
    acts = gather_all_activations(subjects, model, tokenizer, pos=args.pos, device="cuda")

    # Prepare output directory
    out_dir = os.path.join("data", "Representation", args.pos, args.model_name, args.dataset)
    os.makedirs(out_dir, exist_ok=True)

    # Split by layer and save as separate files
    num_layers = get_model_layers(model)
    per_layer: List[Dict[str, torch.Tensor]] = [{} for _ in range(num_layers)]
    for key, vec_list in acts.items():
        if len(vec_list) != num_layers:
            print(f"Warning: key '{key}' has {len(vec_list)} items (expected {num_layers}). Skipped.")
            continue
        for i in range(num_layers):
            per_layer[i][key] = vec_list[i]

    for i, layer_dict in enumerate(per_layer):
        save_path = os.path.join(out_dir, f"layer_{i}.pt")
        torch.save(layer_dict, save_path)
        print(f"Saved: {save_path}")

    print("All layers saved successfully.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Extract per-layer activations for given subjects.")
    parser.add_argument(
        "--model_name", type=str, 
        default="google/gemma-2-2b",
        help="Hugging Face model id (e.g., google/gemma-2-2b)."
    )
    parser.add_argument(
        "--dataset", type=str, 
        default="Counterfact",
        help="Dataset to use."
    )
    parser.add_argument(
        "--pos", type=str,
        default="down_proj_input",
        choices=["down_proj_input", "down_proj_output", "post_feedforward_layernorm_output"],
        help="Which module position to tap activations from.")
    args = parser.parse_args()
    main(args)

# Examples:
# python3 Collect_Representation_Counterfact.py --model_name=google/gemma-2-9b --dataset=Counterfact --pos=down_proj_input
# CUDA_VISIBLE_DEVICES=8,9 python3 Collect_Representation_Counterfact.py --model_name=google/gemma-2-9b --dataset=Counterfact --pos=post_feedforward_layernorm_output
# python3 Collect_Representation_Counterfact.py --model_name=meta-llama/Meta-Llama-3.1-8B --dataset=Counterfact --pos=down_proj_input
# CUDA_VISIBLE_DEVICES=8,9 python3 Collect_Representation_Counterfact.py --model_name=meta-llama/Meta-Llama-3.1-8B --dataset=Counterfact --pos=down_proj_output
