#!/usr/bin/env python3
# ==========================================
# 20-샘플 Probe (First-Token Accuracy / Exact-Match)
# - train: param / in_ctx / pert_ctx_orig / pert_ctx_pert
# - unknown: ood_in_ctx
# - 추가: multi_in_ctx / multi_ood_in_ctx
# ------------------------------------------
# 요구사항 반영:
#   - pert 관련 모드(pert_ctx_orig, pert_ctx_pert)는 ATTR_KEYS_PERT만 사용
#   - 그 외 모드(param, in_ctx, ood_in_ctx, multi_in_ctx, multi_ood_in_ctx)는 ATTR_KEYS_ALL 사용
#   - 5개 시드 반복 후 평균/표준편차 저장
# ==========================================
import os, random, re, json, logging, gc
from pathlib import Path
from typing import List, Dict, Tuple, Optional

import numpy as np
import pandas as pd
import torch, torch.nn.functional as F
from tqdm.auto import tqdm
from datasets import Dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel

# ---------- 로그 설정 ----------
logging.basicConfig(level=logging.INFO)
log = logging.getLogger("probe")

# ---------- 사용자 설정 ----------
MODEL_ROOT = Path("./gpt2_runs_0816")  # 모든 run 디렉터리의 상위 폴더
# (선택) 특정 run에 대해 step 범위를 제한하려면 CKPT_RANGE에 기록하세요. 비우면 모든 step 사용.
# 예: CKPT_RANGE = {"false_concat_10": (1000, 32000)}
CKPT_RANGE: Dict[str, Tuple[int, int]] = {}

# (선택) run 디렉터리명 필터. None이면 전체 사용. (예: "25"만 포함하려면 "25"로 설정)
FILTER_RUN_NAME_SUBSTR: Optional[str] = None

TRAIN_JSON = "../dataset_generation/bioS_train_250806.json"
# 한 파일에 퍼트브가 없다면 아래를 설정하세요 (없으면 None 유지)
TRAIN_PERT_JSON = "../dataset_generation/bioS_train_250806_pert.json"

UNKNOWN_JSON = "../dataset_generation/bioS_unknown_250806.json"

# ----- 속성 키 (요구사항) -----
ATTR_KEYS_PERT = ["birth_date", "university"]
ATTR_KEYS_ALL  = ["birth_city", "birth_date", "major", "university"]

SAMPLE_N_TRAIN   = 200                     # train 샘플 수
SAMPLE_N_UNKNOWN = 200                     # unknown 샘플 수

# 5개 시드로 반복
RANDOM_SEEDS = [0, 1, 2, 3, 4]

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ---------- 유틸리티 ----------
def load_json(path: str):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def list_run_dirs(root: Path) -> List[Path]:
    """checkpoint-* 하위 폴더가 있는 run 디렉터리만 반환"""
    runs = []
    if not root.exists():
        log.warning(f"MODEL_ROOT가 존재하지 않습니다: {root.resolve()}")
        return runs
    for p in root.iterdir():
        if p.is_dir():
            try:
                has_ckpt = any(child.is_dir() and child.name.startswith("checkpoint") for child in p.iterdir())
            except PermissionError:
                has_ckpt = False
            if has_ckpt:
                runs.append(p)
    runs.sort(key=lambda x: x.name)
    return runs

def list_ckpts_in_range(model_dir: Path, lo: int = 0, hi: int = 10**12) -> List[Path]:
    """lo ≤ step ≤ hi & 100의 배수인 checkpoint 디렉터리 정렬 반환"""
    cand = []
    for p in model_dir.iterdir():
        if not p.is_dir() or not p.name.startswith("checkpoint"):
            continue
        m = re.findall(r"\d+", p.name)
        if not m:
            continue
        step = int(m[0])
        if lo <= step <= hi and step % 100 == 0:
            cand.append((step, p))
    return [p for step, p in sorted(cand, key=lambda x: x[0])]

def _get_first_present_key(d: dict, keys: List[str]) -> Optional[str]:
    for k in keys:
        if k in d and d[k] is not None:
            return k
    return None

def split_train_into_o_p_lists(train_raw, fallback_pert_json: Optional[str]) -> Tuple[List[dict], List[dict]]:
    """
    train_raw에서 원본(o)과 퍼트브(p)를 추출.
    - 형태 A: {"orig":[...], "pert":[...]}  -> 바로 사용
    - 형태 B: [ { ..., "pert": {"test_corpus":..., "probes":...}, ... }, ... ] -> 각 엔트리에서 추출
    - 형태 C: [ { "test_corpus":..., "probes":..., "test_corpus_pert":..., "probes_pert":... }, ... ] -> 키로 추출
    - 실패 시: fallback_pert_json을 읽어 pair 구성 (인덱스 매칭)
    """
    # 형태 A: dict with "orig" and "pert"
    if isinstance(train_raw, dict):
        if "orig" in train_raw and "pert" in train_raw:
            o_raw, p_raw = train_raw["orig"], train_raw["pert"]
            assert len(o_raw) == len(p_raw), "orig/pert 길이가 다릅니다."
            return o_raw, p_raw

    # 형태 B/C: list
    if isinstance(train_raw, list):
        o_list, p_list = [], []
        all_have_pert = True
        for e in train_raw:
            # 원본 키
            tc_o_key = _get_first_present_key(e, ["test_corpus", "orig_test_corpus", "ctx", "ctx_o"])
            pr_o_key = _get_first_present_key(e, ["probes", "orig_probes"])

            # 퍼트브: nested 'pert' or flat *_pert
            p_block = None
            if isinstance(e.get("pert"), dict):
                p_block = e["pert"]
            else:
                # flat 키 후보
                tc_p_key = _get_first_present_key(e, ["test_corpus_pert", "pert_test_corpus", "p_test_corpus", "ctx_p"])
                pr_p_key = _get_first_present_key(e, ["probes_pert", "pert_probes", "p_probes"])
                if tc_p_key and pr_p_key:
                    p_block = {"test_corpus": e[tc_p_key], "probes": e[pr_p_key]}

            if tc_o_key and pr_o_key:
                o_entry = {"test_corpus": e[tc_o_key], "probes": e[pr_o_key]}
                o_list.append(o_entry)
            else:
                all_have_pert = False  # 원본조차 없으면 페어링 불가
                break

            if p_block and "test_corpus" in p_block and "probes" in p_block:
                p_list.append({"test_corpus": p_block["test_corpus"], "probes": p_block["probes"]})
            else:
                all_have_pert = False

        if all_have_pert and len(o_list) == len(p_list) and len(o_list) > 0:
            return o_list, p_list
        # else: 낙하하여 fallback 사용

    # fallback: 별도 pert JSON에서 읽기
    if fallback_pert_json is None:
        raise ValueError(
            "train JSON에서 퍼트브 정보를 찾을 수 없습니다. "
            "한 파일에 (orig/pert)가 함께 있거나, TRAIN_PERT_JSON 경로를 지정해야 합니다."
        )
    pert_raw = load_json(fallback_pert_json)
    if isinstance(train_raw, list) and isinstance(pert_raw, list):
        assert len(train_raw) == len(pert_raw), "train/pert 길이가 다릅니다."
        return train_raw, pert_raw
    elif isinstance(train_raw, dict) and isinstance(pert_raw, dict):
        return train_raw.get("orig", []), pert_raw.get("pert", [])
    else:
        raise ValueError("fallback pert 형식이 예상과 다릅니다.")

# ---------- 예시 구성 ----------
def _make_row(tok, mode: str, attr: str, prefix: str, tgt: str, full: Optional[str] = None):
    """
    prefix = (ctx들 concat) + ' ' + prompt  (또는 ' ' + prompt)
    full   = prefix + tgt
    토큰 경계 병합까지 반영해 prompt_len, target_len을 안전하게 계산.
    """
    if full is None:
        full = prefix + tgt

    # 길이 계산은 반드시 prefix와 full의 차이로 (BPE 병합 안전)
    prompt_len = len(tok(prefix)["input_ids"])
    full_len   = len(tok(full)["input_ids"])
    target_len = full_len - prompt_len

    return {
        "mode": mode,
        "attr": attr,
        "text": full,
        "prompt_len": prompt_len,
        "target_len": target_len
    }

def _safe_get_probe(d: dict, attr: str) -> Optional[Tuple[str, str]]:
    """d['probes'][attr]이 존재하면 (prompt, tgt) 반환, 없으면 None"""
    probes = d.get("probes", {})
    pair = probes.get(attr, None)
    if pair is None:
        return None
    # pair 형식 안전 확인
    if isinstance(pair, (list, tuple)) and len(pair) == 2:
        return pair[0], pair[1]
    raise ValueError(f"probes['{attr}'] 형식이 (prompt, tgt)가 아닙니다: {type(pair)}")

def build_examples_train(o_raw: List[dict], p_raw: List[dict], tok,
                         attr_keys_nonpert: List[str], attr_keys_pert: List[str],
                         sample_n: int) -> Dataset:
    """
    train에서 4가지 모드 생성
      - non-pert( param / in_ctx )   : attr_keys_nonpert 사용
      - pert   ( pert_ctx_orig / pert_ctx_pert ) : attr_keys_pert 사용
    """
    n = min(len(o_raw), len(p_raw))
    if n == 0:
        raise ValueError("train 데이터가 비었습니다.")
    k = min(sample_n, n)
    idxs = random.sample(range(n), k)

    rows = []
    for i in idxs:
        o, p = o_raw[i], p_raw[i]
        ctx_o, ctx_p = o["test_corpus"], p["test_corpus"]

        # ----- non-pert 모드 -----
        for attr in attr_keys_nonpert:
            pair = _safe_get_probe(o, attr)
            if pair is None:
                continue
            prompt_o, tgt_o = pair

            # param
            prefix = " " + prompt_o
            rows.append(_make_row(tok, "param", attr, prefix, tgt_o))

            # in_ctx
            prefix = " " + ctx_o + " " + prompt_o
            rows.append(_make_row(tok, "in_ctx", attr, prefix, tgt_o))

        # ----- pert 모드 -----
        for attr in attr_keys_pert:
            pair_o = _safe_get_probe(o, attr)
            pair_p = _safe_get_probe(p, attr)
            if pair_o is None or pair_p is None:
                continue
            prompt_o, tgt_o = pair_o
            prompt_p, tgt_p = pair_p

            # pert_ctx_orig
            prefix = " " + ctx_p + " " + prompt_o
            rows.append(_make_row(tok, "pert_ctx_orig", attr, prefix, tgt_o))

            # pert_ctx_pert
            prefix = " " + ctx_p + " " + prompt_p
            rows.append(_make_row(tok, "pert_ctx_pert", attr, prefix, tgt_p))

    return Dataset.from_list(rows)

def build_examples_ood(u_raw: List[dict], tok, attr_keys: List[str], sample_n: int) -> Dataset:
    """unknown에서 ood_in_ctx 생성 (ATTR_KEYS_ALL 사용, target_len 포함)"""
    n = len(u_raw)
    if n == 0:
        raise ValueError("unknown 데이터가 비었습니다.")
    k = min(sample_n, n)
    idxs = random.sample(range(n), k)

    rows = []
    for i in idxs:
        o = u_raw[i]
        ctx_o = o["test_corpus"]
        for attr in attr_keys:
            pair = _safe_get_probe(o, attr)
            if pair is None:
                continue
            prompt_o, tgt_o = pair
            prefix = " " + ctx_o + " " + prompt_o
            rows.append(_make_row(tok, "ood_in_ctx", attr, prefix, tgt_o))
    return Dataset.from_list(rows)

def _sample_others(n_total: int, self_idx: int, num_others: int) -> List[int]:
    """self_idx를 제외하고 최대 num_others를 샘플. 데이터가 작을 때도 안전."""
    pool = [j for j in range(n_total) if j != self_idx]
    k = min(len(pool), num_others)
    return random.sample(pool, k) if k > 0 else []

def build_examples_train_multi(o_raw: List[dict], tok, attr_keys: List[str],
                               sample_n: int, num_others: int = 2, sep: str = " ") -> Dataset:
    """
    train 기반 multiple in-ctx:
    - 자기 문단 1개 + 다른 사람 문단 num_others개를 섞어서 concat
    - 그 뒤에 prompt_o + tgt_o
    - ATTR_KEYS_ALL 사용
    """
    n = len(o_raw)
    if n == 0:
        raise ValueError("train 데이터가 비었습니다.")
    k = min(sample_n, n)
    idxs = random.sample(range(n), k)

    rows = []
    for i in idxs:
        others = _sample_others(n, i, num_others)
        ctxs = [o_raw[i]["test_corpus"]] + [o_raw[j]["test_corpus"] for j in others]
        random.shuffle(ctxs)
        mixed_ctx = sep.join(ctxs)

        for attr in attr_keys:
            pair = _safe_get_probe(o_raw[i], attr)
            if pair is None:
                continue
            prompt_o, tgt_o = pair
            prefix = " " + mixed_ctx + " " + prompt_o
            rows.append(_make_row(tok, "multi_in_ctx", attr, prefix, tgt_o))
    return Dataset.from_list(rows)

def build_examples_ood_multi(u_raw: List[dict], tok, attr_keys: List[str],
                             sample_n: int, num_others: int = 2, sep: str = " ") -> Dataset:
    """
    unknown 기반 multiple ood-in-ctx:
    - 자기 문단 1개 + 다른 사람 문단 num_others개를 섞어서 concat
    - 그 뒤에 prompt_o + tgt_o
    - ATTR_KEYS_ALL 사용
    """
    n = len(u_raw)
    if n == 0:
        raise ValueError("unknown 데이터가 비었습니다.")
    k = min(sample_n, n)
    idxs = random.sample(range(n), k)

    rows = []
    for i in idxs:
        others = _sample_others(n, i, num_others)
        ctxs = [u_raw[i]["test_corpus"]] + [u_raw[j]["test_corpus"] for j in others]
        random.shuffle(ctxs)
        mixed_ctx = sep.join(ctxs)

        for attr in attr_keys:
            pair = _safe_get_probe(u_raw[i], attr)
            if pair is None:
                continue
            prompt_o, tgt_o = pair
            prefix = " " + mixed_ctx + " " + prompt_o
            rows.append(_make_row(tok, "multi_ood_in_ctx", attr, prefix, tgt_o))
    return Dataset.from_list(rows)

# ---------- 평가 ----------
@torch.inference_mode()
def compute_metrics_for_example(model, tok, ex, device: str) -> Tuple[Optional[int], Optional[int]]:
    """
    returns: (acc1, em)
      - acc1: First-token accuracy (0/1) 또는 None(불가)
      - em  : Exact-match        (0/1) 또는 None(불가)
    """
    max_len = getattr(model.config, "n_ctx", getattr(model.config, "n_positions", 512))
    ids = tok(ex["text"], return_tensors="pt",
              truncation=True, max_length=max_len).input_ids.to(device)
    logits = model(ids).logits.squeeze(0)  # [seq_len, vocab]

    seq_len = ids.shape[1]
    pl = int(ex["prompt_len"])
    tl = int(ex["target_len"])

    acc1 = None
    em   = None

    # First token
    if pl < seq_len and pl - 1 >= 0:
        pred1 = torch.argmax(logits[pl - 1]).item()
        true1 = ids[0, pl].item()
        acc1 = int(pred1 == true1)

    # Exact match: 전체 target 길이가 input에 온전히 포함된 경우만 평가
    if (tl > 0) and (pl + tl <= seq_len) and (pl - 1 >= 0):
        pred_seq = torch.argmax(logits[pl - 1: pl + tl - 1], dim=-1)
        true_seq = ids[0, pl: pl + tl]
        em = int(torch.equal(pred_seq, true_seq))

    return acc1, em

def _init_score_bucket():
    return {"acc1_sum": 0, "acc1_n": 0, "em_sum": 0, "em_n": 0}

def _finalize_scores(buckets: Dict[str, Dict[str, int]]) -> Dict[str, float]:
    out = {}
    for mode, b in buckets.items():
        acc1 = 100.0 * b["acc1_sum"] / b["acc1_n"] if b["acc1_n"] > 0 else float("nan")
        em   = 100.0 * b["em_sum"]   / b["em_n"]   if b["em_n"]   > 0 else float("nan")
        out[f"acc1/{mode}"] = acc1
        out[f"em/{mode}"]   = em
    return out

def probe_ckpt(ckpt_path: Path, tok, eval_ds: Dataset, device: str) -> Dict[str, float]:
    """단일 체크포인트에 대한 모드별 First-Token Accuracy / Exact-Match(%) 계산"""
    log.info(f"▶ 평가 중: {ckpt_path}")
    model = GPT2LMHeadModel.from_pretrained(ckpt_path).to(device).eval()

    buckets: Dict[str, Dict[str, int]] = {}
    for ex in eval_ds:
        mode = ex["mode"]
        if mode not in buckets:
            buckets[mode] = _init_score_bucket()

        acc1, em = compute_metrics_for_example(model, tok, ex, device)
        if acc1 is not None:
            buckets[mode]["acc1_sum"] += acc1
            buckets[mode]["acc1_n"]   += 1
        if em is not None:
            buckets[mode]["em_sum"] += em
            buckets[mode]["em_n"]   += 1

    # 메모리 해제
    del model
    if device == "cuda":
        torch.cuda.empty_cache()
    gc.collect()

    return _finalize_scores(buckets)

# ---------- 토크나이저 ----------
tok = GPT2Tokenizer.from_pretrained("gpt2")
tok.pad_token = tok.eos_token

# ---------- 데이터 로드 (고정) ----------
log.info("데이터 로드 중...")
train_raw = load_json(TRAIN_JSON)
o_raw_full, p_raw_full = split_train_into_o_p_lists(train_raw, TRAIN_PERT_JSON)

unknown_raw_full = load_json(UNKNOWN_JSON)
# 리스트 형태만 지원 (필요 시 확장)
if isinstance(unknown_raw_full, dict):
    # 흔한 케이스가 아니므로 가장 그럴듯한 키를 추출
    if "data" in unknown_raw_full:
        unknown_raw_full = unknown_raw_full["data"]
    else:
        raise ValueError("unknown JSON이 dict 형태입니다. 리스트를 기대합니다.")

# ---------- 시드 반복: 평가/수집 ----------
all_records = []

# 평가에 사용할 run 디렉터리 미리 스캔
run_dirs_master = list_run_dirs(MODEL_ROOT)
if FILTER_RUN_NAME_SUBSTR is not None:
    run_dirs_master = [p for p in run_dirs_master if FILTER_RUN_NAME_SUBSTR in p.name]
if not run_dirs_master:
    log.warning(f"{MODEL_ROOT} 아래에서 checkpoint가 포함된 run 디렉터리를 찾지 못했습니다.")

for seed in RANDOM_SEEDS:
    log.info(f"===== Seed {seed} 시작 =====")
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    # 시드마다 샘플링이 달라지므로, eval 데이터셋을 시드 기준으로 새로 생성
    log.info("평가 샘플 구성 중...")
    eval_train       = build_examples_train(o_raw_full, p_raw_full, tok, ATTR_KEYS_ALL, ATTR_KEYS_PERT, SAMPLE_N_TRAIN)
    eval_ood         = build_examples_ood(unknown_raw_full, tok, ATTR_KEYS_ALL, SAMPLE_N_UNKNOWN)
    eval_train_multi = build_examples_train_multi(o_raw_full, tok, ATTR_KEYS_ALL, SAMPLE_N_TRAIN)
    eval_ood_multi   = build_examples_ood_multi(unknown_raw_full, tok, ATTR_KEYS_ALL, SAMPLE_N_UNKNOWN)

    eval_all = Dataset.from_list(
        [ex for ex in eval_train] +
        [ex for ex in eval_ood] +
        [ex for ex in eval_train_multi] +
        [ex for ex in eval_ood_multi]
    )

    # 이 시드에서의 모든 체크포인트 평가
    for run_dir in run_dirs_master:
        run_name = run_dir.name
        lo, hi = CKPT_RANGE.get(run_name, (0, 10**12))
        ckpts = list_ckpts_in_range(run_dir, lo, hi)
        if not ckpts:
            log.warning(f"{run_name}: 지정 범위({lo},{hi}) 내 checkpoint가 없습니다.")
            continue

        for ckpt in tqdm(ckpts, desc=f"🔍 {run_name} (seed {seed})"):
            step = int(re.findall(r"\d+", ckpt.name)[0])
            rec  = probe_ckpt(ckpt, tok, eval_all, DEVICE)
            rec.update({"model": run_name, "step": step, "seed": seed})
            all_records.append(rec)

# ---------- 결과 집계/저장 ----------
if all_records:
    df = pd.DataFrame(all_records).sort_values(["model", "step", "seed"])
    # 보기 좋게 컬럼 순서 정렬(모드/지표 기준으로)
    metric_cols = sorted([c for c in df.columns if "/" in c])
    df = df[["model", "step", "seed"] + metric_cols]

    out_path_raw = Path("probe_results_multi.csv")
    df.to_csv(out_path_raw, index=False, encoding="utf-8")
    log.info(f"Raw 결과 저장: {out_path_raw.resolve()}")

    # 평균/표준편차 집계 (model × step 단위)
    summary = df.groupby(["model", "step"])[metric_cols].agg(["mean", "std"]).reset_index()
    # MultiIndex -> 평평하게
    summary.columns = ["_".join(col).rstrip("_") for col in summary.columns.values]

    out_path_summary = Path("probe_results_multi_summary.csv")
    summary.to_csv(out_path_summary, index=False, encoding="utf-8")
    log.info(f"Summary 결과 저장: {out_path_summary.resolve()}")

    try:
        from IPython.display import display
        log.info("요약 결과 상위 20행 미리보기:")
        display(summary.head(20))
    except Exception:
        print(summary.head(20))
else:
    log.warning("수집된 결과가 없습니다.")
