from __future__ import annotations

from typing import Optional, List, Dict, Sequence

import torch
from torch import nn, Tensor


def guess_hook_module_path(model: nn.Module, layer: int) -> str:
    """Try common HF model layer paths and return the first valid path."""
    candidates = [
        f"model.layers.{layer}",
        f"model.model.layers.{layer}",
        f"transformer.h.{layer}",
        f"gpt_neox.layers.{layer}",
        f"model.decoder.layers.{layer}",
    ]
    for path in candidates:
        try:
            _ = get_module_by_path(model, path)
            return path
        except Exception:
            continue
    raise RuntimeError(
        f"Cannot auto-resolve hook module path for layer={layer}. "
        f"Tried: {candidates}. Please pass --hook_module_path explicitly."
    )


def get_module_by_path(root: nn.Module, path: str) -> nn.Module:
    """Resolve a dotted module path (supports integer indexing into ModuleList)."""
    cur: object = root
    for part in path.split("."):
        if part.isdigit():
            cur = cur[int(part)]  # type: ignore[index]
        else:
            cur = getattr(cur, part)
    if not isinstance(cur, nn.Module):
        raise ValueError(f"Path '{path}' did not resolve to an nn.Module.")
    return cur


def _infer_hidden_dim(model: nn.Module, sae: nn.Module) -> int:
    """Infer the hidden dimension D for SAE.encode(x).

    Prefer SAE parameter shapes (e.g., W_dec.shape[1]) because they are authoritative.
    Fall back to model.config.hidden_size if available.
    """
    W_dec = getattr(sae, "W_dec", None)
    if isinstance(W_dec, torch.Tensor) and W_dec.ndim == 2:
        return int(W_dec.shape[1])

    hidden_size = getattr(getattr(model, "config", None), "hidden_size", None)
    if isinstance(hidden_size, int) and hidden_size > 0:
        return hidden_size

    raise RuntimeError(
        "Unable to infer SAE hidden dimension (D). "
        "Provide an SAE with W_dec or a model with config.hidden_size."
    )


def capture_module_activations(
    model: nn.Module,
    input_ids: Tensor,
    attention_mask: Optional[Tensor],
    hook_module_path: str,
    batch_size: int = 32,
) -> Tensor:
    """Capture the output activations of a given submodule for all tokens."""
    submodule = get_module_by_path(model, hook_module_path)

    captured: List[Tensor] = []

    def hook_fn(_module, _inp, out):
        if isinstance(out, (tuple, list)):
            out = out[0]
        captured.append(out.detach())

    handle = submodule.register_forward_hook(hook_fn)

    model.eval()
    with torch.no_grad():
        for start in range(0, input_ids.size(0), batch_size):
            end = start + batch_size
            batch_ids = input_ids[start:end]
            batch_mask = attention_mask[start:end] if attention_mask is not None else None

            model(
                input_ids=batch_ids,
                attention_mask=batch_mask,
            )

    handle.remove()

    return torch.cat(captured, dim=0)


def collect_sae_activations_hf(
    model: nn.Module,
    input_ids: Tensor,
    attention_mask: Optional[Tensor],
    hook_module_path: str,
    sae: nn.Module,
    batch_size: int = 32,
) -> Tensor:
    """Compute SAE activations for all tokens by hooking a module output."""
    if not hasattr(sae, "encode"):
        raise RuntimeError("SAE module does not implement .encode(x).")

    resid = capture_module_activations(
        model=model,
        input_ids=input_ids,
        attention_mask=attention_mask,
        hook_module_path=hook_module_path,
        batch_size=batch_size,
    )

    with torch.no_grad():
        act = sae.encode(resid)

    return act


def get_feature_activation_sparsity_hf(
    model: nn.Module,
    input_ids: Tensor,
    attention_mask: Optional[Tensor],
    hook_module_path: str,
    sae: nn.Module,
    batch_size: int = 32,
    tokenizer=None,
) -> Tensor:
    """Estimate feature activation frequency across the evaluation tokens."""
    device = next(model.parameters()).device

    W_dec = getattr(sae, "W_dec", None)
    if isinstance(W_dec, torch.Tensor) and W_dec.ndim == 2:
        n_features = int(W_dec.shape[0])
    else:
        d_model = _infer_hidden_dim(model, sae)
        dummy = torch.zeros(1, 1, d_model, device=device, dtype=next(model.parameters()).dtype)
        with torch.no_grad():
            out = sae.encode(dummy)
        n_features = int(out.shape[-1])

    counts = torch.zeros(n_features, device=device, dtype=torch.float32)
    total = 0

    submodule = get_module_by_path(model, hook_module_path)

    def hook_fn(_module, _inp, out):
        nonlocal total, counts
        if isinstance(out, (tuple, list)):
            out = out[0]
        x = out.detach()
        with torch.no_grad():
            a = sae.encode(x)
            total += a.shape[0] * a.shape[1]
            counts += (a > 0).float().sum(dim=(0, 1))

    handle = submodule.register_forward_hook(hook_fn)

    model.eval()
    with torch.no_grad():
        for start in range(0, input_ids.size(0), batch_size):
            end = start + batch_size
            batch_ids = input_ids[start:end]
            batch_mask = attention_mask[start:end] if attention_mask is not None else None
            model(input_ids=batch_ids, attention_mask=batch_mask)

    handle.remove()

    if total == 0:
        return counts
    return counts / float(total)
