#!/usr/bin/env python3
"""
Lightweight script for the figure's calculation:
1. Sample 1000 aligned trajectories from three files.
2. Compute R_context with distinct 10-gram ratio.
3. Compute R_semantic as 1 - average sentence similarity.
4. Report Pearson correlation between R_context and R_semantic.

It assumes the three files are already aligned by problem order.
"""

from __future__ import annotations

import argparse
import csv
import json
import re
import sys
from pathlib import Path
from typing import Any, Dict, List, Tuple

import numpy as np
from scipy.stats import pearsonr

ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(ROOT / "verl041" / "verl" / "recipe" / "reward_ours"))

from semantic_repetition import VLLMEmbeddingModel  # noqa: E402


DEFAULT_BASELINE = (
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/"
    "project_tts_extrapolation/results_dec/"
    "gspo-step500-valid-all_32768_test.jsonl"
)
DEFAULT_LIE = (
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/"
    "project_tts_extrapolation/results_dec/"
    "gspo-skip-right-step600-valid-all_32768_test.jsonl"
)
DEFAULT_BASE_MODEL = (
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/"
    "project_tts_extrapolation/results_nov/"
    "Qwen3-4B-Base-valid-all_32768_test.jsonl"
)


def load_records(path: str) -> List[Dict[str, Any]]:
    records: List[Dict[str, Any]] = []
    with open(path, "r", encoding="utf-8") as f:
        for line_no, line in enumerate(f, start=1):
            line = line.strip()
            if not line:
                continue
            try:
                item = json.loads(line)
            except json.JSONDecodeError as e:
                preview = line[:200]
                raise ValueError(
                    f"invalid json in {path}:{line_no}\n"
                    f"line preview: {preview!r}"
                ) from e
            text = item.get("generated_text") or item.get(
                "output") or item.get("response") or ""
            if isinstance(text, list):
                text = text[0] if text else ""
            if not isinstance(text, str):
                text = str(text)

            records.append(
                {
                    "index": line_no - 1,
                    "prompt": item.get("problem") or item.get("prompt") or item.get("input") or "",
                    "text": text.replace("<|endoftext|>", "").strip(),
                    "correctness": item.get("correctness", item.get("score")),
                }
            )
    return records


def split_text_to_words(text: str) -> List[str]:
    words: List[str] = []
    for segment in text.split():
        words.extend(segment.split("_"))
    return words


def calculate_distinct_ngram_ratio(text: str, n: int = 10) -> float:
    words = split_text_to_words(text)
    if len(words) < n:
        return 1.0

    ngrams = {tuple(words[i: i + n]) for i in range(len(words) - n + 1)}
    total = len(words) - n + 1
    return len(ngrams) / total if total > 0 else 1.0


def split_steps(text: str) -> List[str]:
    text = text.replace("\r\n", "\n").replace("\r", "\n").strip()
    if not text:
        return []

    raw_parts = re.split(r"\n+|(?<=[。！？!?；;])\s*|(?<=\.)\s+", text)
    steps = [re.sub(r"\s+", " ", part).strip()
             for part in raw_parts if part.strip()]
    return steps if steps else [text]


def batched_embeddings(texts: List[str], model: VLLMEmbeddingModel, batch_size: int) -> np.ndarray:
    all_embeddings: List[np.ndarray] = []
    for start in range(0, len(texts), batch_size):
        batch = texts[start: start + batch_size]
        if not batch:
            continue
        all_embeddings.append(model.get_embeddings_batch(batch))
    return np.concatenate(all_embeddings, axis=0) if all_embeddings else np.empty((0, 0))


def calculate_semantic_non_repetition(
    texts: List[str],
    model: VLLMEmbeddingModel,
    batch_size: int = 256,
) -> List[Dict[str, float]]:
    all_steps: List[str] = []
    boundaries: List[Tuple[int, int]] = []

    for text in texts:
        start = len(all_steps)
        steps = split_steps(text)
        all_steps.extend(steps)
        boundaries.append((start, len(all_steps)))

    embeddings = batched_embeddings(all_steps, model, batch_size)
    if embeddings.size > 0:
        norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
        norms = np.where(norms == 0, 1, norms)
        embeddings = embeddings / norms

    metrics: List[Dict[str, float]] = []
    for start, end in boundaries:
        num_steps = end - start
        if num_steps == 0:
            metrics.append(
                {"r_semantic": 0.0, "num_steps": 0, "avg_similarity": 0.0})
            continue
        if num_steps == 1:
            metrics.append(
                {"r_semantic": 1.0, "num_steps": 1, "avg_similarity": 0.0})
            continue

        current = embeddings[start:end]
        similarity = current @ current.T
        upper_indices = np.triu_indices(num_steps, k=1)
        pairwise_similarities = similarity[upper_indices]
        avg_similarity = float(np.mean(pairwise_similarities)) if len(
            pairwise_similarities) > 0 else 0.0

        metrics.append(
            {
                "r_semantic": 1.0 - avg_similarity,
                "num_steps": num_steps,
                "avg_similarity": avg_similarity,
            }
        )

    return metrics


def summarize(rows: List[Dict[str, Any]]) -> Dict[str, Any]:
    r_context = np.array([row["r_context"] for row in rows], dtype=np.float64)
    r_semantic = np.array([row["r_semantic"]
                          for row in rows], dtype=np.float64)

    if len(rows) >= 2 and np.std(r_context) > 0 and np.std(r_semantic) > 0:
        corr, p_value = pearsonr(r_context, r_semantic)
    else:
        corr, p_value = None, None

    return {
        "num_samples": len(rows),
        "mean_r_context": float(np.mean(r_context)) if len(rows) else None,
        "std_r_context": float(np.std(r_context)) if len(rows) else None,
        "mean_r_semantic": float(np.mean(r_semantic)) if len(rows) else None,
        "std_r_semantic": float(np.std(r_semantic)) if len(rows) else None,
        "pearson_r": None if corr is None else float(corr),
        "pearson_pvalue": None if p_value is None else float(p_value),
        "mean_num_steps": float(np.mean([row["num_steps"] for row in rows])) if rows else None,
    }


def build_rows(records: List[Dict[str, Any]], model: VLLMEmbeddingModel, ngram_n: int, batch_size: int) -> List[Dict[str, Any]]:
    texts = [record["text"] for record in records]
    context_scores = [calculate_distinct_ngram_ratio(
        text, n=ngram_n) for text in texts]
    semantic_scores = calculate_semantic_non_repetition(
        texts=texts,
        model=model,
        batch_size=batch_size,
    )

    rows: List[Dict[str, Any]] = []
    for record, r_context, semantic in zip(records, context_scores, semantic_scores):
        rows.append(
            {
                "index": record["index"],
                "prompt": str(record["prompt"])[:160],
                "correctness": record["correctness"],
                "r_context": float(r_context),
                "r_semantic": float(semantic["r_semantic"]),
                "num_steps": int(semantic["num_steps"]),
                "avg_similarity": float(semantic["avg_similarity"]),
            }
        )
    return rows


def save_long_csv(output_path: Path, all_rows: Dict[str, List[Dict[str, Any]]]) -> None:
    with output_path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(
            f,
            fieldnames=[
                "model",
                "index",
                "correctness",
                "r_context",
                "r_semantic",
                "num_steps",
                "avg_similarity",
                "prompt",
            ],
        )
        writer.writeheader()
        for model_name, rows in all_rows.items():
            for row in rows:
                writer.writerow({"model": model_name, **row})


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--baseline", type=str, default=DEFAULT_BASELINE)
    parser.add_argument("--lie", type=str, default=DEFAULT_LIE)
    parser.add_argument("--base-model", dest="base_model",
                        type=str, default=DEFAULT_BASE_MODEL)
    parser.add_argument("--sample-size", type=int, default=1000)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--ngram-n", type=int, default=10)
    parser.add_argument("--embedding-base-url", type=str,
                        default="http://100.102.3.50:8000/v1")
    parser.add_argument("--embedding-model-name", type=str, default="embed")
    parser.add_argument("--embedding-batch-size", type=int, default=256)
    parser.add_argument(
        "--output-dir",
        type=str,
        default="eval_scripts/analysis/context_semantic_correlation",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    paths = {
        "baseline": args.baseline,
        "lie": args.lie,
        "base_model": args.base_model,
    }

    all_data = {name: load_records(path) for name, path in paths.items()}
    shared_size = min(len(records) for records in all_data.values())
    sample_size = min(args.sample_size, shared_size)

    rng = np.random.default_rng(args.seed)
    sampled_indices = sorted(rng.choice(
        shared_size, size=sample_size, replace=False).tolist())

    model = VLLMEmbeddingModel(
        base_url=args.embedding_base_url,
        model_name=args.embedding_model_name,
    )

    all_rows: Dict[str, List[Dict[str, Any]]] = {}
    summaries: Dict[str, Any] = {}
    for name, records in all_data.items():
        sampled_records = [records[i] for i in sampled_indices]
        rows = build_rows(
            records=sampled_records,
            model=model,
            ngram_n=args.ngram_n,
            batch_size=args.embedding_batch_size,
        )
        all_rows[name] = rows
        summaries[name] = summarize(rows)

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    with (output_dir / "summary.json").open("w", encoding="utf-8") as f:
        json.dump(
            {
                "config": {
                    "paths": paths,
                    "sample_size": sample_size,
                    "seed": args.seed,
                    "ngram_n": args.ngram_n,
                    "embedding_base_url": args.embedding_base_url,
                    "embedding_model_name": args.embedding_model_name,
                },
                "sampled_indices": sampled_indices,
                "results": summaries,
            },
            f,
            ensure_ascii=False,
            indent=2,
        )

    save_long_csv(output_dir / "per_sample.csv", all_rows)

    print(f"sample_size: {sample_size}")
    for name, summary in summaries.items():
        print(
            f"{name}: "
            f"mean_r_context={summary['mean_r_context']:.6f}, "
            f"mean_r_semantic={summary['mean_r_semantic']:.6f}, "
            f"pearson_r={summary['pearson_r']}"
        )
    print(f"saved to: {output_dir}")


if __name__ == "__main__":
    main()
