import re
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Optional

sns.set_theme(style="white")
_LEAD_RE = re.compile(r'^[_▁Ġ#]+')
_PRINTABLE_RE = re.compile(r"[^\x20-\x7E\u00A0-\uFFFF]+")

def clean_tok(s: str) -> str:
    if not isinstance(s, str):
        s = str(s)
    s = _LEAD_RE.sub('', s).replace('Ċ', '')
    return s if s else '·'

def decode_ids_pretty(tokenizer, ids, max_token_chars=14):
    toks = []
    for i in ids:
        try:
            s = tokenizer.decode([int(i)], skip_special_tokens=False)
        except Exception:
            s = f"<id:{i}>"
        s = s.replace("Ċ", "").replace("▁", " ")
        s = _PRINTABLE_RE.sub("", s).strip()
        if not s:
            s = "·"
        if len(s) > max_token_chars:
            s = s[:max_token_chars-1] + "…"
        toks.append(s)
    return toks

def safe_ids_to_tokens(tokenizer, ids):
    toks = []
    try:
        vocab_size = int(getattr(tokenizer, "vocab_size", None) or len(tokenizer))
    except Exception:
        vocab_size = None
    for x in ids:
        try:
            xi = int(x)
        except Exception:
            toks.append(f"<id:{x}>"); continue
        if xi < 0 or (vocab_size is not None and xi >= vocab_size):
            toks.append(f"<id:{xi}>"); continue
        try:
            tok = tokenizer.convert_ids_to_tokens([xi])
            tok = tok[0] if isinstance(tok, list) and len(tok)==1 else tok
        except Exception:
            tok = f"<id:{xi}>"
        toks.append(clean_tok(tok))
    return toks

def _get_lm_head_and_norm(model):
    lm_head = getattr(model, "lm_head", None)
    if lm_head is None and hasattr(model, "get_output_embeddings"):
        lm_head = model.get_output_embeddings()
    final_norm = (
        getattr(model, "final_layer_norm", None)
        or getattr(getattr(model, "model", None), "norm", None)
        or getattr(getattr(model, "transformer", None), "ln_f", None)
        or None
    )
    if lm_head is None:
        if hasattr(model, "transformer") and hasattr(model.transformer, "wte"):
            class TiedHead(torch.nn.Module):
                def __init__(self, emb): super().__init__(); self.weight = emb.weight
                def forward(self, x): return x @ self.weight.t()
            lm_head = TiedHead(model.transformer.wte)
        else:
            raise RuntimeError("无")
    return lm_head, final_norm

def _norm_then_head(h, final_norm, lm_head):
    if final_norm is not None:
        h = final_norm(h)
    return lm_head(h)

@torch.no_grad()
def compute_logit_bias_with_final_layer(model, input_ids, attention_mask=None, temperature=1.0):
    out = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=False)
    logits = out.logits
    if temperature != 1.0:
        logits = logits / temperature
    logprobs = logits.log_softmax(dim=-1)
    b = logprobs.mean(dim=(0,1))
    return b

@torch.no_grad()
def build_candidate_mask_from_final(H_bt, lm_head, final_norm, topk=512):
    h_final = H_bt[-1]
    logits_final = _norm_then_head(h_final, final_norm, lm_head)
    top_ids = logits_final.topk(k=min(topk, logits_final.size(-1)), dim=-1).indices.unique()
    V = logits_final.size(-1)
    mask = torch.zeros(V, dtype=torch.bool, device=logits_final.device)
    mask[top_ids] = True
    return mask

@torch.no_grad()
def build_ascii_mask(tokenizer, device, cache: dict = None):
    if cache is not None and "ascii_mask" in cache:
        return cache["ascii_mask"]
    vocab_size = getattr(tokenizer, "vocab_size", None)
    if vocab_size is None:
        vocab_size = len(getattr(tokenizer, "get_vocab")())
    allow = []
    for i in range(vocab_size):
        s = tokenizer.decode([i], skip_special_tokens=False)
        s = s.replace("Ċ","").replace("▁"," ")
        s = _PRINTABLE_RE.sub("", s)
        if s and all( (32 <= ord(ch) < 127) for ch in s ):
            allow.append(i)
    mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
    if len(allow):
        mask[torch.tensor(allow, device=device)] = True
    if cache is not None:
        cache["ascii_mask"] = mask
    return mask

@torch.no_grad()
def plot_logitlens_heatmap_pretty(
    all_hidden_states: List[torch.Tensor],
    model,
    tokenizer,
    step: Optional[int] = None,
    batch_idx: int = 0,
    token_span: Optional[slice] = None,
    input_ids: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    annotate_topk: int = 3,
    temperature: float = 0.7,
    skip_every_n: int = 1,
    max_token_chars: int = 12,
    apply_debias: bool = True,
    use_final_topk_candidates: bool = True,
    final_topk: int = 512,
    use_ascii_whitelist: bool = False,
    save_path: Optional[str] = None
):
    S = len(all_hidden_states)
    if S == 0: raise ValueError("all_hidden_states 为空")
    step = S-1 if step is None else step
    H = all_hidden_states[step]
    L, B, T, C = H.shape
    if not (0 <= batch_idx < B):
        raise ValueError(f"batch_idx 超界：0..{B-1}")
    lm_head, final_norm = _get_lm_head_and_norm(model)
    try:   param = next(lm_head.parameters())
    except StopIteration:
        param = lm_head.weight
    lm_dev, lm_dtype = param.device, param.dtype
    if token_span is None:
        token_span = slice(0, T)
    H_bt = H[:, batch_idx, token_span, :]
    T_sel = H_bt.shape[1]
    logit_bias = None
    if apply_debias and input_ids is not None:
        logit_bias = compute_logit_bias_with_final_layer(
            model, input_ids.to(lm_dev),
            attention_mask=attention_mask.to(lm_dev) if attention_mask is not None else None,
            temperature=1.0
        ).to(lm_dev, dtype=lm_dtype)
    cand_mask = None
    if use_final_topk_candidates:
        cand_mask = build_candidate_mask_from_final(H_bt.to(lm_dev, lm_dtype), lm_head, final_norm, topk=final_topk)
    if use_ascii_whitelist:
        ascii_mask = build_ascii_mask(tokenizer, device=lm_dev, cache={})
        cand_mask = ascii_mask if cand_mask is None else (cand_mask & ascii_mask)
    top_probs = torch.zeros(L, T_sel)
    annot_txt = [[""]*T_sel for _ in range(L)]
    for l in range(L):
        h = H_bt[l].to(device=lm_dev, dtype=lm_dtype)
        logits = _norm_then_head(h, final_norm, lm_head)
        if logit_bias is not None:
            logits = logits - logit_bias
        if cand_mask is not None:
            logits = logits.masked_fill(~cand_mask, float("-inf"))
        if temperature != 1.0:
            logits = logits / temperature
        probs = torch.softmax(logits, dim=-1)
        v, i = probs.topk(k=max(1, annotate_topk), dim=-1)
        top_probs[l] = v[:, 0].detach().float().cpu()
        ids_np = i.detach().cpu().numpy()
        row = []
        for tpos in range(T_sel):
            toks = decode_ids_pretty(tokenizer, ids_np[tpos].tolist(), max_token_chars=max_token_chars)
            row.append(" / ".join(toks[:annotate_topk]))
        annot_txt[l] = row
    if input_ids is not None:
        ids_slice = input_ids[batch_idx, token_span].detach().cpu().tolist()
        xlabels = decode_ids_pretty(tokenizer, ids_slice, max_token_chars=16)
    else:
        xlabels = [str(i) for i in range(T_sel)]
    layer_idx = list(range(L))[::skip_every_n]
    data = top_probs.numpy()[::skip_every_n]
    annot = np.array(annot_txt, dtype=object)[::skip_every_n]
    fig_w = max(12, T_sel * 1.2)
    fig_h = max(6, len(layer_idx) * 0.6)
    fig, ax = plt.subplots(figsize=(fig_w, fig_h))
    hm = sns.heatmap(
        data, annot=annot, fmt='', cmap='YlGnBu',
        xticklabels=xlabels, yticklabels=layer_idx,
        cbar=True, annot_kws={'size': 14, 'fontweight': 'bold'},
        linewidths=0.8, linecolor='white', ax=ax
    )
    hm.invert_yaxis()
    ax.set_xlabel('Tokens', fontsize=18, fontweight='bold')
    ax.set_ylabel('Layer',  fontsize=18, fontweight='bold')
    ax.tick_params(axis='x', rotation=45, labelsize=16)
    ax.tick_params(axis='y', labelsize=16)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

import numpy as np
import matplotlib.pyplot as plt

def _shift_map(t, alpha: float):
    t = np.asarray(t, dtype=np.float64)
    return (alpha * t) / (1.0 + (alpha - 1.0) * t)

def _theory_counts(total_mask: int, K: int, alpha: float):
    t = (np.arange(1, K + 1, dtype=np.float64)) / K
    target_cum = np.round(total_mask * _shift_map(t, alpha))
    target_cum = np.clip(target_cum, 0, total_mask)
    counts = np.diff(np.concatenate([[0.0], target_cum])).astype(int)
    counts = np.maximum(counts, 1)
    overflow = counts.sum() - total_mask
    i = K - 1
    while overflow > 0 and i >= 0:
        dec = min(overflow, counts[i] - 1)
        counts[i] -= dec
        overflow -= dec
        i -= 1
    assert counts.sum() == total_mask, f"sum(counts)={counts.sum()} != total_mask={total_mask}"
    return counts

def _fit_alpha_from_alpha_bar(alpha_bar_policy, K: int, search=(1e-3, 5.0), grid=2000):
    t = np.linspace(1, K, K, dtype=np.float64) / K
    alphas = np.linspace(search[0], search[1], grid)
    best_alpha, best_loss = None, float("inf")
    for a in alphas:
        theory = _shift_map(t, a)
        loss = np.mean((alpha_bar_policy - theory)**2)
        if loss < best_loss:
            best_loss = loss
            best_alpha = a
    return float(best_alpha), float(best_loss)

def _metrics_and_plots(counts_policy, counts_theory, alpha_bar_policy, alpha_bar_theory):
    K = len(counts_policy)
    steps = np.arange(1, K + 1)
    diff = counts_policy - counts_theory
    mae = float(np.mean(np.abs(diff)))
    rmse = float(np.sqrt(np.mean(diff**2)))
    maxad = int(np.max(np.abs(diff)))
    exact = float(np.mean(diff == 0))
    cum_diff_l1 = float(np.mean(np.abs(alpha_bar_policy - alpha_bar_theory)))

    print(f"[Counts] MAE={mae:.4f}  RMSE={rmse:.4f}  Max|Δ|={maxad}  Exact-match%={100*exact:.2f}%")
    print(f"[Cumulative α] mean |Δ| = {cum_diff_l1:.6f}")

    plt.figure()
    plt.plot(steps, counts_policy, label="policy (returned)")
    plt.plot(steps, counts_theory, label="shift theory")
    plt.title("Per-step unmask counts: policy vs shift theory")
    plt.xlabel("step k"); plt.ylabel("tokens at step k")
    plt.legend(); plt.show()

    plt.figure()
    plt.plot(steps, alpha_bar_policy, label="policy cumulative (ᾱ_policy)")
    plt.plot(steps, alpha_bar_theory, label="shift theory cumulative")
    plt.title("Cumulative unmask fraction (ᾱ): policy vs shift theory")
    plt.xlabel("step k"); plt.ylabel("ᾱ up to step k")
    plt.legend(); plt.show()

    plt.figure()
    plt.bar(steps, diff)
    plt.title("Per-step difference in counts (policy - theory)")
    plt.xlabel("step k"); plt.ylabel("Δ tokens")
    plt.show()

def validate_shift_without_meta(sched: dict, alpha: float = None, sample_index: int = 0):
    counts = np.asarray(sched["counts_per_step"])
    alpha_bar_policy_full = np.asarray(sched["alpha_bar_per_sample"])
    B, K = counts.shape
    assert 0 <= sample_index < B, f"sample_index 0..{B-1}"

    counts_b = counts[sample_index]
    total_mask = int(counts_b.sum())
    alpha_bar_policy = alpha_bar_policy_full[sample_index]

    if np.all(counts_b == 1):
        if alpha is None:
            alpha = 1/3
        counts_theory = _theory_counts(total_mask, K, alpha)
        alpha_bar_theory = np.cumsum(counts_theory) / total_mask
        _metrics_and_plots(counts_b, counts_theory, alpha_bar_policy, alpha_bar_theory)
        return

    if alpha is None:
        a_hat, loss = _fit_alpha_from_alpha_bar(alpha_bar_policy, K)
        print(f"[alpha] estimated ≈ {a_hat:.6f} (grid L2={loss:.6e})")
        alpha = a_hat
    else:
        print(f"[alpha] using provided = {alpha}")

    counts_theory = _theory_counts(total_mask, K, alpha)
    alpha_bar_theory = np.cumsum(counts_theory) / total_mask

    _metrics_and_plots(counts_b, counts_theory, alpha_bar_policy, alpha_bar_theory)

import numpy as np
import matplotlib.pyplot as plt
import torch
from typing import Optional

def fig1_scatter_Mt_vs_logsnr(
    sched,
    all_hidden_states,
    q=0.999,
    sample_index=0,
    start_index=1,
    show_title=True,
    figsize=(6, 6),
    save_path: Optional[str] = None,
    save_dpi: int = 300,
    label_fontsize: int = 14,
    tick_fontsize: int = 12,
    title_fontsize: int = 14,
    fontweight: str = "bold"
):
    logsnr, T = _align_logsnr_and_T(sched, all_hidden_states, sample_index, start_index)

    M = []
    for t in range(start_index, start_index + T):
        H = _to_numpy(all_hidden_states[t])
        H = H[:, sample_index:sample_index+1, ...]
        X = _flatten_to_NC(H)
        M.append(np.quantile(np.abs(X), q))
    M = np.asarray(M)

    coef = np.polyfit(logsnr, M, 1)
    xfit = np.linspace(logsnr.min(), logsnr.max(), 200)
    yfit = np.polyval(coef, xfit)
    rho, p = _perm_p_value(logsnr, M, n_perm=1000, seed=0)

    plt.figure(figsize=figsize)
    plt.scatter(logsnr, M, s=12)
    plt.plot(xfit, yfit, lw=2)

    if show_title:
        plt.title(
            f"slope={coef[0]:.4g}, Spearman ρ={rho:.3f}, p≈{p:.2e}",
            fontsize=title_fontsize,
            fontweight=fontweight
        )

    plt.xlabel("logSNR(t)", fontsize=label_fontsize, fontweight=fontweight)
    plt.ylabel(f"P{int(q*1000)/10}(|H_t|)", fontsize=label_fontsize, fontweight=fontweight)

    plt.xticks(fontsize=tick_fontsize, fontweight=fontweight)
    plt.yticks(fontsize=tick_fontsize, fontweight=fontweight)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=save_dpi, bbox_inches='tight')
    plt.show()

    return {
        "Mt": M,
        "logsnr": logsnr,
        "slope": coef[0],
        "rho": rho,
        "p_perm": p
    }


import numpy as np
import matplotlib.pyplot as plt

try:
    import torch
    _HAS_TORCH = True
except Exception:
    _HAS_TORCH = False

def _to_numpy(x):
    if _HAS_TORCH and isinstance(x, torch.Tensor):
        x = x.to(torch.float32).detach().cpu().numpy()
        return x
    x = np.asarray(x)
    if x.dtype == np.float16:
        x = x.astype(np.float32)
    return x

def _align_logsnr_and_T(sched, all_hidden_states, sample_index=0, start_index=1):
    logsnr_full = _to_numpy(sched["logsnr_per_sample"])
    assert logsnr_full.ndim == 2
    logsnr = logsnr_full[sample_index]
    T_hidden = len(all_hidden_states) - start_index
    T_aligned = max(0, min(len(logsnr), T_hidden))
    return logsnr[:T_aligned], T_aligned

def _abs_values_one_step(all_hidden_states, t_aligned, sample_index=0,
                         channels=None, subsample_elems=None,
                         start_index=1, rng=None):
    if rng is None:
        rng = np.random.default_rng(123 + int(t_aligned))
    t = start_index + t_aligned
    H = _to_numpy(all_hidden_states[t])
    H = H[:, sample_index:sample_index+1]
    if channels is not None:
        ch = np.array(sorted(set(channels)), dtype=int)
        H = H[..., ch]
    X = np.abs(H).reshape(-1)
    if (subsample_elems is not None) and (X.size > subsample_elems):
        idx = rng.choice(X.size, size=subsample_elems, replace=False)
        X = X[idx]
    return X

def _spearman(x, y):
    x = np.asarray(x); y = np.asarray(y)
    rx = np.argsort(np.argsort(x, kind="mergesort"))
    ry = np.argsort(np.argsort(y, kind="mergesort"))
    rx = (rx - rx.mean()) / (rx.std() + 1e-12)
    ry = (ry - ry.mean()) / (ry.std() + 1e-12)
    return float(np.mean(rx*ry))

def exp1_slope_vs_quantile(sched, all_hidden_states, qs=(0.95,0.97,0.98,0.99,0.995,0.999),
                           sample_index=0, start_index=1,
                           channels=None,
                           subsample_elems=200_000,
                           show=True):
    logsnr, T = _align_logsnr_and_T(sched, all_hidden_states, sample_index, start_index)
    slopes, rhos = [], []
    for q in qs:
        Mt = []
        for t in range(T):
            vals = _abs_values_one_step(all_hidden_states, t, sample_index,
                                        channels, subsample_elems, start_index)
            Mt.append(np.quantile(vals, q))
        Mt = np.asarray(Mt)
        b = np.polyfit(logsnr, Mt, 1)[0]
        rho = _spearman(logsnr, Mt)
        slopes.append(b); rhos.append(rho)
    slopes = np.asarray(slopes); rhos = np.asarray(rhos)
    if show:
        fig, ax1 = plt.subplots()
        ax1.plot(qs, slopes, marker='o'); ax1.set_xlabel("quantile q")
        ax1.set_ylabel("slope w.r.t logSNR", color='tab:blue'); ax1.tick_params(axis='y', labelcolor='tab:blue')
        ax2 = ax1.twinx()
        ax2.plot(qs, rhos, marker='s', linestyle='--', color='tab:orange')
        ax2.set_ylabel("Spearman rho", color='tab:orange'); ax2.tick_params(axis='y', labelcolor='tab:orange')
        plt.title("Experiment 1: slope & Spearman vs quantile q")
        plt.show()
    return np.asarray(qs), slopes, rhos

def exp2_quantile_and_ccdf(sched, all_hidden_states,
                           step_indices=None,
                           n_points=200,
                           sample_index=0, start_index=1,
                           channels=None,
                           subsample_elems=300_000):
    logsnr, T = _align_logsnr_and_T(sched, all_hidden_states, sample_index, start_index)
    if step_indices is None:
        step_indices = (max(0, T//8), T//2, max(0, 7*T//8))
    labels = [f"step {i} (logSNR={logsnr[i]:.2f})" for i in step_indices]
    p_grid = np.linspace(0.90, 0.999, n_points)
    Q = []
    for i in step_indices:
        vals = _abs_values_one_step(all_hidden_states, i, sample_index,
                                    channels, subsample_elems, start_index)
        Q.append(np.quantile(vals, p_grid))
    Q = [np.asarray(qi) for qi in Q]
    plt.figure()
    for qi, lab in zip(Q, labels):
        plt.plot(p_grid, qi, lw=2, label=lab)
    plt.xlabel("quantile p"); plt.ylabel("Q_p(|H|)")
    plt.title("Experiment 2a: High-quantile curves (early/mid/late)")
    plt.legend(); plt.show()
    plt.figure()
    for i, lab in zip(step_indices, labels):
        vals = _abs_values_one_step(all_hidden_states, i, sample_index,
                                    channels, subsample_elems, start_index)
        vals = np.sort(vals)
        x_grid = np.linspace(np.quantile(vals, 0.9), np.max(vals), n_points)
        ccdf = 1.0 - np.searchsorted(vals, x_grid, side='right') / vals.size
        plt.plot(x_grid, ccdf, lw=2, label=lab)
    plt.yscale("log")
    plt.xlabel("x"); plt.ylabel("P(|H| > x)  (log scale)")
    plt.title("Experiment 2b: CCDF tails (early/mid/late)")
    plt.legend(); plt.show()
    return {"p_grid": p_grid, "Q_curves": Q, "steps": step_indices, "logsnr": logsnr}
