# autointerp_hf/utils.py
import os
import json
from typing import Dict, Any, Optional, Tuple, List

import torch
import torch.nn as nn


def _pick_config_path(folder: str) -> str:
    """
    Pick a config file for an SAE folder.

    We prefer cfg.json (dictionary_learning style) and fall back to config.json
    (older autointerp style).
    """
    cand_cfg = os.path.join(folder, "cfg.json")
    cand_config = os.path.join(folder, "config.json")
    if os.path.exists(cand_cfg):
        return cand_cfg
    if os.path.exists(cand_config):
        return cand_config
    # Return cfg.json by default so error messages are consistent.
    return cand_cfg


############################################################
# Helper functions to interpret & rebuild locally trained SAEs
#
# Context:
# - During training, each SAE is saved under a folder like:
#       /.../resid_post_layer_1/trainer_0/
#           ae.pt
#           config.json
#
# or in dictionary_learning style:
#       /.../layer1/<run_id>/final_*/ 
#           ae.pt
#           cfg.json
#
# - The "ae.pt" might contain various keys depending on training code:
#   e.g. "decoder.weight", "W_dec", "encoder.weight", ...
#
# Our goal is to re-instantiate a local SAE module which exposes .encode(x),
# returning feature activations for the interpretability pipeline.
############################################################


def _load_raw_state_dict(path: str) -> Dict[str, torch.Tensor]:
    """
    Load raw torch state dict from an 'ae.pt' checkpoint.

    The file may contain:
      - just a state_dict dict[str->Tensor],
      - or a bigger dictionary with keys like 'state_dict', etc.
    """
    obj = torch.load(path, map_location="cpu")
    if isinstance(obj, dict):
        # Common patterns
        if "state_dict" in obj and isinstance(obj["state_dict"], dict):
            return obj["state_dict"]
        if "model_state_dict" in obj and isinstance(obj["model_state_dict"], dict):
            return obj["model_state_dict"]
        # If it looks like a normal state dict already
        if all(isinstance(k, str) for k in obj.keys()) and any(
            isinstance(v, torch.Tensor) for v in obj.values()
        ):
            # Filter tensor keys
            out = {}
            for k, v in obj.items():
                if isinstance(v, torch.Tensor):
                    out[k] = v
            if len(out) > 0:
                return out
    raise ValueError(f"Could not parse a valid state_dict from ae.pt: {path}")


def _maybe_transpose_to_FxD(w: torch.Tensor) -> torch.Tensor:
    """
    Ensure decoder weight is [F, D]. Some code saves it as [D, F].
    Heuristic: if shape[0] < shape[1], assume already [F, D].
    If shape[0] > shape[1], transpose.
    """
    if w.ndim != 2:
        raise ValueError(f"Decoder weight must be rank-2, got shape={w.shape}")
    if w.shape[0] <= w.shape[1]:
        return w
    return w.t().contiguous()


def _pick_decoder_weight_key(state: Dict[str, torch.Tensor]) -> str:
    """
    Try to identify which key in `state` corresponds to decoder weights.
    We prefer keys containing 'decoder' or 'w_dec'.
    """
    candidates = []
    for k, v in state.items():
        if not isinstance(v, torch.Tensor):
            continue
        if v.ndim != 2:
            continue
        name_lower = k.lower()
        score = 0
        if "decoder" in name_lower:
            score += 100
        if "w_dec" in name_lower or "decoder.weight" in name_lower or (
            "dec" in name_lower and "weight" in name_lower
        ):
            score += 50
        if "weight" in name_lower:
            score += 10
        # We'll also store v.numel() to break ties for roughly "largest"
        candidates.append((score, v.numel(), k))

    if len(candidates) == 0:
        raise ValueError(
            f"No 2D tensor candidates for decoder weight found in state_dict keys: {list(state.keys())[:10]}..."
        )
    # best = highest score, then largest numel
    candidates.sort(key=lambda x: (x[0], x[1]), reverse=True)
    return candidates[0][2]


def _pick_encoder_weight_key(state: Dict[str, torch.Tensor]) -> Optional[str]:
    """
    Try to find encoder weights, if any.
    If not found, return None (we can tie weights).
    """
    candidates = []
    for k, v in state.items():
        if not isinstance(v, torch.Tensor):
            continue
        if v.ndim != 2:
            continue
        name_lower = k.lower()
        score = 0
        if "encoder" in name_lower:
            score += 100
        if "w_enc" in name_lower or "encoder.weight" in name_lower or (
            "enc" in name_lower and "weight" in name_lower
        ):
            score += 50
        if "weight" in name_lower:
            score += 10
        candidates.append((score, v.numel(), k))

    if len(candidates) == 0:
        return None
    candidates.sort(key=lambda x: (x[0], x[1]), reverse=True)
    return candidates[0][2]


def _extract_bias_vector(
    state: Dict[str, torch.Tensor],
    candidate_names: List[str],
    expected_dim: Optional[int] = None,
) -> Optional[torch.Tensor]:
    """
    Try to find a bias vector in the state dict.
    - candidate_names: list of potential keys to check first
    - expected_dim: if specified, ensure the vector has that length
    """
    # 1) try direct candidates
    for name in candidate_names:
        if name in state and isinstance(state[name], torch.Tensor):
            v = state[name]
            if v.ndim == 1:
                if expected_dim is None or v.shape[0] == expected_dim:
                    return v.clone()
    # 2) heuristic search: any 1D tensor with the correct dimension
    if expected_dim is not None:
        for k, v in state.items():
            if not isinstance(v, torch.Tensor):
                continue
            if v.ndim != 1:
                continue
            if v.shape[0] != expected_dim:
                continue
            # prefer keys with 'bias'
            if "bias" in k.lower():
                return v.clone()
    return None


def _extract_thresholds(
    trainer_cfg: Dict[str, Any],
    n_features: int,
) -> Tuple[Optional[float], Optional[torch.Tensor]]:
    """
    Many SAE variants store threshold as:
      - scalar: trainer.threshold
      - vector: trainer.threshold_vector (length F)
      - or not at all (standard ReLU / TopK)
    Return (threshold_scalar, threshold_vector_F).
    """
    thr = trainer_cfg.get("threshold", None)
    thr_vec = trainer_cfg.get("threshold_vector", None)

    threshold_scalar: Optional[float] = None
    threshold_vec: Optional[torch.Tensor] = None

    if isinstance(thr, (float, int)):
        threshold_scalar = float(thr)

    if isinstance(thr_vec, list) and len(thr_vec) == n_features:
        threshold_vec = torch.tensor(thr_vec, dtype=torch.float32)

    return threshold_scalar, threshold_vec


def _extract_topk(trainer_cfg: Dict[str, Any]) -> Optional[int]:
    """
    TopK-based SAEs store k as trainer.k (or similar).
    """
    k = trainer_cfg.get("k", None)
    if isinstance(k, int):
        return k
    if isinstance(k, float):
        return int(k)
    return None


def _extract_gated_params(
    state: Dict[str, torch.Tensor],
    n_features: int,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
    """
    Gated variants might store extra vectors:
      - gate_bias (F,)
      - r_mag (F,)
      - mag_bias (F,)
    We try to find them heuristically.
    """
    gate_bias = _extract_bias_vector(
        state,
        candidate_names=["gate_bias", "gate.bias", "gating_bias", "gating.bias"],
        expected_dim=n_features,
    )
    r_mag = _extract_bias_vector(
        state,
        candidate_names=["r_mag", "r_mag.bias", "mag_scale", "mag_scale.bias"],
        expected_dim=n_features,
    )
    mag_bias = _extract_bias_vector(
        state,
        candidate_names=["mag_bias", "mag.bias", "magnitude_bias"],
        expected_dim=n_features,
    )
    return gate_bias, r_mag, mag_bias


class LocalSAE(nn.Module):
    """
    A minimal self-contained SAE module for inference / interpretability.

    This is designed to approximate the behavior of several SAE variants
    we might have trained (Standard ReLU, TopK, JumpReLU, Gated...).

    It supports:
      - encode(x): returns feature activations a in shape [..., F]
        where x has shape [..., D] (e.g. [batch, seq_len, d_model])

    It does NOT need to support training, optimizer state, etc.
    It only needs to expose .encode() because our auto-interpretability
    pipeline relies on SAE.encode(hidden_states) -> activations.

    Internals:
    - We'll store:
        W_dec: [F, D]
        W_enc: [F, D] or None (if None, we tie weights)
        b_dec: [D] or None
        b_enc: [F] or None
        trainer_class_name: e.g. "StandardTrainer" / "TopKTrainer" / "JumpReLUTrainer"
        threshold_scalar / threshold_vector
        k_topk
        gating vectors if needed

    Convention:
    - W_dec is [F, D] (rows = features, cols = hidden dim).
    - W_enc is [F, D] (rows = features, cols = hidden dim). If not present,
      we'll default to using W_dec as the encoder weights (tied weights).
    """

    def __init__(
        self,
        W_dec_FD: torch.Tensor,
        W_enc_FD: Optional[torch.Tensor],
        b_dec_D: Optional[torch.Tensor],
        b_enc_F: Optional[torch.Tensor],
        trainer_class_name: str,
        threshold_scalar: Optional[float],
        threshold_vector_F: Optional[torch.Tensor],
        k_topk: Optional[int],
        gate_bias_F: Optional[torch.Tensor],
        r_mag_F: Optional[torch.Tensor],
        mag_bias_F: Optional[torch.Tensor],
        device: str = "cpu",
    ):
        super().__init__()

        # Normalize the classifier name (e.g. "StandardTrainer", "JumpReLUTrainer", etc.)
        self.trainer_class_name = (trainer_class_name or "").lower()

        # Decoder weights are mandatory
        self.register_buffer("W_dec", W_dec_FD.to(torch.float32).to(device))

        # Encoder weights optional; if absent, we will reuse decoder weights
        if W_enc_FD is not None:
            self.register_buffer("W_enc", W_enc_FD.to(torch.float32).to(device))

        # Decoder bias (vector in hidden dimension D)
        if b_dec_D is not None:
            self.register_buffer("b_dec", b_dec_D.to(torch.float32).to(device))

        # Encoder bias (vector in feature dimension F)
        if b_enc_F is not None:
            self.register_buffer("b_enc", b_enc_F.to(torch.float32).to(device))

        # Threshold(s)
        self.threshold_scalar = threshold_scalar
        if threshold_vector_F is not None:
            self.register_buffer("threshold_vector", threshold_vector_F.to(torch.float32).to(device))

        # TopK parameter
        self.k_topk = k_topk

        # Gating vectors
        if gate_bias_F is not None:
            self.register_buffer("gate_bias", gate_bias_F.to(torch.float32).to(device))
        if r_mag_F is not None:
            self.register_buffer("r_mag", r_mag_F.to(torch.float32).to(device))
        if mag_bias_F is not None:
            self.register_buffer("mag_bias", mag_bias_F.to(torch.float32).to(device))

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute SAE feature activations for input x.

        Args:
            x: [..., D] hidden states

        Returns:
            acts: [..., F]
        """
        # x_centered = x - b_dec
        if hasattr(self, "b_dec"):
            x_centered = x - self.b_dec
        else:
            x_centered = x

        # encoder weights
        if hasattr(self, "W_enc"):
            W_enc = self.W_enc
        else:
            W_enc = self.W_dec  # tied weights

        # pre-activation in feature space: a = x_centered @ W_enc^T + b_enc
        a = torch.einsum("...d,fd->...f", x_centered, W_enc)
        if hasattr(self, "b_enc"):
            a = a + self.b_enc

        # Apply activation function depending on trainer class
        # Standard ReLU-like
        acts = torch.relu(a)

        # JumpReLU: acts = relu(a - threshold)
        if "jump" in self.trainer_class_name:
            if hasattr(self, "threshold_vector"):
                thr = self.threshold_vector
                acts = torch.relu(a - thr)
            elif self.threshold_scalar is not None:
                acts = torch.relu(a - self.threshold_scalar)

        # TopK: keep only top-k activations per token
        if "topk" in self.trainer_class_name and self.k_topk is not None and self.k_topk > 0:
            # acts: [..., F]
            Fdim = acts.shape[-1]
            k = min(self.k_topk, Fdim)
            topk_vals, topk_idx = torch.topk(acts, k=k, dim=-1)
            out = torch.zeros_like(acts)
            out.scatter_(-1, topk_idx, topk_vals)
            acts = out

        # Gated variants (heuristic; may not match all training code exactly)
        if "gated" in self.trainer_class_name and hasattr(self, "gate_bias"):
            gate = torch.sigmoid(a + self.gate_bias)
            acts = acts * gate
            # optional magnitude scaling
            if hasattr(self, "r_mag"):
                acts = acts * torch.relu(self.r_mag)
            if hasattr(self, "mag_bias"):
                acts = acts + self.mag_bias

        return acts


def _build_local_sae_from_folder(folder: str, device: str) -> LocalSAE:
    """
    High-level loader that:
      - reads <folder>/cfg.json (preferred) or <folder>/config.json
      - reads <folder>/ae.pt
      - heuristically extracts the SAE architecture & parameters
      - returns a LocalSAE (nn.Module)

    This is the main constructor we will call from `load_sae`.
    """
    cfg_path = _pick_config_path(folder)
    ae_path = os.path.join(folder, "ae.pt")

    if not os.path.exists(cfg_path):
        raise FileNotFoundError(
            f"[build_local_sae_from_folder] Missing cfg.json/config.json at {cfg_path}"
        )
    if not os.path.exists(ae_path):
        raise FileNotFoundError(
            f"[build_local_sae_from_folder] Missing ae.pt at {ae_path}"
        )

    # Load trainer config (JSON saved at training time)
    with open(cfg_path, "r") as f:
        cfg_json = json.load(f)

    trainer_cfg = cfg_json.get("trainer", {}) or {}
    trainer_class_name = str(trainer_cfg.get("trainer_class", ""))

    # Load raw SAE state dict from ae.pt
    raw_state = _load_raw_state_dict(ae_path)

    # -------- decoder weights --------
    dec_key = _pick_decoder_weight_key(raw_state)
    W_dec = raw_state[dec_key].clone().to(torch.float32).contiguous()
    W_dec = _maybe_transpose_to_FxD(W_dec)  # shape [F, D] after this

    n_features, d_model = W_dec.shape

    # -------- encoder weights (optional) --------
    enc_key = _pick_encoder_weight_key(raw_state)
    W_enc = None
    if enc_key is not None and enc_key in raw_state:
        W_enc = raw_state[enc_key].clone().to(torch.float32).contiguous()
        W_enc = _maybe_transpose_to_FxD(W_enc)

    # Optional decoder bias in hidden-dim space
    b_dec = _extract_bias_vector(
        raw_state,
        candidate_names=[
            "decoder.bias",
            "decoder_bias",
            "b_dec",
            "b_dec_D",
            "b_dec_vector",
        ],
        expected_dim=d_model,
    )

    # Optional encoder bias in feature space
    b_enc = _extract_bias_vector(
        raw_state,
        candidate_names=[
            "encoder.bias",
            "encoder_bias",
            "b_enc",
            "b_enc_F",
            "b_enc_vector",
        ],
        expected_dim=n_features,
    )

    # Thresholds and top-k
    threshold_scalar, threshold_vec = _extract_thresholds(trainer_cfg, n_features)
    k_topk = _extract_topk(trainer_cfg)

    # Gating extras (optional)
    gate_bias, r_mag, mag_bias = _extract_gated_params(raw_state, n_features)

    sae = LocalSAE(
        W_dec_FD=W_dec,
        W_enc_FD=W_enc,
        b_dec_D=b_dec,
        b_enc_F=b_enc,
        trainer_class_name=trainer_class_name,
        threshold_scalar=threshold_scalar,
        threshold_vector_F=threshold_vec,
        k_topk=k_topk,
        gate_bias_F=gate_bias,
        r_mag_F=r_mag,
        mag_bias_F=mag_bias,
        device=device,
    )

    return sae


def load_sae(sae_path: str, device: str, dtype: torch.dtype) -> nn.Module:
    """
    Public entry point used by run_llm_eval.py.

    This loads SAE weights and config from the SAE folder containing ae.pt, and
    rebuilds a minimal SAE object for inference / interpretability.

    Returned object MUST expose:
        .encode(hidden_states_BLD) -> feature_acts_BLF
    which is exactly what the autointerp pipeline expects.
    """
    sae_dir = os.path.dirname(sae_path)

    # Build LocalSAE in the specified device
    sae = _build_local_sae_from_folder(sae_dir, device=device)

    # Cast to final device / dtype for inference
    sae = sae.to(device=device, dtype=dtype)
    sae.eval()

    return sae
