"""Model-agnostic utilities for steering."""

import torch
import torch.nn as nn

from typing import Callable


class ActivationSteering:
    """Applies steering vectors persistently to a module's activations."""

    def __init__(
        self,
        source_module: nn.Module,
        steering_vector_bank: torch.Tensor,
        token_idxs: int | list[int] | None = None,
    ):
        self.source_module = source_module
        self.steering_vector_bank = steering_vector_bank  # (num_vectors, hidden_dim)
        self.vector_idxs: torch.Tensor | None = None
        self.hook: torch.utils.hooks.RemovableHandle | None = None
        self.steering_coefficient = 1.0
        self.token_idxs = token_idxs  # Token positions to apply steering to

        # Register persistent hook (archive style)
        self.register_hook()

    def clear_steering(self) -> None:
        """Clear current steering vectors (archive-style)."""
        self.vector_idxs = None

    def _get_steering_hook(self) -> Callable:
        """Create a hook function to apply steering (archive-style: simple and effective)."""

        def hook(module, input, output):
            if self.vector_idxs is None:
                return output

            # Extract the actual tensor from output
            if isinstance(output, tuple):
                hidden_states = output[0]
                is_tuple_output = True
            else:
                hidden_states = output
                is_tuple_output = False

            # Apply steering to specific token positions
            batch_size, seq_len = hidden_states.shape[0], hidden_states.shape[1]
            if batch_size != self.vector_idxs.shape[0]:
                return output  # Batch size mismatch, skip steering

            # Select steering vectors for this batch
            selected_vectors = self.steering_vector_bank[
                self.vector_idxs
            ]  # (batch, hidden_dim)

            # Cast to match hidden states dtype/device
            selected_vectors = selected_vectors.to(
                device=hidden_states.device, dtype=hidden_states.dtype
            )

            # Create a copy to modify
            steered_states = hidden_states.clone()

            # Apply steering only to specified token positions
            if self.token_idxs is None:
                # Apply to all positions (original behavior)
                steering_addition = (
                    self.steering_coefficient * selected_vectors.unsqueeze(1)
                )
                steered_states = hidden_states + steering_addition
            else:
                # Apply to specific token positions
                token_positions = (
                    self.token_idxs
                    if isinstance(self.token_idxs, list)
                    else [self.token_idxs]
                )

                for token_idx in token_positions:
                    # Handle negative indices
                    if token_idx < 0:
                        token_idx = seq_len + token_idx

                    # Apply only if the token position exists
                    if 0 <= token_idx < seq_len:
                        steering_addition = (
                            self.steering_coefficient * selected_vectors
                        )  # (batch, hidden_dim)
                        steered_states[:, token_idx, :] += steering_addition

            if is_tuple_output:
                return (steered_states,) + output[1:]
            else:
                return steered_states

        return hook

    def set_vector_idxs(self, vector_idxs: torch.Tensor):
        """Set the vector indices to apply steering to."""
        self.vector_idxs = vector_idxs

    def register_hook(self):
        """Register the persistent hook to apply steering."""
        self.clear_hook()
        hook = self._get_steering_hook()
        self.hook = self.source_module.register_forward_hook(hook)

    def clear_hook(self):
        """Remove the registered hook."""
        if self.hook:
            self.hook.remove()
            self.hook = None

    def set_coefficient(self, coefficient: float):
        """Set the steering coefficient."""
        self.steering_coefficient = coefficient
