#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
train_calib_em_calibshape.py (rank-push + robust correctness)

无标签目标域下的“分布微调”工具：保留 calibration head 的排序与数值基调，
同时通过形状调整缓解中段堆积、鼓励两侧靠近 0/1。

模式：
  --shape_mode none      不做分布调整
  --shape_mode hist_eq   经验 CDF（直方图均衡化）
  --shape_mode rank_push 基于 rank 的 U 型拉伸（两侧推开）

保留原预测：
  p_final = λ * p_orig + (1-λ) * p_shaped
  通过 --mix_lambda 控制保留强度（越大越保守）

并打印“调整前/后”的 10-bin 统计（count/mean_prob/target；若记录含正确率字段，也打印 per-bin acc）。

"""

import os
import pickle
import argparse
from typing import List, Dict, Any, Optional

import numpy as np
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel

# ========================= 置信度头（两种结构的兼容） =========================

class ConfidenceHeadExternalLN(nn.Module):
    def __init__(self, in_dim: int, mid: int, p: float = 0.1):
        super().__init__()
        self.norm = nn.LayerNorm(in_dim)
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, mid),
            nn.GELU(),
            nn.Dropout(p),
            nn.Linear(mid, 1),
        )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.norm(x); return self.mlp(x)

class ConfidenceHeadInlineLN(nn.Module):
    def __init__(self, in_dim: int, mid: int, p: float = 0.1):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, mid),
            nn.GELU(),
            nn.Dropout(p),
            nn.Linear(mid, 1),
        )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mlp(x)

def _strip_prefix(state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    new_state = {}
    for k, v in state.items():
        if k.startswith("module."): new_state[k[len("module."):]] = v
        else: new_state[k] = v
    return new_state

def load_conf_head_compat(head_path: str, in_dim: int, device: str, p: float = 0.1) -> nn.Module:
    try: state = torch.load(head_path, map_location=device, weights_only=True)
    except TypeError: state = torch.load(head_path, map_location=device)
    state = _strip_prefix(state)
    keys = set(state.keys())
    has_external_ln = ("norm.weight" in keys and "norm.bias" in keys)
    has_inline_ln   = ("mlp.0.weight" in keys and state["mlp.0.weight"].ndim == 1)
    if has_external_ln and not has_inline_ln:
        mid = int(state["mlp.0.weight"].shape[0])
        head = ConfidenceHeadExternalLN(in_dim=in_dim, mid=mid, p=p).to(device)
        head.load_state_dict(state, strict=True); return head
    if has_inline_ln:
        mid = int(state["mlp.1.weight"].shape[0])
        head = ConfidenceHeadInlineLN(in_dim=in_dim, mid=mid, p=p).to(device)
        head.load_state_dict(state, strict=True); return head
    # fallback：自动推断 mid 并尝试两种结构
    mid = None
    for k, w in state.items():
        if k.endswith(".weight") and w.ndim == 2 and w.shape[1] == in_dim:
            mid = int(w.shape[0]); break
    if mid is None: raise RuntimeError("Cannot infer mid from state_dict.")
    try:
        head = ConfidenceHeadInlineLN(in_dim=in_dim, mid=mid, p=p).to(device)
        head.load_state_dict(state, strict=True); return head
    except Exception:
        head = ConfidenceHeadExternalLN(in_dim=in_dim, mid=mid, p=p).to(device)
        head.load_state_dict(state, strict=True); return head

# ========================= 基座加载 & LoRA 合并 =========================

def load_base_and_head(base_model: str, adapter_dir: str, head_path: str, device: str):
    tok = AutoTokenizer.from_pretrained(base_model, use_fast=True)
    if tok.pad_token_id is None and tok.eos_token_id is not None:
        tok.pad_token_id = tok.eos_token_id
    base = AutoModel.from_pretrained(base_model).to(device).eval()
    if adapter_dir:
        adapter_root = os.path.dirname(adapter_dir) if os.path.isfile(adapter_dir) else adapter_dir
        if os.path.isdir(adapter_root) and os.path.exists(os.path.join(adapter_root, "adapter_config.json")):
            try:
                from peft import PeftModel
                base = PeftModel.from_pretrained(base, adapter_root).merge_and_unload().to(device).eval()
                print(f"[info] loaded & merged LoRA from: {adapter_root}")
            except Exception as e:
                print(f"[warn] failed to load LoRA from {adapter_root}: {e}")
    hidden = getattr(base.config, "hidden_size", getattr(base.config, "n_embd", 768))
    head = load_conf_head_compat(head_path, in_dim=hidden, device=device, p=0.1).eval()
    return tok, base, head

# ========================= 文本规整 & 前向 =========================

def to_text_for_encoding(x: Any) -> str:
    if isinstance(x, list) and len(x) == 1: x = x[0]
    try:
        import numpy as _np
        if isinstance(x, _np.ndarray) and x.size == 1: x = x.item()
    except Exception: pass
    try:
        import torch as _t
        if isinstance(x, _t.Tensor) and x.numel() == 1: x = x.item()
    except Exception: pass
    return str(x)

@torch.no_grad()
def prob_answer_mean(tok, base, head, q: str, a: str, device: str, max_length: int,
                     prefix_tmpl: str, suffix_tmpl: str, layer_index: int,
                     temp: float, bias: float) -> float:
    prefix = (prefix_tmpl or "").format(question=q)
    suffix = (suffix_tmpl or "")
    ans = (" " + a) if len(a) > 0 else ""
    pre_ids = tok(prefix, add_special_tokens=False)["input_ids"]
    pre_ans_ids = tok(prefix + ans, add_special_tokens=False)["input_ids"]
    ans_len = max(1, len(pre_ans_ids) - len(pre_ids))
    full = prefix + ans + suffix
    enc = tok(full, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length)
    input_ids = enc["input_ids"].to(device); attn_mask = enc["attention_mask"].to(device)
    out = base(input_ids=input_ids, attention_mask=attn_mask, output_hidden_states=True)
    hs = out.hidden_states; idx = layer_index if layer_index < 0 else layer_index
    last = hs[idx]
    ans_start = min(len(pre_ids), last.shape[1] - 1)
    ans_end = min(ans_start + ans_len, last.shape[1])
    if ans_end <= ans_start: ans_end = min(ans_start + 1, last.shape[1])
    attn = attn_mask[0]
    if ans_start >= last.shape[1] - 1:
        valid_idx = int(attn.nonzero()[-1]); feats = last[0, valid_idx:valid_idx + 1].mean(dim=0)
    else:
        mask = torch.zeros_like(attn, dtype=last.dtype); mask[ans_start:ans_end] = 1.0
        mask = mask.unsqueeze(0).unsqueeze(-1)
        feats = (last * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-6); feats = feats.squeeze(0)
    logit = head(feats.unsqueeze(0)).squeeze(-1); logit = (logit + bias) / max(1e-6, temp)
    return float(torch.sigmoid(logit).item())

# ========================= 分布调整 =========================

def empirical_ranks(probs: np.ndarray) -> np.ndarray:
    N = probs.shape[0]; order = np.argsort(probs, kind="mergesort")
    ranks = np.empty_like(order, dtype=float); ranks[order] = (np.arange(N) + 0.5) / float(N); return ranks

def push_to_ends_from_ranks(ranks: np.ndarray, alpha: float = 2.0) -> np.ndarray:
    r = np.asarray(ranks, dtype=float); out = np.empty_like(r)
    left = r < 0.5
    out[left] = 0.5 * np.power(2.0 * r[left], alpha)
    out[~left] = 1.0 - 0.5 * np.power(2.0 * (1.0 - r[~left]), alpha)
    return out

def shape_probs(probs: np.ndarray, mode: str, lam: Optional[float], alpha: float) -> np.ndarray:
    if mode == "none": return probs.copy()
    ranks = empirical_ranks(probs)
    if mode == "hist_eq":
        shaped = ranks
    elif mode == "rank_push":
        shaped = push_to_ends_from_ranks(ranks, alpha=alpha)
    elif mode == "rank_random":
        # ------- ablation: 随机替代真实 rank -------
        N = len(probs)
        rand_ranks = np.random.permutation(np.arange(N))  # 随机顺序
        rand_ranks = (rand_ranks + 0.5) / float(N)        # 归一化到 (0,1)
        shaped = push_to_ends_from_ranks(rand_ranks, alpha=alpha)
    else:
        raise ValueError(f"Unknown shape_mode: {mode}")
    if lam is None: return np.clip(shaped, 0.0, 1.0)
    return np.clip(float(lam) * probs + (1.0 - float(lam)) * shaped, 0.0, 1.0)

# ========================= 分箱统计 & 打印 =========================

def bin_stats(probs: np.ndarray, num_bins: int, correctness: Optional[np.ndarray] = None):
    bins = (probs * num_bins).astype(int); bins = np.clip(bins, 0, num_bins - 1)
    counts = [0] * num_bins; means = [np.nan] * num_bins
    accs = [np.nan] * num_bins if correctness is not None else None
    valid_ns = [0] * num_bins if correctness is not None else None
    for b in range(num_bins):
        m = (bins == b)
        if m.any():
            counts[b] = int(m.sum()); means[b] = float(probs[m].mean())
            if correctness is not None:
                corr_b = correctness[m]; mask_valid = np.isfinite(corr_b)
                if mask_valid.any(): accs[b] = float(np.nanmean(corr_b[mask_valid])); valid_ns[b] = int(mask_valid.sum())
    # 邻近填补 mean
    means_np = np.array(means, dtype=float)
    if np.all(np.isnan(means_np)): means_np = np.array([0.5] * num_bins, dtype=float)
    else:
        last = np.nan
        for i in range(num_bins):
            if np.isnan(means_np[i]) and not np.isnan(last): means_np[i] = last
            else: last = means_np[i]
        last = np.nan
        for i in range(num_bins - 1, -1, -1):
            if np.isnan(means_np[i]) and not np.isnan(last): means_np[i] = last
            else: last = means_np[i]
    return bins, counts, means_np.tolist(), accs, valid_ns

def print_bin_report(title: str, counts, means, accs, valid_ns):
    print(f"\n{title}")
    print("  bin | count | mean_prob | target")
    print("  ----+-------+-----------+--------")
    for b in range(len(counts)):
        mp = means[b] if np.isfinite(means[b]) else float("nan")
        print(f"  {b:>3d} | {counts[b]:>5d} | {mp:>9.4f} | {mp:>6.4f}")
    if accs is not None and valid_ns is not None:
        print("\n  [bins] empirical accuracy by correctness (if available)")
        print("  bin | count | acc(valid) | valid_n")
        print("  ----+-------+------------+--------")
        for b in range(len(counts)):
            acc_s = f"{accs[b]:>10.4f}" if (accs[b] is not None and np.isfinite(accs[b])) else "         NA"
            vn = valid_ns[b] if valid_ns[b] is not None else 0
            print(f"  {b:>3d} | {counts[b]:>5d} | {acc_s} | {vn:>6d}")

# ========================= 解析 correctness =========================

def parse_correctness(value: Any) -> float:
    """返回 0.0 / 1.0 / np.nan"""
    # 解包单元素容器
    while isinstance(value, (list, tuple)) and len(value) == 1:
        value = value[0]
    # torch 张量
    try:
        import torch as _t
        if isinstance(value, _t.Tensor):
            if value.numel() == 1:
                value = value.item()
            else:
                return float("nan")
    except Exception:
        pass
    # numpy 标量
    try:
        import numpy as _np
        if isinstance(value, _np.generic):
            value = _np.asarray(value).item()
    except Exception:
        pass
    # 直接类型
    if isinstance(value, bool):
        return 1.0 if value else 0.0
    if isinstance(value, (int, float)):
        v = float(value)
        if abs(v - 0.0) < 1e-8: return 0.0
        if abs(v - 1.0) < 1e-8: return 1.0
        return float("nan")
    if isinstance(value, str):
        s = value.strip().lower()
        if s in ("true", "t", "yes", "y", "1"): return 1.0
        if s in ("false", "f", "no", "n", "0"): return 0.0
        # 处理 "tensor([True])" / "tensor([False])"
        if "tensor" in s and "true" in s: return 1.0
        if "tensor" in s and "false" in s: return 0.0
    return float("nan")

# ========================= 主流程 =========================

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--base_model", required=True)
    ap.add_argument("--adapter_dir", default="")
    ap.add_argument("--head_path", required=True)
    ap.add_argument("--target_pkl", required=True)
    ap.add_argument("--out_pkl", required=True)
    ap.add_argument("--num_bins", type=int, default=10)
    ap.add_argument("--max_length", type=int, default=512)
    ap.add_argument("--prefix", type=str, default="Q: {question}\nA:")
    ap.add_argument("--suffix", type=str, default="")
    ap.add_argument("--layer_index", type=int, default=-1)
    ap.add_argument("--temp", type=float, default=1.0)
    ap.add_argument("--bias", type=float, default=0.0)
    ap.add_argument("--shape_mode", type=str, default="rank_push",
                    choices=["none", "hist_eq", "rank_push", "rank_random"])
    ap.add_argument("--mix_lambda", type=float, default=0.7)
    ap.add_argument("--rank_alpha", type=float, default=2.0)
    ap.add_argument("--correctness_field", type=str, default="correctness",
                    help="records 里用于统计准确率的字段名（若不存在或不可解析则不统计）")
    ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    args = ap.parse_args()

    # 读取输入
    with open(args.target_pkl, "rb") as f:
        records: List[Dict[str, Any]] = pickle.load(f)
    if not isinstance(records, list) or not records:
        raise ValueError(f"{args.target_pkl} should be a non-empty list[dict].")

    tok, base, head = load_base_and_head(args.base_model, args.adapter_dir, args.head_path, device=args.device)

    # 计算原始概率
    probs: List[float] = []
    for rec in records:
        q_enc = to_text_for_encoding(rec.get("question", ""))
        a_enc = to_text_for_encoding(rec.get("model_answer", ""))
        p = prob_answer_mean(tok, base, head,
                             q=q_enc, a=a_enc, device=args.device, max_length=args.max_length,
                             prefix_tmpl=args.prefix, suffix_tmpl=args.suffix,
                             layer_index=args.layer_index, temp=args.temp, bias=args.bias)
        probs.append(p)
    probs_np = np.array(probs, dtype=float)

    # 解析 correctness（如可获得则用于 acc 报告）
    corr_vals = []
    field = args.correctness_field
    for rec in records:
        v = rec.get(field, None)
        corr_vals.append(parse_correctness(v))
    corr_np = np.array(corr_vals, dtype=float)
    has_corr = np.isfinite(corr_np).any()

    # 调整前分箱统计
    bins_b, counts_b, means_b, accs_b, valid_ns_b = bin_stats(
        probs_np, num_bins=args.num_bins, correctness=(corr_np if has_corr else None)
    )
    print_bin_report("[bins BEFORE] distribution (fixed 10 bins, width=0.1)",
                     counts_b, means_b, accs_b, valid_ns_b)

    # 分布形状调整
    shaped_np = shape_probs(probs_np, mode=args.shape_mode, lam=args.mix_lambda, alpha=args.rank_alpha)

    # 调整后分箱统计
    bins_a, counts_a, means_a, accs_a, valid_ns_a = bin_stats(
        shaped_np, num_bins=args.num_bins, correctness=(corr_np if has_corr else None)
    )
    print_bin_report("[bins AFTER ] distribution (fixed 10 bins, width=0.1)",
                     counts_a, means_a, accs_a, valid_ns_a)

    # 生成标签
    hd_target = [float(means_a[int(b)]) for b in bins_a]
    out_records: List[Dict[str, Any]] = []
    for rec, b, t in zip(records, bins_a, hd_target):
        out_records.append({
            "id": rec.get("id", ""),
            "question": rec.get("question", ""),
            "model_answer": rec.get("model_answer", ""),
            field: rec.get(field, None),
            "hd_label": t,
            "hd_target": t,
            "hd_bin": int(b),
        })

    os.makedirs(os.path.dirname(args.out_pkl), exist_ok=True)
    with open(args.out_pkl, "wb") as f:
        pickle.dump(out_records, f, protocol=pickle.HIGHEST_PROTOCOL)

    print(f"\n[done] wrote {args.out_pkl} with {len(out_records)} samples")
    if out_records:
        print("keys:", list(out_records[0].keys()))

if __name__ == "__main__":
    main()
