#!/usr/bin/env python3
"""
Interactive script to run steering vectors on user prompts.

This script loads trained steering vectors and applies them to user-provided prompts,
showing a comparison between steered and unsteered generations.

Usage:
    python scripts/run_steering.py --vectors path/to/steering_vectors.pt --vector-idx 42 --prompt "Your prompt here"
    python scripts/run_steering.py --vectors path/to/steering_vectors.pt --vector-idx 42 --coefficient 1.5 --prompt "Your prompt here"
"""

import argparse
import logging
import random
import sys
from pathlib import Path
from typing import Dict, Any

import torch
import yaml
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_utils import PreTrainedModel
import torch.nn as nn

# Add src to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))

from steering import ActivationSteering


def extract_continuations(
    generated_ids: torch.Tensor, input_ids: torch.Tensor
) -> list[list[int]]:
    """Extract only the generated continuations after the original (padded) input."""
    assert generated_ids.dim() == 2 and input_ids.dim() == 2, "expected (B, T) tensors"
    assert generated_ids.shape[0] == input_ids.shape[0], "batch sizes must match"

    seq_len = int(input_ids.shape[1])
    batch_size = int(generated_ids.shape[0])

    continuations: list[list[int]] = []
    for i in range(batch_size):
        cont = generated_ids[i, seq_len:].tolist()
        continuations.append(cont)
    return continuations


logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


def get_module(model: PreTrainedModel, layer_idx: int) -> nn.Module:
    """Get the transformer block module at given layer index."""
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        # Llama/Qwen style models
        return model.model.layers[layer_idx]
    elif hasattr(model, "transformer") and hasattr(model.transformer, "h"):
        # GPT style models
        return model.transformer.h[layer_idx]
    else:
        raise ValueError(f"Unsupported model architecture: {type(model)}")


def get_submodule(block: nn.Module, submodule_path: str | None) -> nn.Module:
    """Resolve a dotted submodule path within a transformer block, or return block if None."""
    if submodule_path is None or submodule_path == "":
        return block
    cur: nn.Module = block
    for name in submodule_path.split("."):
        if not hasattr(cur, name):
            raise ValueError(
                f"Block {type(block)} does not have submodule '{submodule_path}' (missing '{name}')"
            )
        cur = getattr(cur, name)
    if not isinstance(cur, nn.Module):
        raise ValueError(
            f"Resolved path '{submodule_path}' is not a module: {type(cur)}"
        )
    return cur


def load_config(config_path: Path) -> Dict[str, Any]:
    """Load configuration from YAML file."""
    with open(config_path, "r") as f:
        try:
            # Try safe loading first
            config = yaml.safe_load(f)
        except yaml.constructor.ConstructorError:
            # If that fails due to Python objects, use unsafe loading
            f.seek(0)
            config = yaml.unsafe_load(f)

    # Convert any Path objects to strings for compatibility
    if "output_dir" in config and hasattr(config["output_dir"], "__fspath__"):
        config["output_dir"] = str(config["output_dir"])

    return config


def setup_model_and_tokenizer(
    config: Dict[str, Any], device: str
) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
    """Load model and tokenizer based on config."""
    logging.info(f"Loading model: {config['model_name']}")

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        config["model_name"], trust_remote_code=True
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        config["model_name"],
        torch_dtype=torch.bfloat16,
        device_map=device,
        trust_remote_code=True,
    )
    model.eval()

    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad = False

    return model, tokenizer


def setup_steering_hook(
    model: AutoModelForCausalLM, steering_vectors: torch.Tensor, config: Dict[str, Any]
) -> ActivationSteering:
    """Setup steering hook based on config."""
    source_layer = config["source_layer"]
    source_submodule = config.get("source_submodule")

    # Get the module to hook
    source_block = get_module(model, source_layer)
    source_module = get_submodule(source_block, source_submodule)

    # Create steering hook
    steering_hook = ActivationSteering(
        source_module=source_module,
        steering_vector_bank=steering_vectors,
    )

    logging.info(
        f"Setup steering hook on layer {source_layer}, submodule: {source_submodule}"
    )
    return steering_hook


def tokenize_prompt(
    prompt: str, tokenizer: AutoTokenizer, max_length: int = 512
) -> Dict[str, torch.Tensor]:
    """Tokenize user prompt."""
    # For chat models, we might want to format the prompt properly
    if (
        "chat" in tokenizer.name_or_path.lower()
        or "instruct" in tokenizer.name_or_path.lower()
    ):
        # Apply chat template if available
        if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
            messages = [{"role": "user", "content": prompt}]
            prompt = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )

    # Tokenize
    encoding = tokenizer(
        prompt,
        return_tensors="pt",
        padding=False,
        truncation=True,
        max_length=max_length,
    )
    return encoding


def generate_with_steering(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    steering_hook: ActivationSteering,
    prompt_tokens: Dict[str, torch.Tensor],
    vector_idx: int | None = None,
    coefficient: float = 1.0,
    **generation_kwargs,
) -> str:
    """Generate text with optional steering."""
    device = next(model.parameters()).device

    # Move tokens to device
    batch = {k: v.to(device) for k, v in prompt_tokens.items()}

    # Set steering if vector_idx is provided
    if vector_idx is not None:
        vector_idxs = torch.tensor([vector_idx], device=device)
        steering_hook.set_vector_idxs(vector_idxs)
        steering_hook.set_coefficient(coefficient)
    else:
        steering_hook.clear_steering()

    # Generate
    with torch.inference_mode():
        generated_ids = model.generate(**batch, **generation_kwargs)

    # Extract continuation
    continuations = extract_continuations(generated_ids, batch["input_ids"])
    generated_text = tokenizer.decode(continuations[0], skip_special_tokens=True)

    return generated_text


def main():
    parser = argparse.ArgumentParser(description="Run steering vectors on user prompts")
    parser.add_argument(
        "--vectors", required=True, type=Path, help="Path to steering_vectors.pt file"
    )
    parser.add_argument(
        "--vector-idx", required=True, type=int, help="Index of steering vector to use"
    )
    parser.add_argument(
        "--prompt", required=True, type=str, help="User prompt to generate from"
    )
    parser.add_argument(
        "--coefficient",
        type=float,
        default=1.0,
        help="Steering coefficient (default: 1.0)",
    )
    parser.add_argument(
        "--max-new-tokens", type=int, default=128, help="Maximum new tokens to generate"
    )
    parser.add_argument(
        "--temperature", type=float, default=0.7, help="Generation temperature"
    )
    parser.add_argument("--top-p", type=float, default=0.95, help="Generation top-p")
    parser.add_argument(
        "--do-sample", action="store_true", help="Use sampling for generation"
    )
    parser.add_argument(
        "--device", type=str, default="auto", help="Device to use (auto/cuda/cpu)"
    )
    parser.add_argument("--seed", type=int, default=42, help="Random seed")

    args = parser.parse_args()

    # Set random seeds for reproducibility
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    # Validate inputs
    vectors_path = Path(args.vectors)
    if not vectors_path.exists():
        raise FileNotFoundError(f"Steering vectors file not found: {vectors_path}")

    config_path = vectors_path.parent / "config.yaml"
    if not config_path.exists():
        raise FileNotFoundError(f"Config file not found: {config_path}")

    # Determine device
    if args.device == "auto":
        device = "cuda" if torch.cuda.is_available() else "cpu"
    else:
        device = args.device

    logging.info(f"Using device: {device}")

    # Load configuration
    logging.info(f"Loading config from: {config_path}")
    config = load_config(config_path)

    # Load steering vectors
    logging.info(f"Loading steering vectors from: {vectors_path}")
    steering_vectors = torch.load(vectors_path, map_location=device)

    # Validate vector index
    if args.vector_idx < 0 or args.vector_idx >= steering_vectors.shape[0]:
        raise ValueError(
            f"Vector index {args.vector_idx} out of range [0, {steering_vectors.shape[0] - 1}]"
        )

    logging.info(
        f"Loaded {steering_vectors.shape[0]} steering vectors of dimension {steering_vectors.shape[1]}"
    )
    logging.info(f"Using vector {args.vector_idx} with coefficient {args.coefficient}")

    # Setup model and tokenizer
    model, tokenizer = setup_model_and_tokenizer(config, device)

    # Setup steering hook
    steering_hook = setup_steering_hook(model, steering_vectors, config)

    # Tokenize prompt
    logging.info(f"Tokenizing prompt: {args.prompt}")
    prompt_tokens = tokenize_prompt(args.prompt, tokenizer)

    # Generation parameters
    generation_kwargs = {
        "max_new_tokens": args.max_new_tokens,
        "do_sample": args.do_sample,
        "temperature": args.temperature,
        "top_p": args.top_p,
        "pad_token_id": tokenizer.pad_token_id,
        "eos_token_id": tokenizer.eos_token_id,
        "use_cache": True,
    }

    print("\n" + "=" * 80)
    print("STEERING VECTOR GENERATION COMPARISON")
    print("=" * 80)
    print(f"Model: {config['model_name']}")
    print(f"Vector Index: {args.vector_idx}")
    print(f"Steering Coefficient: {args.coefficient}")
    print(
        f"Source Layer: {config['source_layer']} ({config.get('source_submodule', 'full module')})"
    )
    print("=" * 80)

    print(f"\nPROMPT:\n{args.prompt}")

    # Generate without steering (baseline)
    print(f"\nBASELINE GENERATION (no steering):")
    print("-" * 50)
    baseline_text = generate_with_steering(
        model,
        tokenizer,
        steering_hook,
        prompt_tokens,
        vector_idx=None,
        **generation_kwargs,
    )
    print(baseline_text)

    # Generate with steering
    print(f"\nSTEERED GENERATION (vector {args.vector_idx}, coeff={args.coefficient}):")
    print("-" * 50)
    steered_text = generate_with_steering(
        model,
        tokenizer,
        steering_hook,
        prompt_tokens,
        vector_idx=args.vector_idx,
        coefficient=args.coefficient,
        **generation_kwargs,
    )
    print(steered_text)

    print("\n" + "=" * 80)
    print("GENERATION COMPLETE")
    print("=" * 80)

    # Clean up
    steering_hook.clear_hook()


if __name__ == "__main__":
    main()
