import os
import torch
import argparse
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from matplotlib import cm
import matplotlib.pyplot as plt
import matplotlib as mpl
import csv
from tqdm.auto import tqdm

# 수식 렌더링 비활성화
mpl.rcParams['text.usetex'] = False
mpl.rcParams['mathtext.default'] = 'regular'

def normalize(values):
    vmax = max(values)
    vmin = min(values)
    abs_max = max(abs(vmin), abs(vmax))
    if abs_max < 1e-5:
        return [0.5 for _ in values]
    return [0.5 + 0.5 * (v / abs_max) for v in values]

def compute_soft_values(logits, beta=1.0):
    values = []
    for t in range(logits.shape[0]):
        q_t = logits[t] / beta
        v_t = beta * torch.logsumexp(q_t, dim=-1)
        values.append(v_t.item())
    return values

def get_token_html(token: str, value: float, norm_value: float, label: str, cmap_name: str = "Reds") -> str:
    rgba = cm.get_cmap(cmap_name)(norm_value)
    rgb = np.array(rgba[:3]) * 255
    bg_color = f"rgb({int(rgb[0])},{int(rgb[1])},{int(rgb[2])})"
    luminance = 0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2]
    text_color = "white" if luminance < 140 else "black"
    token = token.replace("Ġ", " ").replace("Ī", "\\n")
    return f"""<span style="background-color:{bg_color}; color:{text_color}; padding:2px 4px; margin:1px;\n                border-radius:4px; font-family:monospace"\n                title=\"{label}={value:.4f}\">{token}</span>"""

def save_html(tokens, logps, dvs, out_path):
    norm_logp = normalize(logps)
    norm_dv = normalize(dvs)

    logp_row = [
        get_token_html(tok, val, norm, "logp", "Reds")
        for tok, val, norm in zip(tokens, logps, norm_logp)
    ]
    dv_row = [
        get_token_html(tok, val, norm, "ΔV", "Reds")
        for tok, val, norm in zip(tokens, dvs, norm_dv)
    ]

    html_output = f"""
    <html><body>
    <div style='font-size: 16px;'><b>Token LogProb</b></div>
    <div style='font-size: 16px; margin-bottom: 16px;'>{' '.join(logp_row)}</div>
    <div style='font-size: 16px;'><b>Δ Soft Value (ΔVₜ = Vₜ - Vₜ₋₁)</b></div>
    <div style='font-size: 16px;'>{' '.join(dv_row)}</div>
    </body></html>
    """

    with open(out_path, "w", encoding="utf-8") as f:
        f.write(html_output)

def save_csv(tokens, logps, dvs, sample_id, kind, output_dir):
    csv_path = os.path.join(output_dir, f"token_vis_{sample_id:03d}_{kind}.csv")
    with open(csv_path, "w", encoding="utf-8", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["token_index", "token", "logp", "delta_v"])
        for i, (tok, lp, dv) in enumerate(zip(tokens, logps, dvs)):
            writer.writerow([i, tok, lp, dv])

import os
import torch
import argparse
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from matplotlib import cm
import matplotlib.pyplot as plt
import matplotlib as mpl
import csv
from tqdm.auto import tqdm

# 수식 렌더링 비활성화
mpl.rcParams['text.usetex'] = False
mpl.rcParams['mathtext.default'] = 'regular'

def normalize(values):
    vmax = max(values)
    vmin = min(values)
    abs_max = max(abs(vmin), abs(vmax))
    if abs_max < 1e-5:
        return [0.5 for _ in values]
    return [0.5 + 0.5 * (v / abs_max) for v in values]

def compute_soft_values(logits, beta=1.0):
    values = []
    for t in range(logits.shape[0]):
        q_t = logits[t] / beta
        v_t = beta * torch.logsumexp(q_t, dim=-1)
        values.append(v_t.item())
    return values

def get_token_html(token: str, value: float, norm_value: float, label: str, cmap_name: str = "Reds") -> str:
    rgba = cm.get_cmap(cmap_name)(norm_value)
    rgb = np.array(rgba[:3]) * 255
    bg_color = f"rgb({int(rgb[0])},{int(rgb[1])},{int(rgb[2])})"
    luminance = 0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2]
    text_color = "white" if luminance < 140 else "black"
    token = token.replace("Ġ", " ").replace("Ī", "\\n")
    return f"""<span style="background-color:{bg_color}; color:{text_color}; padding:2px 4px; margin:1px;\n                border-radius:4px; font-family:monospace"\n                title=\"{label}={value:.4f}\">{token}</span>"""

def save_html(tokens, logps, dvs, out_path):
    norm_logp = normalize(logps)
    norm_dv = normalize(dvs)

    logp_row = [
        get_token_html(tok, val, norm, "logp", "Reds")
        for tok, val, norm in zip(tokens, logps, norm_logp)
    ]
    dv_row = [
        get_token_html(tok, val, norm, "ΔV", "Reds")
        for tok, val, norm in zip(tokens, dvs, norm_dv)
    ]

    html_output = f"""
    <html><body>
    <div style='font-size: 16px;'><b>Token LogProb</b></div>
    <div style='font-size: 16px; margin-bottom: 16px;'>{' '.join(logp_row)}</div>
    <div style='font-size: 16px;'><b>Δ Soft Value (ΔVₜ = Vₜ - Vₜ₋₁)</b></div>
    <div style='font-size: 16px;'>{' '.join(dv_row)}</div>
    </body></html>
    """

    with open(out_path, "w", encoding="utf-8") as f:
        f.write(html_output)

def save_csv(tokens, logps, dvs, sample_id, kind, output_dir):
    csv_path = os.path.join(output_dir, f"token_vis_{sample_id:03d}_{kind}.csv")
    with open(csv_path, "w", encoding="utf-8", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["token_index", "token", "logp", "delta_v"])
        for i, (tok, lp, dv) in enumerate(zip(tokens, logps, dvs)):
            writer.writerow([i, tok, lp, dv])

def process_batch(batch_indices, batch_samples, model, tokenizer, output_dir, beta=1.0):
    prompts, kinds, answer_starts = [], [], []
    idx_map = {}
    vt_stats = {"match": 0, "total": 0}
    logp_stats = {"match": 0, "total": 0}
    vt_collection = {"chosen": [], "rejected": []}

    for idx, sample in zip(batch_indices, batch_samples):
        idx_map[idx] = {}

        for kind in ["chosen", "rejected"]:
            dialog = sample[kind]
            q = next((t["content"] for t in dialog if t["role"] == "user"), "")
            a = next((t["content"] for t in dialog if t["role"] == "assistant"), "")
            if len(q) > 700 or len(a) > 700:
                continue
            prompt = q.strip() + " A: " + a.strip()
            answer_start = len(tokenizer(q.strip() + " A:", return_tensors="pt").input_ids[0])
            prompts.append(prompt)
            kinds.append((idx, kind))
            answer_starts.append(answer_start)

    if not prompts:
        return vt_stats, logp_stats, vt_collection

    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(model.device)
    input_ids = inputs.input_ids
    logits = model(**inputs).logits
    logps = torch.nn.functional.log_softmax(logits, dim=-1)

    for i in range(len(prompts)):
        idx, kind = kinds[i]
        seq = input_ids[i]
        logp_i = logps[i]
        soft_vs = compute_soft_values(logits[i], beta)
        vt = soft_vs[-1]

        # 평균 logp 계산
        token_logps = [logp_i[t - 1, seq[t]].item() for t in range(1, len(seq))]
        avg_logp = sum(token_logps) / len(token_logps)

        idx_map[idx][kind] = {"vt": vt, "avg_logp": avg_logp}
        vt_collection[kind].append(vt)

        # 시각화 저장
        delta_vs = [soft_vs[t] - soft_vs[t - 1] for t in range(1, len(soft_vs))]
        tokens = tokenizer.convert_ids_to_tokens(seq[1:])
        a_start = answer_starts[i]
        tokens_answer = tokens[a_start - 1:]
        token_logps_answer = token_logps[a_start - 1:]
        delta_vs_answer = delta_vs[a_start - 1:]

        html_path = os.path.join(output_dir, f"token_vis_{idx:03d}_{kind}.html")
        save_html(tokens_answer, token_logps_answer, delta_vs_answer, html_path)
        save_csv(tokens_answer, token_logps_answer, delta_vs_answer, idx, kind, output_dir)

    for idx in idx_map:
        if "chosen" in idx_map[idx] and "rejected" in idx_map[idx]:
            vt_stats["total"] += 1
            logp_stats["total"] += 1

            if idx_map[idx]["chosen"]["vt"] > idx_map[idx]["rejected"]["vt"]:
                vt_stats["match"] += 1
            if idx_map[idx]["chosen"]["avg_logp"] > idx_map[idx]["rejected"]["avg_logp"]:
                logp_stats["match"] += 1

    return vt_stats, logp_stats, vt_collection


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--beta", type=float, default=1.0)
    parser.add_argument("--max_samples", type=int, default=1000)
    parser.add_argument("--batch_size", type=int, default=16)
    args = parser.parse_args()

    output_dir = os.path.join("./token_visualizations", args.output_dir)
    os.makedirs(output_dir, exist_ok=True)

    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        device_map="auto",
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    ).eval()

    dataset = load_dataset("argilla/dpo-mix-7k", split="test")

    total_vt, match_vt = 0, 0
    total_logp, match_logp = 0, 0
    all_vt_chosen, all_vt_rejected = [], []

    for i in tqdm(range(0, min(len(dataset), args.max_samples), args.batch_size)):
        batch_samples = [dataset[j] for j in range(i, min(i + args.batch_size, len(dataset)))]
        vt_stats, logp_stats, vt_batch = process_batch(range(i, i + len(batch_samples)), batch_samples, model, tokenizer, output_dir, beta=args.beta)
        
        total_vt += vt_stats["total"]
        match_vt += vt_stats["match"]
        
        total_logp += logp_stats["total"]
        match_logp += logp_stats["match"]
        
        all_vt_chosen.extend(vt_batch["chosen"])
        all_vt_rejected.extend(vt_batch["rejected"])

    if total_vt > 0:
        print(f"[V_T Accuracy] Match: {match_vt}/{total_vt} = {match_vt / total_vt * 100:.2f}%")
    else:
        print("No valid V_T comparisons found.")

    if total_logp > 0:
        print(f"[Avg LogP Accuracy] Match: {match_logp}/{total_logp} = {match_logp / total_logp * 100:.2f}%")
    else:
        print("No valid LogP comparisons found.")
    
if __name__ == "__main__":
    main()