#!/usr/bin/env python
"""Compare trend tokens generated by different dataset variants."""

from __future__ import annotations

import argparse
import json
from collections import defaultdict
from collections.abc import Iterable, Sequence
from pathlib import Path
from typing import List


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=(
            "Compare LLM generated trend tokens (e.g., ver_gen9) against a "
            "ground-truth-derived reference (e.g., ver_gen5)."
        )
    )
    parser.add_argument(
        "--reference",
        type=Path,
        default=Path("dataset/FNSPID/ver_gen5/vali.json"),
        help="JSON file whose `news` field is treated as the ground-truth trends.",
    )
    parser.add_argument(
        "--candidate",
        type=Path,
        default=Path("dataset/FNSPID/ver_gen9/vali.json"),
        help="JSON file whose `news` field is evaluated against the reference.",
    )
    parser.add_argument(
        "--max-mismatches",
        type=int,
        default=5,
        help="Number of representative mismatches to print for manual inspection.",
    )
    return parser.parse_args()


def parse_news_sequence(raw_value: Sequence | str) -> List[int]:
    """Convert a news field into an integer sequence of trend tokens."""

    if isinstance(raw_value, str):
        tokens: Iterable[str] = (tok.strip() for tok in raw_value.split(","))
    elif isinstance(raw_value, Iterable):
        tokens = raw_value
    else:
        raise TypeError(f"Unsupported news type: {type(raw_value)}")

    parsed: List[int] = []
    for token in tokens:
        if token is None:
            continue
        token_str = str(token).strip()
        if not token_str:
            continue
        try:
            value = float(token_str)
        except ValueError as exc:  # pragma: no cover - defensive path
            raise ValueError(f"Cannot convert token `{token_str}` to float.") from exc
        parsed.append(int(round(value)))
    if not parsed:
        raise ValueError("Encountered empty news sequence after parsing.")
    return parsed


def load_dataset(path: Path) -> list[dict]:
    with path.open("r", encoding="utf-8") as file:
        return json.load(file)


def compare_trends(
    reference: list[dict],
    candidate: list[dict],
    max_mismatches: int,
) -> dict:
    if len(reference) != len(candidate):
        raise ValueError(
            f"Dataset length mismatch: {len(reference)} vs {len(candidate)}."
        )

    total_steps = 0
    matched_steps = 0
    per_horizon_counts: dict[int, list[int]] = defaultdict(lambda: [0, 0])
    per_class_counts: dict[int, list[int]] = defaultdict(lambda: [0, 0])
    mismatches: list[dict] = []

    for idx, (ref_item, cand_item) in enumerate(zip(reference, candidate)):
        ref_seq = parse_news_sequence(ref_item["news"])
        cand_seq = parse_news_sequence(cand_item["news"])
        if len(ref_seq) != len(cand_seq):
            raise ValueError(
                f"Sequence length mismatch at sample {idx}: "
                f"{len(ref_seq)} vs {len(cand_seq)}."
            )

        for horizon, (ref_token, cand_token) in enumerate(zip(ref_seq, cand_seq)):
            total_steps += 1
            per_horizon_counts[horizon][1] += 1
            per_class_counts[ref_token][1] += 1
            if ref_token == cand_token:
                matched_steps += 1
                per_horizon_counts[horizon][0] += 1
                per_class_counts[ref_token][0] += 1
            elif len(mismatches) < max_mismatches:
                mismatches.append(
                    {
                        "sample_index": idx,
                        "horizon": horizon,
                        "reference": ref_token,
                        "candidate": cand_token,
                        "historical_data": reference[idx].get("historical_data", "")[:80],
                    }
                )

    overall_accuracy = matched_steps / total_steps
    print(f"Compared {total_steps} horizons across {len(reference)} samples.")
    print(f"Overall accuracy: {overall_accuracy:.4%}")

    print("\nPer-horizon accuracy:")
    for horizon in sorted(per_horizon_counts):
        correct, count = per_horizon_counts[horizon]
        accuracy = correct / count if count else 0.0
        print(f"  Step {horizon + 1}: {accuracy:.4%} ({correct}/{count})")

    print("\nPer-class accuracy (reference-conditioned):")
    for trend_value in sorted(per_class_counts):
        correct, count = per_class_counts[trend_value]
        accuracy = correct / count if count else 0.0
        print(f"  Trend {trend_value:+}: {accuracy:.4%} ({correct}/{count})")

    if mismatches:
        print(f"\nFirst {len(mismatches)} mismatches:")
        for miss in mismatches:
            print(
                "  "
                f"sample={miss['sample_index']} "
                f"horizon={miss['horizon'] + 1} "
                f"ref={miss['reference']} "
                f"cand={miss['candidate']} "
                f"historical={miss['historical_data']!r}"
            )
    result = {
        "summary": {
            "total_horizons": total_steps,
            "matched_steps": matched_steps,
            "overall_accuracy": overall_accuracy,
        },
        "per_horizon_accuracy": {
            f"step_{horizon + 1}": {
                "correct": correct,
                "total": count,
                "accuracy": correct / count if count else 0.0
            }
            for horizon, (correct, count) in per_horizon_counts.items()
        },
        "per_class_accuracy": {
            f"trend_{trend_value}": {
                "correct": correct,
                "total": count,
                "accuracy": correct / count if count else 0.0
            }
            for trend_value, (correct, count) in per_class_counts.items()
        },
        "mismatches": mismatches
    }

    return result


def main() -> None:
    args = parse_args()
    reference = load_dataset(args.reference)
    candidate = load_dataset(args.candidate)
    results=compare_trends(reference, candidate, args.max_mismatches)
    output_path = Path("comparison_results_vali.json")  # 默认输出路径
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2)

    print(f"Results saved to {output_path}")

if __name__ == "__main__":
    main()
