import os
import csv
from pathlib import Path
from matplotlib import cm
import numpy as np

def normalize(values):
    vmax = max(values)
    vmin = min(values)
    abs_max = max(abs(vmin), abs(vmax))
    if abs_max < 1e-5:
        return [0.5 for _ in values]
    return [0.5 + 0.5 * (v / abs_max) for v in values]

def get_token_html(token: str, value: float, norm_value: float, label: str, cmap_name: str = "RdBu") -> str:
    rgba = cm.get_cmap(cmap_name)(norm_value)
    rgb = np.array(rgba[:3]) * 255
    bg_color = f"rgb({int(rgb[0])},{int(rgb[1])},{int(rgb[2])})"
    luminance = 0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2]
    text_color = "white" if luminance < 140 else "black"
    token = token.replace("Ġ", " ").replace("Ċ", "\\n")
    return f"""<span style="background-color:{bg_color}; color:{text_color}; padding:2px 4px; margin:1px;
                border-radius:4px; font-family:monospace"
                title="{label}={value:.4f}">{token}</span>"""

def save_diff_html(tokens, diffs, output_path, label="Δ", title="Δ Value (TPKD − DPO)"):
    norm_diffs = normalize(diffs)
    html_row = [
        get_token_html(tok, diff, norm, label, "RdBu")
        for tok, diff, norm in zip(tokens, diffs, norm_diffs)
    ]
    html = f"""
    <html><body>
    <div style='font-size: 16px;'><b>{title}</b></div>
    <div style='font-size: 16px;'>{' '.join(html_row)}</div>
    </body></html>
    """
    with open(output_path, "w", encoding="utf-8") as f:
        f.write(html)

def load_csv_data(csv_path):
    tokens, logps, dvs = [], [], []
    with open(csv_path, "r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            tokens.append(row["token"])
            logps.append(float(row["logp"]))
            dvs.append(float(row["delta_v"]))
    return tokens, logps, dvs

def compare_and_generate_html(tp_dir, dpo_dir, output_dir):
    os.makedirs(output_dir, exist_ok=True)

    for kind in ["chosen", "rejected"]:
        tp_files = sorted(Path(tp_dir).glob(f"*_*.csv"))
        tp_files = [f for f in tp_files if f.name.endswith(f"{kind}.csv")]

        for tp_csv in tp_files:
            fname = tp_csv.name
            dpo_csv = Path(dpo_dir) / fname
            base = fname.replace(".csv", "")
            out_dv = Path(output_dir) / f"{base}_dv.html"
            out_logp = Path(output_dir) / f"{base}_logp.html"

            if not dpo_csv.exists():
                print(f"⚠️ {dpo_csv} not found.")
                continue

            tokens1, logps1, dvs1 = load_csv_data(tp_csv)
            tokens2, logps2, dvs2 = load_csv_data(dpo_csv)

            if tokens1 != tokens2:
                print(f"❌ Token mismatch in {fname}")
                continue

            dv_diffs = [a - b for a, b in zip(dvs1, dvs2)]
            logp_diffs = [a - b for a, b in zip(logps1, logps2)]

            save_diff_html(tokens1, dv_diffs, out_dv, label="ΔV", title=f"Δ Soft Value (TPKD − DPO) [{kind}]")
            save_diff_html(tokens1, logp_diffs, out_logp, label="Δlogp", title=f"Δ LogProb (TPKD − DPO) [{kind}]")
            print(f"✅ Saved: {out_dv.name}, {out_logp.name}")

if __name__ == "__main__":
    tp_dir = "token_visualizations/TPKD"
    dpo_dir = "token_visualizations/DPO"
    output_dir = "token_visualizations/TPKD-DPO"

    compare_and_generate_html(tp_dir, dpo_dir, output_dir)