import logging
from pathlib import Path
import random
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_utils import PreTrainedModel
import wandb

from ..steering import ActivationSteering
from .generation_utils import extract_continuations
from .config import SteerCLRTrainerConfig
from .loss import (
    diversity_loss,
    magnitude_loss,
    orthogonality_loss,
)

logging.basicConfig(level=logging.INFO)


def get_module(model: PreTrainedModel, layer_idx: int) -> nn.Module:
    """Get the transformer block module at given layer index."""
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        # Llama/Qwen style models
        return model.model.layers[layer_idx]
    elif hasattr(model, "transformer") and hasattr(model.transformer, "h"):
        # GPT style models
        return model.transformer.h[layer_idx]
    else:
        raise ValueError(f"Unsupported model architecture: {type(model)}")


def get_submodule(block: nn.Module, submodule_path: str | None) -> nn.Module:
    """Resolve a dotted submodule path within a transformer block, or return block if None."""
    if submodule_path is None or submodule_path == "":
        return block
    cur: nn.Module = block
    for name in submodule_path.split("."):
        if not hasattr(cur, name):
            raise ValueError(
                f"Block {type(block)} does not have submodule '{submodule_path}' (missing '{name}')"
            )
        cur = getattr(cur, name)
    if not isinstance(cur, nn.Module):
        raise ValueError(
            f"Resolved path '{submodule_path}' is not a module: {type(cur)}"
        )
    return cur


class SteerCLRTrainer:
    def __init__(
        self,
        model: AutoModelForCausalLM,
        tokenizer: AutoTokenizer,
        dataloader: DataLoader,
        config: SteerCLRTrainerConfig,
        experiment_dir: Path | None = None,
        val_dataloader: DataLoader | None = None,
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.dataloader = dataloader
        self.config = config
        self.device = next(model.parameters()).device
        self.experiment_dir = (
            Path(experiment_dir) if experiment_dir is not None else None
        )
        self.val_dataloader = val_dataloader

        # Freeze model parameters; we only optimize steering vectors
        for param in self.model.parameters():
            param.requires_grad = False

        # Initialize steering vectors
        hidden_dim = model.config.hidden_size

        # Keep steering vectors in float32 for optimizer stability
        init_vectors = torch.randn(
            config.n_vectors, hidden_dim, device=self.device, dtype=torch.float32
        )
        self.steering_vectors = nn.Parameter(init_vectors)

        # Enforce initial radius constraint (numerically stable)
        with torch.no_grad():
            norms = torch.norm(self.steering_vectors, p=2, dim=-1, keepdim=True)
            self.steering_vectors.data *= self.config.radius / norms

        # Setup optimizer with learning rate scheduling
        self.optimizer = AdamW(
            [self.steering_vectors],
            lr=config.learning_rate,
            betas=(0.9, 0.98),
            weight_decay=config.optimizer_weight_decay,
            amsgrad=True,
        )

        # Learning rate warmup
        warmup_steps = int(config.n_training_steps * config.warmup_steps_ratio)

        def lr_lambda(step):
            if warmup_steps > 0 and step < warmup_steps:
                return step / warmup_steps
            return 1.0

        self.scheduler = LambdaLR(self.optimizer, lr_lambda)

        # Initialize persistent steering hook (archive style)
        self._setup_steering_hook()

        # Initialize wandb
        self._init_wandb()

    def _init_wandb(self):
        """Initialize Weights & Biases logging."""
        run_name = self.config.wandb_run_name
        if run_name is None:
            model_name = self.config.model_name.split("/")[-1]
            run_name = f"{model_name}_steerclr_{self.config.n_vectors}vec"

        self.wandb_run = wandb.init(
            project=self.config.wandb_project,
            entity=self.config.wandb_entity,
            name=run_name,
            tags=self.config.wandb_tags
            + [
                self.config.model_name.split("/")[-1],
                f"n_vectors_{self.config.n_vectors}",
                f"target_layer_{self.config.target_layer}",
                f"source_layer_{self.config.source_layer}",
            ],
            config=self.config.to_dict(),
            reinit=True,
        )

        # Log model configuration
        wandb.config.update(
            {
                "model_hidden_size": self.model.config.hidden_size,
                "model_dtype": str(next(self.model.parameters()).dtype),
                "device": str(self.device),
                "total_parameters": sum(p.numel() for p in self.model.parameters()),
                "trainable_parameters": self.steering_vectors.numel(),
            },
            allow_val_change=True,
        )

        logging.info(f"Initialized wandb run: {self.wandb_run.name}")

    def _setup_steering_hook(self):
        """Set up persistent steering hook (archive style)."""
        source_block = get_module(self.model, self.config.source_layer)
        source_module = get_submodule(source_block, self.config.source_submodule)

        # Create persistent steering hook
        self.steering_hook = ActivationSteering(
            source_module=source_module,
            steering_vector_bank=self.steering_vectors,
        )
        logging.info(
            f"Registered persistent steering hook on layer {self.config.source_layer}"
        )

    def compute_hidden_states(
        self,
        batch: dict[str, torch.Tensor],
        target_layer: int,
        target_tokens: slice = slice(None),
    ) -> torch.Tensor:
        """Compute hidden states for given batch (archive style)."""
        # Use output_hidden_states=True to get all layer activations
        outputs = self.model(
            batch["input_ids"],
            attention_mask=batch.get("attention_mask", None),
            output_hidden_states=True,
        )
        hidden_states = outputs.hidden_states[
            target_layer
        ]  # (batch, seq_len, hidden_dim)
        return hidden_states[:, target_tokens, :]

    def set_steering_vectors(self, vector_idxs: torch.Tensor):
        """Set which steering vectors to apply (archive style)."""
        self.steering_hook.set_vector_idxs(vector_idxs)

    def clear_steering_vectors(self):
        """Clear steering vectors (archive style)."""
        self.steering_hook.clear_steering()

    def train(self):
        """Main training loop."""
        self.model.eval()
        self.train_iter = iter(self.dataloader)

        pbar = tqdm(range(self.config.n_training_steps), dynamic_ncols=True)
        for step in pbar:
            # Get next batch
            try:
                batch = next(self.train_iter)
            except StopIteration:
                self.train_iter = iter(self.dataloader)
                batch = next(self.train_iter)

            # Move tensor data to device, keep text fields on CPU
            batch = {
                k: v.to(self.device) if k not in ("text", "unformatted_text") else v
                for k, v in batch.items()
            }

            # Sample vector indices for this batch
            batch_vector_idxs = torch.tensor(
                random.sample(
                    range(self.config.n_vectors), self.config.num_vectors_per_batch
                ),
                device=self.device,
            )

            # Repeat batch for each vector
            N = batch["input_ids"].shape[0]
            K = self.config.num_vectors_per_batch

            # Handle tensors and text fields separately
            repeated_batch = {}
            text_fields = {"text", "unformatted_text"}

            for k, v in batch.items():
                if k in text_fields:
                    # For text fields, repeat each string K times
                    if isinstance(v, list):
                        repeated_batch[k] = [item for item in v for _ in range(K)]
                    else:
                        # Should not happen in normal batched data, but handle gracefully
                        repeated_batch[k] = [v] * K
                else:
                    # For tensor fields, use repeat_interleave
                    repeated_batch[k] = v.repeat_interleave(K, dim=0)

            batch = repeated_batch
            vector_idxs = batch_vector_idxs.repeat(N)

            # Get activation shifts
            assert torch.isfinite(self.steering_vectors).all(), (
                "Steering vectors contain non-finite values"
            )
            activation_shifts = self._get_activation_shifts(batch, vector_idxs)

            assert torch.isfinite(activation_shifts).all(), (
                "Activation shifts are not finite"
            )

            # Calculate losses
            if self.config.alpha > 0:
                loss_mag = magnitude_loss(
                    activation_shifts, p=self.config.alpha_p, q=self.config.alpha_q
                )
            else:
                # Create zero tensor that maintains gradient tracking
                loss_mag = activation_shifts.sum() * 0.0

            if self.config.beta > 0:
                loss_div = diversity_loss(
                    activation_shifts,
                    vector_idxs,
                    temperature=self.config.tau,
                    loss_type=self.config.diversity_loss_type,
                )
            else:
                # Create zero tensor that maintains gradient tracking
                loss_div = activation_shifts.sum() * 0.0

            if self.config.lambda_ > 0:
                loss_ortho = orthogonality_loss(
                    self.steering_vectors,
                    eps=self.config.epsilon,
                    style=self.config.orthogonality_style,
                )
            else:
                # Create zero tensor that maintains gradient tracking
                loss_ortho = self.steering_vectors.sum() * 0.0

            if (
                self.config.alpha <= 0
                and self.config.beta <= 0
                and self.config.lambda_ <= 0
            ):
                raise ValueError(
                    "At least one of alpha, beta, or lambda_ must be greater than 0"
                )

            if not torch.isfinite(loss_div):
                print(f"WARNING: Non-finite diversity loss: {loss_div}")
                print(f"Vector idxs: {vector_idxs}")
                print(f"Activation shifts shape: {activation_shifts.shape}")

            if not torch.isfinite(loss_ortho):
                print(f"WARNING: Non-finite orthogonality loss: {loss_ortho}")
                print(
                    f"Steering vectors stats: mean={self.steering_vectors.mean()}, std={self.steering_vectors.std()}, min={self.steering_vectors.min()}, max={self.steering_vectors.max()}"
                )

            total_loss = (
                self.config.alpha * loss_mag
                + self.config.beta * loss_div
                + self.config.lambda_ * loss_ortho
            )

            # Backward pass
            total_loss.backward()

            # Gradient clipping for numerical stability
            torch.nn.utils.clip_grad_norm_(self.steering_vectors, max_norm=1.0)

            self.optimizer.step()
            self.scheduler.step()
            self.optimizer.zero_grad(set_to_none=True)

            assert torch.isfinite(total_loss), "Total loss is not finite"

            with torch.no_grad():
                norms = torch.norm(self.steering_vectors, p=2, dim=-1, keepdim=True)
                self.steering_vectors.data *= self.config.radius / norms.clamp_min(
                    1e-12
                )

            # Log metrics
            current_lr = self.scheduler.get_last_lr()[0]
            if self.wandb_run:
                log_dict = {
                    "train/total_loss": total_loss.item(),
                    "train/magnitude_loss": loss_mag.item(),
                    "train/diversity_loss": loss_div.item(),
                    "train/orthogonality_loss": loss_ortho.item(),
                    "train/learning_rate": current_lr,
                    "train/step": step,
                }

                # Log vector statistics
                with torch.no_grad():
                    vector_norms = torch.norm(self.steering_vectors, p=2, dim=1)
                    log_dict.update(
                        {
                            "vectors/mean_norm": vector_norms.mean().item(),
                            "vectors/std_norm": vector_norms.std().item(),
                            "vectors/min_norm": vector_norms.min().item(),
                            "vectors/max_norm": vector_norms.max().item(),
                        }
                    )

                wandb.log(log_dict, step=step)

            # Update progress bar
            pbar.set_description(
                f"Loss: {total_loss.item():.4f} "
                f"(Mag: {loss_mag.item():.4f}, "
                f"Div: {loss_div.item():.4f}, "
                f"Ortho: {loss_ortho.item():.4f}) "
                f"LR: {current_lr:.2e}"
            )

            # Periodic validation
            if (step + 1) % self.config.val_frequency == 0:
                try:
                    self.run_validation(step=step + 1)
                except Exception as e:
                    logging.warning(f"Validation at step {step + 1} failed: {e}")

    def _get_activation_shifts(
        self, batch: dict[str, torch.Tensor], vector_idxs: torch.Tensor
    ) -> torch.Tensor:
        """Get activation shifts for the current batch (archive style)."""
        # Determine target tokens based on config
        if isinstance(self.config.token_idxs, int) and self.config.token_idxs < 0:
            target_tokens = slice(self.config.token_idxs, None)  # Last k tokens
        else:
            target_tokens = slice(None)  # All tokens

        # Get unsteered activations (baseline)
        self.clear_steering_vectors()
        unsteered_activations = self.compute_hidden_states(
            batch, self.config.target_layer, target_tokens
        ).float()

        # Get steered activations
        self.set_steering_vectors(vector_idxs)
        steered_activations = self.compute_hidden_states(
            batch, self.config.target_layer, target_tokens
        ).float()

        # Calculate activation shifts
        activation_shifts = steered_activations - unsteered_activations

        # Mean-pool over tokens if we have multiple tokens
        attention_mask = batch.get("attention_mask", None)
        if activation_shifts.shape[1] > 1:
            if attention_mask is not None:
                # Mask-aware mean pooling
                if target_tokens != slice(None):
                    mask = attention_mask[:, target_tokens]
                else:
                    mask = attention_mask
                num = (activation_shifts * mask.unsqueeze(-1)).sum(dim=-2)
                denom = mask.sum(dim=-1, keepdim=True).clamp_min(1).to(num.dtype)
                activation_shifts = num / denom
            else:
                activation_shifts = activation_shifts.mean(dim=-2)
        else:
            activation_shifts = activation_shifts.squeeze(-2)

        return activation_shifts

    def _repeat_batch(
        self, batch: dict[str, torch.Tensor], repeats: int
    ) -> dict[str, torch.Tensor]:
        """Repeat a batch along the batch dimension."""
        assert repeats >= 1
        result = {}
        for k, v in batch.items():
            if k in ("text", "unformatted_text"):
                result[k] = [item for item in v for _ in range(repeats)]
            else:
                result[k] = v.repeat_interleave(repeats, dim=0)
        return result

    def _filter_generation_kwargs(self, batch: dict) -> dict:
        """Filter out non-tensor fields that are not valid for model.generate()."""
        return {k: v for k, v in batch.items() if k not in ("text", "unformatted_text")}

    def _mean_pool_activations(
        self, activations: torch.Tensor, attention_mask: torch.Tensor | None
    ) -> torch.Tensor:
        """Mean-pool activations over sequence length using an attention mask."""
        if activations.shape[1] == 1:
            return activations.squeeze(1)
        if attention_mask is not None:
            # Mask-aware mean pooling
            num = (activations * attention_mask.unsqueeze(-1)).sum(dim=1)
            denom = attention_mask.sum(dim=1, keepdim=True).clamp_min(1.0)
            return num / denom
        else:
            return activations.mean(dim=1)

    def save_vectors(self, output_dir: str | Path):
        """Save the learned steering vectors."""
        output_dir = Path(output_dir)
        output_dir.mkdir(exist_ok=True, parents=True)
        torch.save(
            self.steering_vectors.detach().cpu(), output_dir / "steering_vectors.pt"
        )
        logging.info(f"Saved steering vectors to {output_dir / 'steering_vectors.pt'}")

    def finalize_wandb(self):
        """Finalize wandb run."""
        if self.wandb_run:
            wandb.finish()
            logging.info("Finalized wandb run")

    def run_validation(self, step: int | None = None) -> None:
        """Run validation including loss calculation and text generation."""
        if self.val_dataloader is None:
            logging.info("No validation dataloader provided, skipping validation.")
            return

        logging.info(f"Running validation at step {step}")

        # 1. Calculate validation loss
        try:
            val_loss, loss_components = self._calculate_validation_loss()
            logging.info(f"Validation loss: {val_loss:.6f}")
            if self.wandb_run and step is not None:
                wandb.log(loss_components, step=step)
        except Exception as e:
            logging.warning(f"Validation loss calculation failed: {e}", exc_info=True)

        # 2. Run generation and statistics validation
        if self.config.save_generations and self.experiment_dir is not None:
            if (step is None) or self.config.generate_during_training:
                try:
                    self._run_generation_and_stats_validation(step)
                except Exception as e:
                    logging.warning(
                        f"Generation and statistics validation failed: {e}",
                        exc_info=True,
                    )
        logging.info(f"Validation completed at step {step}")

    def _calculate_validation_loss(self) -> tuple[float, dict]:
        """Calculate validation loss on the validation dataset."""
        self.model.eval()
        total_loss_mag = 0.0
        total_loss_div = 0.0
        num_batches = 0

        # Orthogonality loss is independent of the batch, calculate it once
        with torch.no_grad():
            if self.config.lambda_ > 0:
                loss_ortho = orthogonality_loss(
                    self.steering_vectors,
                    eps=self.config.epsilon,
                    style=self.config.orthogonality_style,
                )
            else:
                loss_ortho = torch.tensor(0.0, device=self.device)

        with torch.no_grad():
            for batch_idx, batch in tqdm(
                enumerate(self.val_dataloader), dynamic_ncols=True
            ):
                batch = {
                    k: v.to(self.device) if k not in ("text", "unformatted_text") else v
                    for k, v in batch.items()
                }

                N = batch["input_ids"].shape[0]
                K = self.config.num_vectors_per_batch

                base_start = (batch_idx * K) % self.config.n_vectors
                batch_vector_idxs = (
                    torch.arange(base_start, base_start + K, device=self.device)
                    % self.config.n_vectors
                )
                vector_idxs = batch_vector_idxs.repeat(N)

                rep_batch = {
                    k: v.repeat_interleave(K, dim=0)
                    for k, v in batch.items()
                    if k not in ("text", "unformatted_text")
                }

                try:
                    activation_shifts = self._get_activation_shifts(
                        rep_batch, vector_idxs
                    )

                    loss_mag = (
                        magnitude_loss(
                            activation_shifts,
                            p=self.config.alpha_p,
                            q=self.config.alpha_q,
                        )
                        if self.config.alpha > 0
                        else torch.tensor(0.0, device=self.device)
                    )
                    loss_div = (
                        diversity_loss(
                            activation_shifts,
                            vector_idxs,
                            temperature=self.config.tau,
                            loss_type=self.config.diversity_loss_type,
                        )
                        if self.config.beta > 0
                        else torch.tensor(0.0, device=self.device)
                    )
                    total_loss_mag += loss_mag.item()
                    total_loss_div += loss_div.item()
                    num_batches += 1
                except Exception as e:
                    logging.warning(f"Error calculating val loss for batch: {e}")
                    continue

        avg_loss_mag = total_loss_mag / max(1, num_batches)
        avg_loss_div = total_loss_div / max(1, num_batches)
        loss_ortho_item = loss_ortho.item()

        total_val_loss = (
            self.config.alpha * avg_loss_mag
            + self.config.beta * avg_loss_div
            + self.config.lambda_ * loss_ortho_item
        )
        loss_components = {
            "validation/total_loss": total_val_loss,
            "validation/magnitude_loss": avg_loss_mag,
            "validation/diversity_loss": avg_loss_div,
            "validation/orthogonality_loss": loss_ortho_item,
        }
        return total_val_loss, loss_components

    def _run_generation_and_stats_validation(self, step: int | None = None) -> None:
        """
        Run text generation and statistics validation.
        This method combines generation and statistics calculation in a single pass
        over the validation dataset for efficiency.
        """
        assert self.val_dataloader is not None
        assert self.experiment_dir is not None
        self.model.eval()

        # Setup paths
        step_str = f"{int(step):06d}" if step is not None else "final"
        gen_path = self.experiment_dir / f"validation_generations_{step_str}.jsonl"
        stats_path = self.experiment_dir / f"validation_stats_{step_str}.json"

        # Precompute vector stats
        vector_norms = torch.norm(self.steering_vectors, p=2, dim=-1).tolist()

        # Stats accumulators
        stats_accum = {
            v_idx: {
                "last_layer_l2": 0.0,
                "last_layer_cos": 0.0,
                "logits_l2": 0.0,
                "logits_kl": 0.0,
                "samples": 0,
            }
            for v_idx in range(self.config.n_vectors)
        }

        all_generations = []
        # Generation kwargs
        gen_kwargs = dict(
            max_new_tokens=self.config.val_max_new_tokens,
            do_sample=self.config.val_do_sample,
            temperature=self.config.val_temperature,
            top_p=self.config.val_top_p,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            use_cache=True,
        )

        logging.info(f"Starting generation and stats validation for step {step}")
        for batch_idx, batch in tqdm(
            enumerate(self.val_dataloader),
            total=len(self.val_dataloader),
            desc="Validation",
        ):
            batch = {
                k: v.to(self.device) if k not in ("text", "unformatted_text") else v
                for k, v in batch.items()
            }
            B = batch["input_ids"].shape[0]

            # 1. Baseline (unsteered) pass
            self.clear_steering_vectors()
            with torch.inference_mode():
                # Forward pass for stats
                base_out = self.model(
                    **self._filter_generation_kwargs(batch),
                    output_hidden_states=True,
                )
                unsteered_acts = base_out.hidden_states[
                    self.config.target_layer
                ].float()
                unsteered_logits = base_out.logits.float()

                # Generation
                base_gens = self.model.generate(
                    **self._filter_generation_kwargs(batch), **gen_kwargs
                )
                base_texts = self.tokenizer.batch_decode(
                    extract_continuations(base_gens, batch["input_ids"]),
                    skip_special_tokens=True,
                )

            # Store baseline generations
            for i, text in enumerate(base_texts):
                all_generations.append(
                    {
                        "prompt_idx": batch_idx * B + i,
                        "vector_idx": "baseline",
                        "generated_text": text,
                        "input_text": batch["text"][i],
                    }
                )

            # 2. Steered passes in chunks
            vectors_per_call = max(1, int(self.config.val_vectors_per_call))
            for v_offset in range(0, self.config.n_vectors, vectors_per_call):
                v_end = min(v_offset + vectors_per_call, self.config.n_vectors)
                chunk_v_indices = list(range(v_offset, v_end))
                K = len(chunk_v_indices)

                rep_batch = self._repeat_batch(batch, K)
                v_indices_tensor = (
                    torch.tensor(chunk_v_indices, device=self.device)
                    .unsqueeze(0)
                    .repeat(B, 1)
                    .view(-1)
                )

                self.set_steering_vectors(v_indices_tensor)
                with torch.inference_mode():
                    # Forward pass for stats
                    steered_out = self.model(
                        **self._filter_generation_kwargs(rep_batch),
                        output_hidden_states=True,
                    )
                    steered_acts = steered_out.hidden_states[
                        self.config.target_layer
                    ].float()
                    steered_logits = steered_out.logits.float()

                    # Generation
                    steered_gens = self.model.generate(
                        **self._filter_generation_kwargs(rep_batch), **gen_kwargs
                    )
                    steered_texts = self.tokenizer.batch_decode(
                        extract_continuations(steered_gens, rep_batch["input_ids"]),
                        skip_special_tokens=True,
                    )

                # Compute and accumulate stats for the chunk
                self._accumulate_stats(
                    stats_accum,
                    unsteered_acts,
                    unsteered_logits,
                    steered_acts,
                    steered_logits,
                    batch.get("attention_mask"),
                    chunk_v_indices,
                )

                # Store steered generations
                for i in range(B):
                    for j, v_idx in enumerate(chunk_v_indices):
                        text = steered_texts[i * K + j]
                        all_generations.append(
                            {
                                "prompt_idx": batch_idx * B + i,
                                "vector_idx": v_idx,
                                "generated_text": text,
                                "input_text": batch["text"][i],
                            }
                        )

        # Sort and save generations
        all_generations.sort(
            key=lambda d: (
                d["prompt_idx"],
                -1 if d["vector_idx"] == "baseline" else d["vector_idx"],
            )
        )
        with open(gen_path, "w") as f:
            for item in all_generations:
                f.write(json.dumps(item) + "\n")
        logging.info(f"Saved validation generations to {gen_path}")

        # Finalize and save statistics
        final_stats = self._finalize_stats(stats_accum, vector_norms)
        with open(stats_path, "w") as f:
            json.dump(final_stats, f, indent=2)
        logging.info(f"Saved validation statistics to {stats_path}")

    def _accumulate_stats(
        self,
        stats_accum: dict,
        unsteered_acts: torch.Tensor,
        unsteered_logits: torch.Tensor,
        steered_acts: torch.Tensor,
        steered_logits: torch.Tensor,
        attention_mask: torch.Tensor | None,
        chunk_v_indices: list[int],
    ):
        """Helper to compute and accumulate statistics for a chunk of vectors."""
        B, S, H = unsteered_acts.shape
        V = unsteered_logits.shape[-1]
        K = len(chunk_v_indices)

        # Pool activations
        pooled_unsteered = self._mean_pool_activations(unsteered_acts, attention_mask)

        rep_mask = (
            attention_mask.repeat_interleave(K, 0)
            if attention_mask is not None
            else None
        )
        pooled_steered = self._mean_pool_activations(steered_acts, rep_mask)
        pooled_steered = pooled_steered.view(B, K, H)

        # Last-token logits
        last_token_unsteered = unsteered_logits[:, -1, :]
        last_token_steered = steered_logits.view(B * K, S, V)[:, -1, :].view(B, K, V)

        # Expand unsteered for comparison
        exp_unsteered_acts = pooled_unsteered.unsqueeze(1).expand(-1, K, -1)
        exp_unsteered_logits = last_token_unsteered.unsqueeze(1).expand(-1, K, -1)

        # Activation stats
        diff_acts = pooled_steered - exp_unsteered_acts
        l2_acts = torch.norm(diff_acts, p=2, dim=-1).mean(0)
        cos_acts = F.cosine_similarity(pooled_steered, exp_unsteered_acts, dim=-1).mean(
            0
        )

        # Logit stats
        diff_logits = last_token_steered - exp_unsteered_logits
        l2_logits = torch.norm(diff_logits, p=2, dim=-1).mean(0)

        p = F.softmax(exp_unsteered_logits, dim=-1).clamp(min=1e-9)
        q = F.softmax(last_token_steered, dim=-1).clamp(min=1e-9)
        kl_div = (q * (q.log() - p.log())).sum(dim=-1).mean(0)

        for j, v_idx in enumerate(chunk_v_indices):
            stats_accum[v_idx]["last_layer_l2"] += l2_acts[j].item()
            stats_accum[v_idx]["last_layer_cos"] += cos_acts[j].item()
            stats_accum[v_idx]["logits_l2"] += l2_logits[j].item()
            stats_accum[v_idx]["logits_kl"] += kl_div[j].item()
            stats_accum[v_idx]["samples"] += B

    def _finalize_stats(self, stats_accum: dict, vector_norms: list[float]) -> dict:
        """Finalize and average statistics."""
        final_stats = {"vectors": []}
        for v_idx, norms in enumerate(vector_norms):
            stats = stats_accum[v_idx]
            num_samples = max(1, stats["samples"])
            final_stats["vectors"].append(
                {
                    "vector_idx": v_idx,
                    "vector_norm": norms,
                    "last_layer_mean_l2_dist": stats["last_layer_l2"] / num_samples,
                    "last_layer_mean_cos_sim": stats["last_layer_cos"] / num_samples,
                    "logits_mean_l2_dist": stats["logits_l2"] / num_samples,
                    "logits_mean_kl_div": stats["logits_kl"] / num_samples,
                }
            )
        return final_stats
