import os
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from matplotlib import cm

def normalize(values):
    vmin, vmax = min(values), max(values)
    if vmax - vmin < 1e-5:
        return [0.5 for _ in values]
    return [(v - vmin) / (vmax - vmin) for v in values]

def get_token_html(token: str, value: float, norm_value: float, label: str, cmap_name: str = "Reds") -> str:
    rgba = cm.get_cmap(cmap_name)(norm_value)
    rgb = np.array(rgba[:3]) * 255
    bg_color = f"rgb({int(rgb[0])},{int(rgb[1])},{int(rgb[2])})"
    luminance = 0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2]
    text_color = "white" if luminance < 140 else "black"
    token = token.replace("Ġ", " ").replace("Ċ", "\\n")
    return f"""<span style="background-color:{bg_color}; color:{text_color}; padding:2px 4px; margin:1px;
                border-radius:4px; font-family:monospace"
                title="{label}={value:.4f}">{token}</span>"""

def compute_soft_values(logits, beta=1.0):
    values = []
    for t in range(logits.shape[0]):
        q_t = logits[t] / beta
        v_t = beta * torch.logsumexp(q_t, dim=-1)
        values.append(v_t.item())
    return values

def analyze_model(model_path, question, answer, beta=1.0):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map="auto",
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    ).eval()

    prompt = question.strip() + " A: " + answer.strip()
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs.input_ids[0]
    question_prefix = question.strip() + " A:"
    question_ids = tokenizer(question_prefix, return_tensors="pt").input_ids[0]
    answer_start = len(question_ids)

    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits[0]
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

    token_logps = []
    for t in range(1, len(input_ids)):
        token_id = input_ids[t]
        token_logp = log_probs[t - 1, token_id]
        token_logps.append(token_logp.item())

    soft_vs = compute_soft_values(logits, beta=beta)
    delta_vs = [soft_vs[t] - soft_vs[t - 1] for t in range(1, len(soft_vs))]

    tokens = tokenizer.convert_ids_to_tokens(input_ids[1:])
    tokens_answer = tokens[answer_start - 1:]
    token_logps_answer = token_logps[answer_start - 1:]
    delta_vs_answer = delta_vs[answer_start - 1:]

    norm_logp = normalize(token_logps_answer)
    norm_dv = normalize(delta_vs_answer)

    logp_row = [
        get_token_html(tok, val, norm, "logp", "Reds")
        for tok, val, norm in zip(tokens_answer, token_logps_answer, norm_logp)
    ]
    delta_row = [
        get_token_html(tok, val, norm, "ΔV", "Blues")
        for tok, val, norm in zip(tokens_answer, delta_vs_answer, norm_dv)
    ]
    return " ".join(logp_row), " ".join(delta_row)

def main():
    question = '''
    Who was Marie Curie and what was her contribution to science?
    '''
    answer = '''
    Marie Curie was a physicist and chemist who discovered the elements polonium and radium, and conducted pioneering research on radioactivity.

   
       '''
    beta = 1.0
    model_paths = {
        "TPKD": "./model/TPKD_entropy",
        "DPO-long": "./model/dpo_long",
        "DPO-teacher": "./model/dpo_teacher",
        "Ref-teacher": "./model/ref_teacher",
    }
    output_file = "token_vis_custom/curie_compare_token_vis.html"

    rows = ""
    for name, path in model_paths.items():
        print(f"🔍 Processing {name} ...")
        logp_html, delta_html = analyze_model(path, question, answer, beta)
        row_html = f"""
        <tr>
            <td style="vertical-align:top; font-weight:bold">{name}</td>
            <td>{logp_html}</td>
            <td>{delta_html}</td>
        </tr>
        """
        rows += row_html

    html_output = f"""
    <html><body>
    <style>
        table {{ border-collapse: collapse; width: 100%; }}
        th, td {{ border: 1px solid #999; padding: 8px; text-align: left; }}
        th {{ background-color: #eee; }}
    </style>
    <h2>Token-level Visualization Across Models</h2>
    <table>
        <tr>
            <th>Model</th>
            <th>Token LogProb (logp)</th>
            <th>Δ Soft Value (ΔVₜ = Vₜ - Vₜ₋₁)</th>
        </tr>
        {rows}
    </table>
    </body></html>
    """

    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    with open(output_file, "w", encoding="utf-8") as f:
        f.write(html_output)
    print(f"✅ All done. Saved to: {output_file}")

if __name__ == "__main__":
    main()