#!/usr/bin/env python3
"""Phase 0: Novelty Signal Analysis for GSPO Novelty-on-SFT Reward.

Analyzes lm_eval pass@k JSONL outputs to determine whether there is meaningful
variance in SFT perplexity among correct puzzle rollouts — the prerequisite
for the novelty bonus to have signal during GSPO training.

Two stages:
  Stage 1 (fast): Correctness scoring + diversity metrics (lengths, reasoning text)
  Stage 2 (GPU): SFT perplexity computation via vLLM prompt_logprobs

Usage:
  # Stage 1 only (no GPU needed):
  python scripts/analysis/novelty_signal_analysis.py \
      --results_dir results/novelty_signal/sft_v2 \
      --stage 1

  # Both stages:
  python scripts/analysis/novelty_signal_analysis.py \
      --results_dir results/novelty_signal/sft_v2 \
      --sft_model checkpoints/olmo3_7b_multi_puzzle_dsr_v2/merged_ep5_fp32 \
      --stage 2
"""

import argparse
import json
import os
import re
import sys
from collections import defaultdict
from pathlib import Path

import numpy as np

# Add project root for reward function imports
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(PROJECT_ROOT))

sys.path.insert(0, str(PROJECT_ROOT / "evaluate" / "custom_tasks" / "puzzle" / "_common"))
from scoring import make_scorer, _normalize_solution


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def detect_puzzle_type(task_name: str) -> str:
    """Infer puzzle type from lm_eval task name."""
    for p in ["bridges", "galaxies", "pattern", "undead", "loopy", "singles"]:
        if p in task_name.lower():
            return p
    return "generic"


def extract_reasoning(response: str) -> str | None:
    """Extract <reasoning>...</reasoning> section from a response."""
    m = re.search(r"<reasoning>(.*?)</reasoning>", response, re.DOTALL)
    return m.group(1).strip() if m else None


def load_jsonl_results(results_dir: str) -> list[dict]:
    """Load all lm_eval samples JSONL files from a results directory.

    Returns list of dicts with keys: task, doc_id, doc, responses, scores.
    """
    records = []
    jsonl_files = list(Path(results_dir).rglob("samples_*.jsonl"))
    if not jsonl_files:
        print(f"ERROR: No samples_*.jsonl files found in {results_dir}")
        sys.exit(1)

    for fpath in sorted(jsonl_files):
        # Infer task name from filename: samples_TASKNAME_TIMESTAMP.jsonl
        fname = fpath.stem  # e.g. samples_bridges_5x5dm_pass8_train_2026-04-20T...
        # Remove "samples_" prefix and timestamp suffix
        parts = fname.split("_")
        # Find timestamp part (starts with 202)
        ts_idx = next((i for i, p in enumerate(parts) if p.startswith("202")), len(parts))
        task_name = "_".join(parts[1:ts_idx])  # skip "samples" prefix

        puzzle_type = detect_puzzle_type(task_name)
        _, score_single = make_scorer(puzzle_type)

        with open(fpath) as f:
            for line in f:
                d = json.loads(line)
                responses = d["resps"][0]  # list of all repeats
                doc = d["doc"]

                # Score each response
                scores = [score_single(resp, doc) for resp in responses]

                records.append({
                    "task": task_name,
                    "puzzle_type": puzzle_type,
                    "doc_id": d["doc_id"],
                    "doc": doc,
                    "responses": responses,
                    "scores": scores,
                })

    return records


# ---------------------------------------------------------------------------
# Stage 1: Diversity metrics (no GPU)
# ---------------------------------------------------------------------------

def stage1_diversity_analysis(records: list[dict]) -> dict:
    """Compute diversity metrics across correct responses for each prompt."""
    results = {
        "per_task": {},
        "overall": {},
    }

    task_groups = defaultdict(list)
    for r in records:
        task_groups[r["task"]].append(r)

    all_within_prompt_stds = []
    all_n_correct = []
    all_multi_correct_pct = []

    for task, task_records in sorted(task_groups.items()):
        n_problems = len(task_records)
        n_correct_per_prompt = []
        unique_per_prompt = []
        length_stds = []
        multi_correct_count = 0
        reasoning_present = 0
        total_responses = 0

        for rec in task_records:
            resps = rec["responses"]
            scores = rec["scores"]
            n_repeats = len(resps)
            total_responses += n_repeats

            correct_resps = [r for r, s in zip(resps, scores) if s == 1]
            n_correct = len(correct_resps)
            n_correct_per_prompt.append(n_correct)
            unique_per_prompt.append(len(set(resps)))

            # Count reasoning tags
            for r in resps:
                if extract_reasoning(r) is not None:
                    reasoning_present += 1

            if n_correct >= 2:
                multi_correct_count += 1
                # Length diversity among correct responses
                correct_lengths = [len(r) for r in correct_resps]
                length_stds.append(np.std(correct_lengths))

                # Reasoning length diversity
                reasoning_lengths = []
                for r in correct_resps:
                    reasoning = extract_reasoning(r)
                    if reasoning:
                        reasoning_lengths.append(len(reasoning))
                if len(reasoning_lengths) >= 2:
                    all_within_prompt_stds.append(np.std(reasoning_lengths))

        n_correct_arr = np.array(n_correct_per_prompt)
        unique_arr = np.array(unique_per_prompt)

        task_result = {
            "n_problems": n_problems,
            "n_repeats": n_repeats,
            "total_responses": total_responses,
            "reasoning_tag_rate": reasoning_present / total_responses if total_responses > 0 else 0,
            "mean_correct_per_prompt": float(np.mean(n_correct_arr)),
            "mean_unique_per_prompt": float(np.mean(unique_arr)),
            "multi_correct_prompts": multi_correct_count,
            "multi_correct_pct": multi_correct_count / n_problems if n_problems > 0 else 0,
            "pass_at_1": float(np.mean(n_correct_arr > 0)),
            "mean_length_std_correct": float(np.mean(length_stds)) if length_stds else 0,
        }
        results["per_task"][task] = task_result

        all_n_correct.extend(n_correct_per_prompt)
        all_multi_correct_pct.append(task_result["multi_correct_pct"])

        print(f"\n{'='*60}")
        print(f"Task: {task}")
        print(f"{'='*60}")
        print(f"  Problems: {n_problems}, Repeats: {n_repeats}")
        print(f"  <reasoning> tag rate: {task_result['reasoning_tag_rate']:.1%}")
        print(f"  Mean correct/prompt: {task_result['mean_correct_per_prompt']:.1f}/{n_repeats}")
        print(f"  Mean unique/prompt: {task_result['mean_unique_per_prompt']:.1f}/{n_repeats}")
        print(f"  Prompts with ≥2 correct: {multi_correct_count}/{n_problems} ({task_result['multi_correct_pct']:.0%})")
        print(f"  Pass@1: {task_result['pass_at_1']:.1%}")
        if length_stds:
            print(f"  Mean response length std (correct only): {task_result['mean_length_std_correct']:.0f} chars")

    # Overall summary
    results["overall"] = {
        "total_problems": len(records),
        "total_responses": sum(len(r["responses"]) for r in records),
        "mean_within_prompt_reasoning_len_std": float(np.mean(all_within_prompt_stds)) if all_within_prompt_stds else 0,
        "median_within_prompt_reasoning_len_std": float(np.median(all_within_prompt_stds)) if all_within_prompt_stds else 0,
        "pct_prompts_with_reasoning_diversity": (
            sum(1 for s in all_within_prompt_stds if s > 500) / len(all_within_prompt_stds)
            if all_within_prompt_stds else 0
        ),
    }

    print(f"\n{'='*60}")
    print(f"OVERALL SUMMARY")
    print(f"{'='*60}")
    print(f"  Total problems: {results['overall']['total_problems']}")
    print(f"  Total responses: {results['overall']['total_responses']}")
    if all_within_prompt_stds:
        print(f"  Within-prompt reasoning length std (among correct):")
        print(f"    Mean: {results['overall']['mean_within_prompt_reasoning_len_std']:.0f} chars")
        print(f"    Median: {results['overall']['median_within_prompt_reasoning_len_std']:.0f} chars")
        print(f"    % prompts with std > 500 chars: {results['overall']['pct_prompts_with_reasoning_diversity']:.0%}")

    return results


# ---------------------------------------------------------------------------
# Stage 2: SFT perplexity via vLLM prompt_logprobs
# ---------------------------------------------------------------------------

def stage2_sft_perplexity(records: list[dict], sft_model: str, system_instruction: str) -> dict:
    """Compute SFT perplexity of reasoning sections using vLLM prompt_logprobs.

    For each response, constructs the full chat sequence (system + user + assistant response),
    feeds it to vLLM as a "prompt" with prompt_logprobs=1, then extracts per-token log-probs
    for the response portion. Mean negative log-prob = perplexity proxy.
    """
    from transformers import AutoTokenizer
    from vllm import LLM, SamplingParams

    print(f"\nLoading SFT model for perplexity: {sft_model}")
    tokenizer = AutoTokenizer.from_pretrained(sft_model, trust_remote_code=True)
    llm = LLM(
        model=sft_model,
        gpu_memory_utilization=0.90,
        max_model_len=20000,
        trust_remote_code=True,
    )

    # Build prompt template function (same as lm_eval uses)
    sys.path.insert(0, str(PROJECT_ROOT / "evaluate" / "custom_tasks" / "puzzle" / "_common"))
    from formatting import make_doc_to_text

    template_map = {
        "bridges": "bridges_intformat.txt",
        "pattern": "pattern_intformat.txt",
        "undead": "undead_intformat.txt",
        "galaxies": "galaxies_intformat.txt",
    }

    # Collect all (prompt, response, metadata) tuples
    items = []
    for rec in records:
        puzzle_type = rec["puzzle_type"]
        template_name = template_map.get(puzzle_type, f"{puzzle_type}_intformat.txt")
        doc_to_text = make_doc_to_text(template_name)
        user_message = doc_to_text(rec["doc"])

        for j, (resp, score) in enumerate(zip(rec["responses"], rec["scores"])):
            reasoning = extract_reasoning(resp)
            if reasoning is None:
                continue  # skip responses without reasoning tags

            # Build full chat sequence: system + user + assistant response
            messages = [
                {"role": "system", "content": system_instruction},
                {"role": "user", "content": user_message},
                {"role": "assistant", "content": resp},
            ]
            full_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)

            # Build prompt-only (without assistant response) to find response start
            prompt_messages = [
                {"role": "system", "content": system_instruction},
                {"role": "user", "content": user_message},
            ]
            prompt_text = tokenizer.apply_chat_template(prompt_messages, tokenize=False, add_generation_prompt=True)

            prompt_token_count = len(tokenizer.encode(prompt_text))

            items.append({
                "full_text": full_text,
                "prompt_token_count": prompt_token_count,
                "task": rec["task"],
                "doc_id": rec["doc_id"],
                "resp_idx": j,
                "score": score,
                "reasoning_len": len(reasoning),
                "response_len": len(resp),
            })

    print(f"  Computing perplexity for {len(items)} responses...")

    # Use vLLM prompt_logprobs: feed full sequence as "prompt", get per-token log-probs
    sp = SamplingParams(max_tokens=1, prompt_logprobs=0, temperature=0)
    # prompt_logprobs=0 means don't return per-token details, but we need them
    # Actually, prompt_logprobs=1 returns top-1 logprob per position
    sp = SamplingParams(max_tokens=1, prompt_logprobs=1, temperature=0)

    # Process in batches to avoid OOM
    BATCH_SIZE = 32
    all_neg_logprobs = []

    for batch_start in range(0, len(items), BATCH_SIZE):
        batch = items[batch_start:batch_start + BATCH_SIZE]
        prompts = [item["full_text"] for item in batch]

        outputs = llm.generate(prompts, sp)

        for item, output in zip(batch, outputs):
            # Extract log-probs for response tokens only
            prompt_logprobs = output.prompt_logprobs  # list of dicts, one per token
            if prompt_logprobs is None:
                item["mean_neg_logprob"] = float("nan")
                continue

            # Skip prompt tokens, get response token log-probs
            response_logprobs = []
            for pos in range(item["prompt_token_count"], len(prompt_logprobs)):
                lp = prompt_logprobs[pos]
                if lp is not None:
                    # lp is a dict {token_id: Logprob} — get the actual token's logprob
                    # The token at this position's logprob is the first value
                    for token_id, logprob_obj in lp.items():
                        response_logprobs.append(logprob_obj.logprob)
                        break

            if response_logprobs:
                mean_neg_lp = -np.mean(response_logprobs)
                item["mean_neg_logprob"] = float(mean_neg_lp)
            else:
                item["mean_neg_logprob"] = float("nan")

        done = min(batch_start + BATCH_SIZE, len(items))
        print(f"    Processed {done}/{len(items)} responses", end="\r")

    print()

    # Analyze perplexity results
    return _analyze_perplexity(items, records)


def _analyze_perplexity(items: list[dict], records: list[dict]) -> dict:
    """Analyze SFT perplexity distributions."""
    # Group by (task, doc_id)
    groups = defaultdict(list)
    for item in items:
        if not np.isnan(item["mean_neg_logprob"]):
            groups[(item["task"], item["doc_id"])].append(item)

    # Within-prompt variance among correct responses
    within_prompt_stds = []
    within_prompt_stds_by_task = defaultdict(list)
    correct_ppls = []
    incorrect_ppls = []

    for (task, doc_id), group_items in groups.items():
        correct = [it for it in group_items if it["score"] == 1]
        incorrect = [it for it in group_items if it["score"] == 0]

        for it in correct:
            correct_ppls.append(it["mean_neg_logprob"])
        for it in incorrect:
            incorrect_ppls.append(it["mean_neg_logprob"])

        if len(correct) >= 2:
            ppls = [it["mean_neg_logprob"] for it in correct]
            std = np.std(ppls)
            within_prompt_stds.append(std)
            within_prompt_stds_by_task[task].append(std)

    correct_ppls = np.array(correct_ppls)
    incorrect_ppls = np.array(incorrect_ppls)
    within_prompt_stds = np.array(within_prompt_stds)

    print(f"\n{'='*60}")
    print("SFT PERPLEXITY ANALYSIS")
    print(f"{'='*60}")

    print(f"\n  Correct responses: {len(correct_ppls)}")
    if len(correct_ppls) > 0:
        print(f"    Mean neg-logprob: {np.mean(correct_ppls):.4f}")
        print(f"    Std neg-logprob:  {np.std(correct_ppls):.4f}")
    print(f"  Incorrect responses: {len(incorrect_ppls)}")
    if len(incorrect_ppls) > 0:
        print(f"    Mean neg-logprob: {np.mean(incorrect_ppls):.4f}")
        print(f"    Std neg-logprob:  {np.std(incorrect_ppls):.4f}")

    print(f"\n  Within-prompt ppl std (correct only, ≥2 correct):")
    print(f"    N prompts: {len(within_prompt_stds)}")
    if len(within_prompt_stds) > 0:
        print(f"    Mean std:   {np.mean(within_prompt_stds):.4f} nats/token")
        print(f"    Median std: {np.median(within_prompt_stds):.4f} nats/token")
        print(f"    % with std > 0.1: {np.mean(within_prompt_stds > 0.1):.0%}")
        print(f"    % with std > 0.3: {np.mean(within_prompt_stds > 0.3):.0%}")
        print(f"    % with std > 0.5: {np.mean(within_prompt_stds > 0.5):.0%}")

    # Novel correct: correct responses with ppl > mean + 1 std
    if len(correct_ppls) > 0:
        threshold = np.mean(correct_ppls) + np.std(correct_ppls)
        novel_correct = np.sum(correct_ppls > threshold)
        print(f"\n  Novel correct (ppl > mean + 1σ): {novel_correct}/{len(correct_ppls)} ({novel_correct/len(correct_ppls):.0%})")

    # Per-task breakdown
    for task, stds in sorted(within_prompt_stds_by_task.items()):
        stds = np.array(stds)
        print(f"\n  {task}:")
        print(f"    N prompts with ≥2 correct: {len(stds)}")
        if len(stds) > 0:
            print(f"    Mean within-prompt ppl std: {np.mean(stds):.4f}")
            print(f"    % with std > 0.1: {np.mean(stds > 0.1):.0%}")

    results = {
        "n_correct": len(correct_ppls),
        "n_incorrect": len(incorrect_ppls),
        "correct_ppl_mean": float(np.mean(correct_ppls)) if len(correct_ppls) > 0 else None,
        "correct_ppl_std": float(np.std(correct_ppls)) if len(correct_ppls) > 0 else None,
        "incorrect_ppl_mean": float(np.mean(incorrect_ppls)) if len(incorrect_ppls) > 0 else None,
        "within_prompt_std_mean": float(np.mean(within_prompt_stds)) if len(within_prompt_stds) > 0 else None,
        "within_prompt_std_median": float(np.median(within_prompt_stds)) if len(within_prompt_stds) > 0 else None,
        "pct_prompts_std_gt_0.1": float(np.mean(within_prompt_stds > 0.1)) if len(within_prompt_stds) > 0 else None,
        "pct_prompts_std_gt_0.3": float(np.mean(within_prompt_stds > 0.3)) if len(within_prompt_stds) > 0 else None,
        "per_task": {
            task: {
                "n_prompts": len(stds),
                "mean_std": float(np.mean(stds)) if len(stds) > 0 else None,
            }
            for task, stds in within_prompt_stds_by_task.items()
        },
    }
    return results


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser(description="Phase 0: Novelty Signal Analysis")
    parser.add_argument("--results_dir", required=True, help="Path to lm_eval results dir with samples_*.jsonl")
    parser.add_argument("--sft_model", default=None, help="SFT model path for perplexity (stage 2)")
    parser.add_argument("--stage", type=int, default=1, choices=[1, 2], help="1=diversity only, 2=+perplexity")
    parser.add_argument("--output_json", default=None, help="Save results to JSON file")
    parser.add_argument("--system_instruction", default=(
        "A conversation between User and Assistant. The user asks a question, "
        "and the Assistant solves it step by step by reasoning. Provide the reasoning "
        "in <reasoning> reasoning here </reasoning> and the final solution within "
        "<answer> answer here </answer>"
    ))
    args = parser.parse_args()

    print(f"Loading results from: {args.results_dir}")
    records = load_jsonl_results(args.results_dir)
    print(f"Loaded {len(records)} problems, {sum(len(r['responses']) for r in records)} total responses")

    # Stage 1: diversity
    diversity_results = stage1_diversity_analysis(records)

    # Stage 2: perplexity
    ppl_results = None
    if args.stage >= 2:
        if args.sft_model is None:
            print("\nERROR: --sft_model required for stage 2")
            sys.exit(1)
        ppl_results = stage2_sft_perplexity(records, args.sft_model, args.system_instruction)

    # Save results
    if args.output_json:
        output = {
            "diversity": diversity_results,
            "perplexity": ppl_results,
        }
        os.makedirs(os.path.dirname(args.output_json) or ".", exist_ok=True)
        with open(args.output_json, "w") as f:
            json.dump(output, f, indent=2)
        print(f"\nResults saved to: {args.output_json}")


if __name__ == "__main__":
    main()
