#!/usr/bin/env python3
"""
token_entropy_highlight.py
—————————
• Calculates per‑token entropy.
• Picks spaced high‑entropy positions.
• <mark>Highlights</mark> those tokens in the output.
• Saves the raw text, the highlighted text, and N+1
  sequential slices (generated_step1 … generated_step{N+1}).
"""

import json
import numpy as np
from tqdm import tqdm
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

from data_loader import load_data          # unchanged
from parser import parse_question          # unchanged


# ────────────────────────────────────────────────────────────
# Utility functions
# ────────────────────────────────────────────────────────────
def calculate_entropy(logprobs_dict, temperature: float = 1.0) -> float:
    """Shannon entropy from logprobs dict (keys = token‑str, values = LogProb)."""
    logprobs = np.array([lp.logprob for lp in logprobs_dict.values()])
    scaled = logprobs / temperature
    probs = np.exp(scaled - np.max(scaled))
    probs /= np.sum(probs)
    return float(-np.sum(probs * np.log(probs + 1e-10)))


def select_spaced_high_entropy_positions(
    entropies: list[dict],
    min_spacing: int = 10,
    n_positions: int = 5
) -> list[dict]:
    """Greedy pick of high‑entropy positions with a spacing constraint."""
    if not entropies:
        return []

    if len(entropies) < 50:
        idxs = np.linspace(0, len(entropies) - 1, min(n_positions, len(entropies)), dtype=int)
        return [entropies[i] for i in idxs]

    sorted_e = sorted(entropies, key=lambda x: x['entropy'], reverse=True)
    selected, excluded = [], set()

    for entry in sorted_e:
        pos = entry['position']
        if pos in excluded:
            continue
        selected.append(entry)
        excluded.update(range(pos - min_spacing, pos + min_spacing + 1))
        if len(selected) >= n_positions:
            break

    return sorted(selected, key=lambda x: x['position'])


# ────────────────────────────────────────────────────────────
# Main
# ────────────────────────────────────────────────────────────
def main():
    # 1. SETUP ───────────────────────────────────────────────
    model_path = "Qwen/Qwen2.5-3B-Instruct"
    data_dir = "./data"
    num_samples = 10
    temperature = 0.7          # keeps some entropy variation

    system_prompt = (
        "Respond in the following format, with only the final answer between the <answer> tags "
        "and always put your answer in boxed:\n"
        "<reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>"
    )

    print("Loading model & tokenizer …")
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
    llm = LLM(model_path, tensor_parallel_size=1, max_logprobs=40)

    print("Loading MATH dataset …")
    examples = load_data("math", "test", data_dir)[:num_samples]

    # 2. PROCESS LOOP ────────────────────────────────────────
    results = []
    print(f"Processing {num_samples} questions …")
    for idx, example in enumerate(tqdm(examples)):
        question = parse_question(example, "math")

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": question}
        ]
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        params = SamplingParams(
            temperature=temperature,
            top_p=0.95,
            max_tokens=1024,
            logprobs=40,
            stop=["</s>"]
        )
        output = llm.generate([prompt], params)[0].outputs[0]

        token_ids = output.token_ids
        tokens = [tokenizer.decode([tid]) for tid in token_ids]

        # 2‑A. ENTROPY PER TOKEN
        entropies = []
        if output.logprobs:
            for pos, lp_dict in enumerate(output.logprobs):
                if lp_dict is None:
                    continue
                entropies.append({
                    'position': pos,
                    'token': tokens[pos],
                    'token_id': token_ids[pos],
                    'entropy': calculate_entropy(lp_dict, temperature)
                })

        # 2‑B. PICK POSITIONS
        spaced_high_entropy = select_spaced_high_entropy_positions(entropies)
        high_positions = {e['position'] for e in spaced_high_entropy}

        # 2‑C. BUILD HIGHLIGHTED TEXT
        highlighted_tokens = []
        for i, tok in enumerate(tokens):
            if i in high_positions:
                highlighted_tokens.append(f"<mark>{tok}</mark>")
            else:
                highlighted_tokens.append(tok)
        highlighted_text = "".join(highlighted_tokens)

        # 2‑D. SPLIT INTO STEP‑WISE SLICES
        boundaries = [0] + sorted(high_positions) + [len(tokens)]
        step_slices: list[str] = []
        for s, e in zip(boundaries[:-1], boundaries[1:]):
            step_slices.append("".join(tokens[s:e]))

        # 2‑E. BUILD RESULT ENTRY
        result_entry = {
            'question_idx': example['idx'],
            'question': question[:200] + '...' if len(question) > 200 else question,
            'generated_text': output.text,              # raw
            'generated_text_highlighted': highlighted_text,
            'total_tokens': len(tokens),
            'spaced_high_entropy_positions': spaced_high_entropy,
            'average_entropy': float(np.mean([e['entropy'] for e in entropies])) if entropies else 0.0,
            'max_entropy': float(np.max([e['entropy'] for e in entropies])) if entropies else 0.0,
            'min_entropy': float(np.min([e['entropy'] for e in entropies])) if entropies else 0.0
        }

        # attach step‑wise pieces as generated_step1, 2, …
        for i, piece in enumerate(step_slices, start=1):
            result_entry[f'generated_step{i}'] = piece

        results.append(result_entry)

    # 3. WRITE JSON ──────────────────────────────────────────
    output_data = {
        'model': model_path,
        'temperature': temperature,
        'num_samples': num_samples,
        'results': results
    }

    out_path = 'tokenentropy_highlight.json'
    with open(out_path, 'w') as f:
        json.dump(output_data, f, indent=2)

    # 4. SUMMARY ─────────────────────────────────────────────
    avg_entropies = [r['average_entropy'] for r in results]
    max_entropies = [r['max_entropy'] for r in results]

    print("\n" + "=" * 60)
    print(f"Results saved to {out_path}")
    print(f"Processed {num_samples} questions")
    print("\nOverall entropy stats:")
    print(f"  Mean of average entropies: {np.mean(avg_entropies):.4f}")
    print(f"  Mean of max entropies: {np.mean(max_entropies):.4f}")
    print(f"  Highest single‑token entropy: {np.max(max_entropies):.4f}")


if __name__ == "__main__":
    main()
