import os
import re
import numpy as np
import pandas as pd
from pathlib import Path
from bs4 import BeautifulSoup
from scipy.special import softmax
from scipy.special import rel_entr

def parse_dv_from_html(html_path):
    with open(html_path, "r", encoding="utf-8") as f:
        soup = BeautifulSoup(f, "html.parser")
    spans = soup.find_all("span")
    dv_values = []
    for span in spans:
        title = span.get("title", "")
        match = re.search(r"ΔV=([-+]?[0-9]*\.?[0-9]+)", title)
        if match:
            dv_values.append(float(match.group(1)))
    return dv_values

def kl_divergence(p, q):
    p = np.asarray(p, dtype=np.float64)
    q = np.asarray(q, dtype=np.float64)
    return np.sum(rel_entr(p, q))

def compute_kl_scores(tp_dpo_dir, dpo_ref_dir, topk=5, mink=5, min_length=150):
    result = []
    tp_files = sorted(Path(tp_dpo_dir).glob("*_dv.html"))

    for tp_path in tp_files:
        fname = tp_path.name
        ref_path = Path(dpo_ref_dir) / fname

        if not ref_path.exists():
            print(f"⚠️ Missing: {ref_path}")
            continue

        tp_dv = parse_dv_from_html(tp_path)
        ref_dv = parse_dv_from_html(ref_path)

        if len(tp_dv) != len(ref_dv):
            print(f"❌ Length mismatch in {fname}")
            continue

        if len(tp_dv) < min_length:
            continue

        tp_dist = softmax(tp_dv)
        ref_dist = softmax(ref_dv)

        kl = kl_divergence(tp_dist, ref_dist)
        result.append({
            "file": fname,
            "kl_divergence": kl,
            "length": len(tp_dv),
            "tp_path": str(tp_path),
            "ref_path": str(ref_path)
        })

    df = pd.DataFrame(result)

    # 평균 KL 출력
    mean_kl = df["kl_divergence"].mean()
    print(f"\n📊 Mean KL-divergence (len >= {min_length}): {mean_kl:.6f}\n")

    # Top-K 출력
    topk_df = df.sort_values("kl_divergence", ascending=False).head(topk)
    print(f"🔥 Top-{topk} samples by KL-divergence (len >= {min_length}):")
    print(topk_df[["file", "kl_divergence", "length"]])

    # Bottom-K 출력
    mink_df = df.sort_values("kl_divergence", ascending=True).head(mink)
    print(f"\n❄️ Bottom-{mink} samples by KL-divergence (len >= {min_length}):")
    print(mink_df[["file", "kl_divergence", "length"]])

    return df

# 예시 실행 코드
if __name__ == "__main__":
    tp_dpo_dir = "token_visualizations/advantage_TPKD"
    dpo_ref_dir = "token_visualizations/advantage_teacher"
    
    df = compute_kl_scores(tp_dpo_dir, dpo_ref_dir)
    print(df)