import json
import os
import re
import torch
import torch.nn as nn


def get_features_by_layers(features_file):
    with open(features_file, "r") as f:
        features_by_layer = json.load(f)
        features_by_layer = {int(key): [int(v) for v in values] for key, values in features_by_layer.items()}
        return features_by_layer


# =========================
# Local DL-SAE lightweight loader & wrappers
# =========================

def _flatten_state_dict(sd: dict) -> dict:
    """
    Some checkpoints save weights under a 'state_dict' key.
    Normalize to a flat dict of tensors.
    """
    if "state_dict" in sd and isinstance(sd["state_dict"], dict):
        sd = sd["state_dict"]
    return sd


def _find_tensor_keys(state_dict: dict, pattern_list: list[str]) -> list[str]:
    """
    Return keys (in order of appearance) that contain ANY of the substrings in pattern_list.
    """
    out = []
    for k in state_dict.keys():
        kl = k.lower()
        if any(p in kl for p in pattern_list):
            out.append(k)
    return out


def _pick_best_decoder_key(state_dict: dict) -> str:
    """
    Heuristically pick the decoder weight key.
    Prefer names containing 'decoder' or 'w_dec', and with largest numel.
    """
    candidates = []
    for k, v in state_dict.items():
        if not isinstance(v, torch.Tensor):
            continue
        if v.ndim != 2:
            continue
        name = k.lower()
        score = 0
        if "decoder" in name:
            score += 100
        if "w_dec" in name or "decoder.weight" in name or ("dec" in name and "weight" in name):
            score += 50
        if "weight" in name:
            score += 10
        candidates.append((score, v.numel(), k))
    if not candidates:
        raise ValueError("No 2D decoder-like tensor found in state dict.")
    candidates.sort(reverse=True)
    return candidates[0][2]


def _pick_best_encoder_key(state_dict: dict) -> str | None:
    """
    Heuristically pick the encoder weight key, if exists.
    """
    candidates = []
    for k, v in state_dict.items():
        if not isinstance(v, torch.Tensor):
            continue
        if v.ndim != 2:
            continue
        name = k.lower()
        score = 0
        if "encoder" in name:
            score += 100
        if "w_enc" in name or "encoder.weight" in name or ("enc" in name and "weight" in name):
            score += 50
        if "weight" in name:
            score += 10
        # Penalize if also looks like decoder
        if "decoder" in name or "w_dec" in name:
            score -= 200
        candidates.append((score, v.numel(), k))
    if not candidates:
        return None
    candidates.sort(reverse=True)
    if candidates[0][0] <= 0:
        # No convincing encoder key found
        return None
    return candidates[0][2]


def _extract_bias_vector(state_dict: dict, names: list[str]) -> torch.Tensor | None:
    """
    Try to find a bias vector among provided candidate names.
    """
    for n in names:
        if n in state_dict and isinstance(state_dict[n], torch.Tensor):
            v = state_dict[n]
            if v.ndim == 1:
                return v
    # try fuzzy search
    for k, v in state_dict.items():
        if not isinstance(v, torch.Tensor) or v.ndim != 1:
            continue
        kl = k.lower()
        if any(n in kl for n in names):
            return v
    return None


def _ensure_FxD(mat: torch.Tensor) -> torch.Tensor:
    """
    Ensure decoder matrix is [F, D]. If [D, F], transpose it.
    """
    if mat.ndim != 2:
        raise ValueError("Decoder/Encoder matrix must be 2D.")
    f, d = mat.shape
    # Heuristic: decoder should be wider in features dimension
    # But we cannot rely on f>d always; use the convention "we want F x D"
    # If we assume model hidden dim D is typically <= 8192 and features often >= 8k/16k
    # We'll treat "the larger dimension" as F. If ambiguous, check name in caller.
    # Here, implement a simple rule: If f < d, transpose.
    if f < d:
        mat = mat.t()
    return mat


def _extract_decoder_FxD(state_dict: dict, decoder_key: str) -> torch.Tensor:
    W = state_dict[decoder_key]
    if not isinstance(W, torch.Tensor) or W.ndim != 2:
        raise ValueError(f"Decoder weight at {decoder_key} must be a 2D tensor.")
    W = W.to(torch.float32).contiguous()
    W = _ensure_FxD(W)
    return W


def _extract_encoder_FxD(state_dict: dict, encoder_key: str) -> torch.Tensor:
    W = state_dict[encoder_key]
    if not isinstance(W, torch.Tensor) or W.ndim != 2:
        raise ValueError(f"Encoder weight at {encoder_key} must be a 2D tensor.")
    W = W.to(torch.float32).contiguous()
    W = _ensure_FxD(W)
    return W


def _get_k_from_config(cfg: dict) -> int | None:
    tr = cfg.get("trainer", {}) if isinstance(cfg.get("trainer"), dict) else {}
    k = tr.get("k", None)
    try:
        return int(k) if k is not None else None
    except Exception:
        return None


def _get_scalar_threshold_from_config(cfg: dict) -> float | None:
    tr = cfg.get("trainer", {}) if isinstance(cfg.get("trainer"), dict) else {}
    thr = tr.get("threshold", None)
    try:
        if thr is None:
            return None
        if isinstance(thr, (int, float)):
            return float(thr)
        # Sometimes saved as string
        return float(str(thr))
    except Exception:
        return None


class LocalSAE(nn.Module):
    """
    A minimal local SAE wrapper that supports common DL-SAE variants:
    - Standard (ReLU): a = ReLU(x W_enc^T + b_enc)
    - TopK / BatchTopK: same as ReLU then keep top-k; optionally subtract scalar/vector threshold before ReLU
    - JumpReLU: a = ReLU((x W_enc^T + b_enc) - threshold_vector)
    - Gated: ReLU(pre) * sigmoid(r_mag * pre + mag_bias) (if gate params exist)

    decode: x_hat = a @ W_dec + b_dec (when b_dec present)
    NOTE: W_dec must be [F, D]; W_enc must be [F, D].
    """
    def __init__(
        self,
        W_dec_FD: torch.Tensor,
        W_enc_FD: torch.Tensor | None,
        b_dec_D: torch.Tensor | None,
        b_enc_F: torch.Tensor | None,
        trainer_class_name: str,
        threshold_scalar: float | None,
        threshold_vector_F: torch.Tensor | None,
        k_topk: int | None,
        gate_bias_F: torch.Tensor | None,
        r_mag_F: torch.Tensor | None,
        mag_bias_F: torch.Tensor | None,
        device: str = "cpu",
    ):
        super().__init__()
        self.trainer_class_name = (trainer_class_name or "").lower()

        # Always register decoder
        self.register_buffer("W_dec", W_dec_FD.to(torch.float32).to(device))

        # Conditionally register optional buffers (不要先写 None 再 register)
        if W_enc_FD is not None:
            self.register_buffer("W_enc", W_enc_FD.to(torch.float32).to(device))

        if b_dec_D is not None:
            self.register_buffer("b_dec", b_dec_D.to(torch.float32).to(device))

        if b_enc_F is not None:
            self.register_buffer("b_enc", b_enc_F.to(torch.float32).to(device))

        if threshold_vector_F is not None:
            self.register_buffer("threshold_vector", threshold_vector_F.to(torch.float32).to(device))

        if gate_bias_F is not None:
            self.register_buffer("gate_bias", gate_bias_F.to(torch.float32).to(device))

        if r_mag_F is not None:
            self.register_buffer("r_mag", r_mag_F.to(torch.float32).to(device))

        if mag_bias_F is not None:
            self.register_buffer("mag_bias", mag_bias_F.to(torch.float32).to(device))

        # Scalars / ints 不用 register_buffer
        self.threshold_scalar = threshold_scalar
        self.k_topk = k_topk

    def _preact(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute linear pre-activation in feature space: [B,T,F].
        If encoder weights are missing, fall back to using decoder as encoder (weight tying).
        """
        W_enc = getattr(self, "W_enc", None)
        if W_enc is None:
            W_enc = self.W_dec

        #  pre = torch.einsum("btd,fd->btf", x.to(W_enc.dtype), W_enc)
        x = x.to(W_enc.dtype)
        pre = torch.einsum("...d,fd->...f", x, W_enc)

    
        if self.trainer_class_name.startswith("gated"):
            gate_bias = getattr(self, "gate_bias", None)
            if gate_bias is not None:
                pre = pre + gate_bias
        else:
            b_enc = getattr(self, "b_enc", None)
            if b_enc is not None:
                pre = pre + b_enc
        return pre

    def _apply_activation(self, pre: torch.Tensor) -> torch.Tensor:
        name = self.trainer_class_name

        # JumpReLU
        if "jumprelu" in name:
            thr_vec = getattr(self, "threshold_vector", None)
            thr = thr_vec if thr_vec is not None else (self.threshold_scalar if self.threshold_scalar is not None else 0.0)
            return torch.relu(pre - thr)

        # TopK / BatchTopK
        if ("topk" in name) or ("batchtopk" in name):
            thr_vec = getattr(self, "threshold_vector", None)
            out = pre
            if thr_vec is not None:
                out = out - thr_vec
            elif self.threshold_scalar is not None:
                out = out - self.threshold_scalar
            out = torch.relu(out)
            if self.k_topk is not None and 0 < self.k_topk < out.shape[-1]:
                topk_vals, topk_idx = torch.topk(out, k=self.k_topk, dim=-1)
                zeros = torch.zeros_like(out)
                out = zeros.scatter(-1, topk_idx, topk_vals)
            return out

        # Gated
        if "gated" in name:
            base = torch.relu(pre)
            r_mag = getattr(self, "r_mag", None)
            mag_bias = getattr(self, "mag_bias", None)
            if (r_mag is not None) or (mag_bias is not None):
                r = r_mag if r_mag is not None else 1.0
                m = mag_bias if mag_bias is not None else 0.0
                gate = torch.sigmoid(r * pre + m)
                base = base * gate
            return base

        # Default: Standard ReLU
        return torch.relu(pre)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        pre = self._preact(x)
        a = self._apply_activation(pre)
        return a

    # def decode(self, a: torch.Tensor) -> torch.Tensor:
    #     x_hat = torch.einsum("btf,fd->btd", a.to(self.W_dec.dtype), self.W_dec)
    #     b_dec = getattr(self, "b_dec", None)
    #     if b_dec is not None:
    #         x_hat = x_hat + b_dec
    #     return x_hat
    def decode(self, a: torch.Tensor) -> torch.Tensor:
        z = a.to(self.W_dec.dtype)
        # auto-broadcast for any leading shape
        x_hat = torch.einsum("...f,fd->...d", z, self.W_dec)
        b_dec = getattr(self, "b_dec", None)
        if b_dec is not None:
            x_hat = x_hat + b_dec  
        return x_hat



def _load_dictionary_learning_sae_from_folder(folder: str, device: str):
    """
    Pure local loader: read config.json + ae.pt, build a LocalSAE with proper
    encoder/decoder/bias/thresholds inferred from checkpoint and trainer class.
    """
    cfg_path = os.path.join(folder, "config.json")
    wt_path = os.path.join(folder, "ae.pt")
    if not os.path.exists(cfg_path) or not os.path.exists(wt_path):
        raise FileNotFoundError(f"Missing config.json or ae.pt under {folder}")

    with open(cfg_path) as f:
        cfg = json.load(f)
    trainer_class = (cfg.get("trainer", {}) or {}).get("trainer_class", "")

    ckpt = torch.load(wt_path, map_location="cpu")
    sd = _flatten_state_dict(ckpt)

    # ----- decoder -----
    dec_key = _pick_best_decoder_key(sd)
    W_dec_FD = _extract_decoder_FxD(sd, dec_key)

    # decoder bias (optional)
    b_dec = _extract_bias_vector(
        sd,
        names=["decoder.bias", "decoder_bias", "b_dec", "bias", "dec_bias"],
    )

    # ----- encoder -----
    enc_key = _pick_best_encoder_key(sd)
    W_enc_FD = _extract_encoder_FxD(sd, enc_key) if enc_key is not None else None

    # encoder bias / gate bias
    # Standard/TopK/BatchTopK often: encoder.bias / b_enc
    b_enc = _extract_bias_vector(
        sd,
        names=["encoder.bias", "b_enc", "enc_bias", "bias_enc"],
    )
    # Gated: gate_bias
    gate_bias = _extract_bias_vector(
        sd,
        names=["gate_bias"],
    )

    # thresholds
    # JumpReLU often has 'threshold' as a vector of length F
    threshold_vec = _extract_bias_vector(
        sd,
        names=["threshold", "thresholds"],
    )
    # Some trainers also provide scalar threshold in config
    threshold_scalar = _get_scalar_threshold_from_config(cfg)

    # top-k
    k_topk = _get_k_from_config(cfg)

    # gated params (optional)
    r_mag = _extract_bias_vector(sd, names=["r_mag"])
    mag_bias = _extract_bias_vector(sd, names=["mag_bias"])

    # Construct local SAE
    sae = LocalSAE(
        W_dec_FD=W_dec_FD,
        W_enc_FD=W_enc_FD,
        b_dec_D=b_dec,
        b_enc_F=b_enc,
        trainer_class_name=trainer_class,
        threshold_scalar=threshold_scalar,
        threshold_vector_F=threshold_vec,
        k_topk=k_topk,
        gate_bias_F=gate_bias,
        r_mag_F=r_mag,
        mag_bias_F=mag_bias,
        device=device,
    )
    return sae


def _swap_layer_number_in_path(base_dir: str, layer: int) -> str:
    """
    If the given directory encodes a layer number (e.g. 'resid_post_layer_20' or 'layer_20'),
    replace it with the requested 'layer'. If it does not, return the base_dir unchanged.
    """
    p1 = re.compile(r"(resid_post_layer_)(\d+)")
    p2 = re.compile(r"(layer_)(\d+)")
    if p1.search(base_dir):
        return p1.sub(rf"\g<1>{layer}", base_dir)
    if p2.search(base_dir):
        return p2.sub(rf"\g<1>{layer}", base_dir)
    return base_dir


def _get_decoder_weights_any(sae_obj):
    """
    Return decoder weight matrix [F, D] for any LocalSAE-like object.
    """
    if hasattr(sae_obj, "W_dec"):
        return sae_obj.W_dec
    if hasattr(sae_obj, "decoder") and hasattr(sae_obj.decoder, "weight"):
        W = sae_obj.decoder.weight
        # unify to [F, D]
        if W.shape[0] < W.shape[1]:
            W = W.t()
        return W
    if hasattr(sae_obj, "W_out"):
        W = sae_obj.W_out
        if W.shape[0] < W.shape[1]:
            W = W.t()
        return W
    raise AttributeError("Decoder weights not found on SAE object.")


# =========================
# Public API used by output_score.py
# =========================

def get_sae(
    model_type,
    layer,
    saes,
    backend: str = "dl_local",
    dl_local_dir: str | None = None,
    device: str = "cpu",
):
    """
    Only 'dl_local' backend is supported here (no sae_lens / no HF download).
    We accept one folder path; if it contains a layer number, it will be swapped to the requested 'layer'.
    """
    if layer in saes:
        return saes[layer]

    if backend != "dl_local":
        raise ValueError("Only 'dl_local' backend is supported in this setup.")

    assert dl_local_dir is not None, "dl_local_dir must be provided for local DL SAEs."
    layer_dir = _swap_layer_number_in_path(dl_local_dir, layer)
    sae = _load_dictionary_learning_sae_from_folder(layer_dir, device=device)
    saes[layer] = sae
    return sae


def cache_logit_lens(
    layer,
    saes,
    model_type,
    final_layer_norm,
    lm_head,
    k,
    backend: str = "dl_local",
    dl_local_dir: str | None = None,
    device: str = "cpu",
):
    """
    Build logit-lens tokens for the decoder weights of the local DL SAE at 'layer'.
    """
    sae = get_sae(
        model_type=model_type,
        layer=layer,
        saes=saes,
        backend=backend,
        dl_local_dir=dl_local_dir,
        device=device,
    )

    final_layer_norm = final_layer_norm.cpu()
    lm_head = lm_head.cpu()

    decoder_weights = _get_decoder_weights_any(sae).detach().cpu().to(torch.float32)  # [F, D]
    decoder_weights = final_layer_norm(decoder_weights)  # LN on feature vectors
    logits = lm_head(decoder_weights)  # [F, |V|]
    confidence = torch.softmax(logits, dim=1).detach().cpu()

    topk = torch.topk(confidence, dim=1, k=k)
    return topk, confidence, logits
