import re
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Optional

sns.set_theme(style="white")
_LEAD_RE = re.compile(r'^[_▁Ġ#]+')
_PRINTABLE_RE = re.compile(r"[^\x20-\x7E\u00A0-\uFFFF]+")

def clean_tok(s: str) -> str:
    if not isinstance(s, str):
        s = str(s)
    s = _LEAD_RE.sub('', s).replace('Ċ', '')
    return s if s else '·'

def decode_ids_pretty(tokenizer, ids, max_token_chars=14):
    toks = []
    for i in ids:
        try:
            s = tokenizer.decode([int(i)], skip_special_tokens=False)
        except Exception:
            s = f"<id:{i}>"
        s = s.replace("Ċ", "").replace("▁", " ")
        s = _PRINTABLE_RE.sub("", s).strip()
        if not s:
            s = "·"
        if len(s) > max_token_chars:
            s = s[:max_token_chars-1] + "…"
        toks.append(s)
    return toks

def safe_ids_to_tokens(tokenizer, ids):
    toks = []
    try:
        vocab_size = int(getattr(tokenizer, "vocab_size", None) or len(tokenizer))
    except Exception:
        vocab_size = None
    for x in ids:
        try:
            xi = int(x)
        except Exception:
            toks.append(f"<id:{x}>"); continue
        if xi < 0 or (vocab_size is not None and xi >= vocab_size):
            toks.append(f"<id:{xi}>"); continue
        try:
            tok = tokenizer.convert_ids_to_tokens([xi])
            tok = tok[0] if isinstance(tok, list) and len(tok)==1 else tok
        except Exception:
            tok = f"<id:{xi}>"
        toks.append(clean_tok(tok))
    return toks

def _get_lm_head_and_norm(model):
    lm_head = getattr(model, "lm_head", None)
    if lm_head is None and hasattr(model, "get_output_embeddings"):
        lm_head = model.get_output_embeddings()
    final_norm = (
        getattr(model, "final_layer_norm", None)
        or getattr(getattr(model, "model", None), "norm", None)
        or getattr(getattr(model, "transformer", None), "ln_f", None)
        or None
    )
    if lm_head is None:
        if hasattr(model, "transformer") and hasattr(model.transformer, "wte"):
            class TiedHead(torch.nn.Module):
                def __init__(self, emb): super().__init__(); self.weight = emb.weight
                def forward(self, x): return x @ self.weight.t()
            lm_head = TiedHead(model.transformer.wte)
        else:
            raise RuntimeError("无")
    return lm_head, final_norm

def _norm_then_head(h, final_norm, lm_head):
    if final_norm is not None:
        h = final_norm(h)
    return lm_head(h)

@torch.no_grad()
def compute_logit_bias_with_final_layer(model, input_ids, attention_mask=None, temperature=1.0):
    out = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=False)
    logits = out.logits
    if temperature != 1.0:
        logits = logits / temperature
    logprobs = logits.log_softmax(dim=-1)
    b = logprobs.mean(dim=(0,1))
    return b

@torch.no_grad()
def build_candidate_mask_from_final(H_bt, lm_head, final_norm, topk=512):
    h_final = H_bt[-1]
    logits_final = _norm_then_head(h_final, final_norm, lm_head)
    top_ids = logits_final.topk(k=min(topk, logits_final.size(-1)), dim=-1).indices.unique()
    V = logits_final.size(-1)
    mask = torch.zeros(V, dtype=torch.bool, device=logits_final.device)
    mask[top_ids] = True
    return mask

@torch.no_grad()
def build_ascii_mask(tokenizer, device, cache: dict = None):
    if cache is not None and "ascii_mask" in cache:
        return cache["ascii_mask"]
    vocab_size = getattr(tokenizer, "vocab_size", None)
    if vocab_size is None:
        vocab_size = len(getattr(tokenizer, "get_vocab")())
    allow = []
    for i in range(vocab_size):
        s = tokenizer.decode([i], skip_special_tokens=False)
        s = s.replace("Ċ","").replace("▁"," ")
        s = _PRINTABLE_RE.sub("", s)
        if s and all( (32 <= ord(ch) < 127) for ch in s ):
            allow.append(i)
    mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
    if len(allow):
        mask[torch.tensor(allow, device=device)] = True
    if cache is not None:
        cache["ascii_mask"] = mask
    return mask

@torch.no_grad()
def plot_logitlens_heatmap_pretty(
    all_hidden_states: List[torch.Tensor],
    model,
    tokenizer,
    step: Optional[int] = None,
    batch_idx: int = 0,
    token_span: Optional[slice] = None,
    input_ids: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    annotate_topk: int = 3,
    temperature: float = 0.7,
    skip_every_n: int = 1,
    max_token_chars: int = 12,
    apply_debias: bool = True,
    use_final_topk_candidates: bool = True,
    final_topk: int = 512,
    use_ascii_whitelist: bool = False,
    save_path: Optional[str] = None
):
    S = len(all_hidden_states)
    if S == 0: raise ValueError("all_hidden_states 为空")
    step = S-1 if step is None else step
    H = all_hidden_states[step]
    L, B, T, C = H.shape
    if not (0 <= batch_idx < B):
        raise ValueError(f"batch_idx 超界：0..{B-1}")
    lm_head, final_norm = _get_lm_head_and_norm(model)
    try:   param = next(lm_head.parameters())
    except StopIteration:
        param = lm_head.weight
    lm_dev, lm_dtype = param.device, param.dtype
    if token_span is None:
        token_span = slice(0, T)
    H_bt = H[:, batch_idx, token_span, :]
    T_sel = H_bt.shape[1]
    logit_bias = None
    if apply_debias and input_ids is not None:
        logit_bias = compute_logit_bias_with_final_layer(
            model, input_ids.to(lm_dev),
            attention_mask=attention_mask.to(lm_dev) if attention_mask is not None else None,
            temperature=1.0
        ).to(lm_dev, dtype=lm_dtype)
    cand_mask = None
    if use_final_topk_candidates:
        cand_mask = build_candidate_mask_from_final(H_bt.to(lm_dev, lm_dtype), lm_head, final_norm, topk=final_topk)
    if use_ascii_whitelist:
        ascii_mask = build_ascii_mask(tokenizer, device=lm_dev, cache={})
        cand_mask = ascii_mask if cand_mask is None else (cand_mask & ascii_mask)
    top_probs = torch.zeros(L, T_sel)
    annot_txt = [[""]*T_sel for _ in range(L)]
    for l in range(L):
        h = H_bt[l].to(device=lm_dev, dtype=lm_dtype)
        logits = _norm_then_head(h, final_norm, lm_head)
        if logit_bias is not None:
            logits = logits - logit_bias
        if cand_mask is not None:
            logits = logits.masked_fill(~cand_mask, float("-inf"))
        if temperature != 1.0:
            logits = logits / temperature
        probs = torch.softmax(logits, dim=-1)
        v, i = probs.topk(k=max(1, annotate_topk), dim=-1)
        top_probs[l] = v[:, 0].detach().float().cpu()
        ids_np = i.detach().cpu().numpy()
        row = []
        for tpos in range(T_sel):
            toks = decode_ids_pretty(tokenizer, ids_np[tpos].tolist(), max_token_chars=max_token_chars)
            row.append(" / ".join(toks[:annotate_topk]))
        annot_txt[l] = row
    if input_ids is not None:
        ids_slice = input_ids[batch_idx, token_span].detach().cpu().tolist()
        xlabels = decode_ids_pretty(tokenizer, ids_slice, max_token_chars=16)
    else:
        xlabels = [str(i) for i in range(T_sel)]
    layer_idx = list(range(L))[::skip_every_n]
    data = top_probs.numpy()[::skip_every_n]
    annot = np.array(annot_txt, dtype=object)[::skip_every_n]
    fig_w = max(12, T_sel * 1.2)
    fig_h = max(6, len(layer_idx) * 0.6)
    fig, ax = plt.subplots(figsize=(fig_w, fig_h))
    hm = sns.heatmap(
        data, annot=annot, fmt='', cmap='YlGnBu',
        xticklabels=xlabels, yticklabels=layer_idx,
        cbar=True, annot_kws={'size': 14, 'fontweight': 'bold'},
        linewidths=0.8, linecolor='white', ax=ax
    )
    hm.invert_yaxis()
    ax.set_xlabel('Tokens', fontsize=18, fontweight='bold')
    ax.set_ylabel('Layer',  fontsize=18, fontweight='bold')
    ax.tick_params(axis='x', rotation=45, labelsize=16)
    ax.tick_params(axis='y', labelsize=16)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
