#!/usr/bin/env python3
"""
Test script to verify token analysis functionality works correctly with cluster results.
"""

import json
import tempfile
import os
import traceback
from scripts.reformat_cluster_results import (
    analyze_cluster_batch_tokens,
    convert_cluster_results_to_responses,
)


def create_sample_cluster_results():
    """Create sample cluster results for testing."""
    return [
        {
            "id": "colorLeaves_star_adjacency-scale_up_3-analyst-3-full_output-output",
            "completion": """THINKING
This is a graph transformation task. I need to analyze the examples to understand the pattern.

Looking at the input graphs, I can see they have different structures. Let me identify what changes between input and output.

In the examples, it appears that nodes with degree 1 (leaf nodes) are being colored blue while other nodes remain grey.

ANSWER
G describes a graph among nodes 0, 1, 2, 3, 4.
The edges in G are: (0,1) (1,2) (2,3) (3,4).

Node 0 is blue.
Node 4 is blue.""",
            "metadata": {
                "benchmark": "colorLeaves",
                "graph_type": "star",
                "encoding": "adjacency",
                "size_pattern": "scale_up_3",
                "system_prompt": "analyst",
                "question_type": "full_output",
                "target": "output",
                "n_pairs": "3",
            },
        },
        {
            "id": "colorDegree2_random_incident-mixed_3-teacher-3-node_count-output",
            "completion": """THINKING
I need to count how many nodes will be in the output graph after applying the transformation pattern.

Looking at the examples, the transformation seems to be coloring certain nodes but not changing the graph structure.

The input graph has 5 nodes, so the output should also have 5 nodes.

ANSWER
5""",
            "metadata": {
                "benchmark": "colorDegree2",
                "graph_type": "random",
                "encoding": "incident",
                "size_pattern": "mixed_3",
                "system_prompt": "teacher",
                "question_type": "node_count",
                "target": "output",
                "n_pairs": "3",
            },
        },
        {
            "id": "addHub_tree_adjacency-large_2-none-2-is_connected-input",
            "completion": """THINKING
I need to determine if the input graph is connected.

A graph is connected if there is a path between every pair of nodes.

Looking at the input graph structure, I can trace paths between all nodes.

ANSWER
Yes""",
            "metadata": {
                "benchmark": "addHub",
                "graph_type": "tree",
                "encoding": "adjacency",
                "size_pattern": "large_2",
                "system_prompt": "none",
                "question_type": "is_connected",
                "target": "input",
                "n_pairs": "2",
            },
        },
    ]


def test_token_analysis():
    """Test the token analysis functionality."""
    print("🧪 Testing token analysis functionality...")

    # Create sample data
    sample_results = create_sample_cluster_results()
    model_name = "qwen3-32b"

    # Test token analysis
    print("\n1. Testing batch token analysis...")
    batch_token_data = analyze_cluster_batch_tokens(
        sample_results, model_name, verbose=True
    )

    # Verify results
    assert batch_token_data["model"] == model_name
    assert batch_token_data["total_responses"] == 3
    assert batch_token_data["responses_with_tokens"] > 0
    assert len(batch_token_data["per_response_tokens"]) == 3

    print("✅ Batch token analysis passed!")

    # Test individual response token data
    print("\n2. Testing individual response token data...")
    for response_id, token_data in batch_token_data["per_response_tokens"].items():
        assert "output_tokens" in token_data
        assert "reasoning_tokens" in token_data
        assert "answer_tokens" in token_data
        assert "has_thinking_section" in token_data
        assert "has_answer_section" in token_data

        print(f"   Response {response_id}:")
        print(f"     Output tokens: {token_data['output_tokens']}")
        print(f"     Reasoning tokens: {token_data['reasoning_tokens']}")
        print(f"     Answer tokens: {token_data['answer_tokens']}")
        print(f"     Has THINKING: {token_data['has_thinking_section']}")
        print(f"     Has ANSWER: {token_data['has_answer_section']}")

    print("✅ Individual token data passed!")

    return batch_token_data


def test_full_conversion():
    """Test the full conversion process with token analysis."""
    print("\n🧪 Testing full conversion process...")

    # Create temporary files
    with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
        cluster_file = f.name
        json.dump(create_sample_cluster_results(), f)

    try:
        # Test conversion with token analysis
        print("\n3. Testing full conversion with token analysis...")
        stats = convert_cluster_results_to_responses(
            cluster_file,
            model_name="qwen3-32b-test",
            verbose=True,
            analyze_tokens=True,
            store_individual_tokens=False,  # Only consolidated file
        )

        # Verify conversion stats
        assert stats["total_results"] == 3
        assert stats["converted"] > 0
        assert stats["token_analysis"] is not None
        assert stats["token_data_file"] is not None

        print("✅ Full conversion passed!")

        # Check that token file was created
        if os.path.exists(stats["token_data_file"]):
            print(f"✅ Token data file created: {stats['token_data_file']}")

            # Load and verify token file content
            with open(stats["token_data_file"], "r", encoding="utf-8") as f:
                token_data = json.load(f)

            assert "batch_statistics" in token_data
            assert "per_response_tokens" in token_data
            assert len(token_data["per_response_tokens"]) == 3

            print("✅ Token file content verified!")
        else:
            print("❌ Token data file was not created")

        return stats

    finally:
        # Clean up
        if os.path.exists(cluster_file):
            os.unlink(cluster_file)
        if stats.get("token_data_file") and os.path.exists(stats["token_data_file"]):
            os.unlink(stats["token_data_file"])


def main():
    """Run all tests."""
    print("🚀 Starting token analysis tests...\n")

    try:
        # Test 1: Token analysis only
        batch_data = test_token_analysis()

        # Test 2: Full conversion process
        conversion_stats = test_full_conversion()

        print("\n🎉 All tests passed!")
        print("\nSummary:")
        print(f"✅ Analyzed {batch_data['total_responses']} responses")
        print(
            f"✅ Average output tokens: {batch_data['batch_statistics']['avg_output_tokens']:.1f}"
        )
        print(
            f"✅ Responses with reasoning: {batch_data['batch_statistics']['responses_with_reasoning']}"
        )
        print(
            f"✅ Conversion success rate: {conversion_stats['converted']}/{conversion_stats['total_results']}"
        )

    except AssertionError as e:
        print(f"\n❌ Test failed: {e}")

        traceback.print_exc()
        return False

    return True


if __name__ == "__main__":
    success = main()
    exit(0 if success else 1)
