#!/usr/bin/env python3
"""
Convert cluster-generated results (batch JSON) into individual response files with token analysis.

Enhanced version that:
1. Analyzes tokens for each response during reformatting
2. Creates a consolidated token data file per cluster results batch
3. Optionally stores token data with individual response files
4. Maintains compatibility with existing evaluation pipeline
"""

import argparse
import json
import os
import traceback
from typing import Dict, Optional, List, Any
from datetime import datetime

# Import the existing token analysis utilities
from scripts.utils.token_utils import extract_reasoning_tokens


def analyze_cluster_batch_tokens(
    cluster_results: List[Dict[str, Any]], model_name: str, verbose: bool = False
) -> Dict[str, Any]:
    """
    Analyze tokens for an entire cluster results batch.

    Parameters:
    - cluster_results: List of result dictionaries from cluster
    - model_name: Name of the model used
    - verbose: Whether to print progress

    Returns:
    - Dict with consolidated token analysis
    """
    if verbose:
        print(f"🧠 Analyzing tokens for {len(cluster_results)} responses...")

    batch_token_data = {
        "model": model_name,
        "analysis_timestamp": datetime.now().isoformat(),
        "total_responses": len(cluster_results),
        "responses_with_tokens": 0,
        "batch_statistics": {
            "total_input_tokens": 0,
            "total_output_tokens": 0,
            "total_reasoning_tokens": 0,
            "total_answer_tokens": 0,
            "avg_input_tokens": 0,
            "avg_output_tokens": 0,
            "avg_reasoning_tokens": 0,
            "responses_with_reasoning": 0,
        },
        "per_response_tokens": {},
    }

    total_output = 0
    total_reasoning = 0
    total_answer = 0
    responses_with_reasoning = 0
    valid_responses = 0

    for i, result in enumerate(cluster_results):
        try:
            response_id = result.get("id", f"response_{i}")
            completion = result.get("completion", "")

            if not completion:
                continue

            # We don't have the original prompt in cluster results, so we'll work with what we have
            # For cluster results, we'll focus on analyzing the response content
            reasoning_breakdown = extract_reasoning_tokens(completion)

            # Create simplified token usage data
            token_usage = {
                "input_tokens": None,  # Not available for cluster results
                "output_tokens": reasoning_breakdown["total_response_tokens"],
                "total_tokens": reasoning_breakdown[
                    "total_response_tokens"
                ],  # Just output since we don't have input
                "reasoning_tokens": reasoning_breakdown["reasoning_tokens"],
                "answer_tokens": reasoning_breakdown["answer_tokens"],
                "estimation_method": "estimated_from_response",
                "has_thinking_section": reasoning_breakdown["has_thinking_section"],
                "has_answer_section": reasoning_breakdown["has_answer_section"],
                "model": model_name,
            }

            # Calculate reasoning ratio if we have both parts
            if token_usage["reasoning_tokens"] and token_usage["answer_tokens"]:
                total_response = (
                    token_usage["reasoning_tokens"] + token_usage["answer_tokens"]
                )
                token_usage["reasoning_ratio"] = (
                    token_usage["reasoning_tokens"] / total_response
                )

            # Store per-response data
            batch_token_data["per_response_tokens"][response_id] = token_usage

            metadata = result.get("metadata", {})
            if metadata.get("question_type") == "full_output" and metadata.get("target") == "output":
                if not response_id.endswith("-full_output-output"):
                    suffixed_id = f"{response_id}-full_output-output"
                    batch_token_data["per_response_tokens"][suffixed_id] = token_usage

            # Update batch statistics
            if token_usage["output_tokens"]:
                total_output += token_usage["output_tokens"]
                valid_responses += 1

            if token_usage["reasoning_tokens"] and token_usage["reasoning_tokens"] > 0:
                total_reasoning += token_usage["reasoning_tokens"]
                responses_with_reasoning += 1

            if token_usage["answer_tokens"] and token_usage["answer_tokens"] > 0:
                total_answer += token_usage["answer_tokens"]

        except (KeyError, TypeError, ValueError) as e:
            if verbose:
                print(f"⚠️ Error analyzing tokens for response {i}: {e}")
            continue

    # Calculate batch statistics
    if valid_responses > 0:
        batch_token_data["responses_with_tokens"] = valid_responses
        batch_token_data["batch_statistics"].update(
            {
                "total_output_tokens": total_output,
                "total_reasoning_tokens": total_reasoning,
                "total_answer_tokens": total_answer,
                "avg_output_tokens": total_output / valid_responses,
                "avg_reasoning_tokens": (
                    total_reasoning / valid_responses if total_reasoning > 0 else 0
                ),
                "responses_with_reasoning": responses_with_reasoning,
            }
        )

    if verbose:
        print("✅ Token analysis complete:")
        print(f"   - Analyzed {valid_responses}/{len(cluster_results)} responses")
        print(
            f"   - Average output tokens: {batch_token_data['batch_statistics']['avg_output_tokens']:.1f}"
        )
        print(f"   - Responses with reasoning: {responses_with_reasoning}")

    return batch_token_data


def parse_prompt_metadata(prompt_id: str) -> Optional[Dict[str, str]]:
    """
    Parse prompt metadata from the ID format used in batch results.

    [Keep the existing implementation from the original file]
    """
    try:
        # Split on underscores to separate benchmark and graph type from the rest
        parts = prompt_id.split("_")
        if len(parts) < 3:
            print(f"⚠️ Unexpected prompt ID format: {prompt_id}")
            return None

        benchmark = parts[0]
        graph_type = parts[1]

        # Join the remaining parts and parse the filename-style part
        filename_part = "_".join(parts[2:])

        # Parse the hyphen-separated components
        components = filename_part.split("-")
        if len(components) < 4:
            print(f"⚠️ Insufficient components in prompt ID: {prompt_id}")
            return None

        encoding = components[0]

        # Handle size pattern (could be multiple components joined by hyphens)
        # Work backwards from the end to identify system_prompt and n_pairs
        system_prompt = components[-2] if len(components) >= 4 else "none"
        n_pairs = components[-1] if len(components) >= 4 else "3"

        # Everything between encoding and system_prompt is the size pattern
        size_pattern_parts = components[1:-2] if len(components) > 4 else []
        size_pattern = (
            "-".join(size_pattern_parts) if size_pattern_parts else components[1]
        )

        # Default question type and target
        question_type = "full_output"
        target = "output"

        # Check if this has question-target format (6+ components)
        if len(components) >= 6:
            # Format: encoding-pattern-system-pairs-question-target
            question_type = components[-2] if len(components) >= 6 else "full_output"
            target = components[-1] if len(components) >= 6 else "output"
            n_pairs = components[-3] if len(components) >= 6 else "3"
            system_prompt = components[-4] if len(components) >= 6 else "none"
            # Recalculate size pattern
            size_pattern_parts = components[1:-4] if len(components) > 6 else []
            size_pattern = (
                "-".join(size_pattern_parts) if size_pattern_parts else components[1]
            )

        return {
            "benchmark": benchmark,
            "graph_type": graph_type,
            "encoding": encoding,
            "size_pattern": size_pattern,
            "system_prompt": system_prompt,
            "question_type": question_type,
            "target": target,
            "n_pairs": n_pairs,
            "original_id": prompt_id,
        }

    except (IndexError, ValueError) as e:
        print(f"⚠️ Error parsing prompt ID {prompt_id}: {e}")
        return None


def extract_model_name_from_file(filepath: str) -> str:
    """
    Extract model name from the cluster results filename.

    [Keep the existing implementation from the original file]
    """
    filename = os.path.basename(filepath)
    name, _ = os.path.splitext(filename)
    if "_results" in name:
        model_part = name.split("_results", 1)[0]
    else:
        model_part = name
    return model_part.replace("_", "-").lower()


def generate_response_filename(metadata: Dict[str, str], model_name: str) -> str:
    """
    Generate the response filename following the established naming convention.

    [Keep the existing implementation from the original file]
    """
    encoding = metadata["encoding"]
    size_pattern = metadata["size_pattern"]
    system_prompt = metadata["system_prompt"]
    question_type = metadata["question_type"]
    target = metadata["target"]

    # Always include question type and target to ensure filename metadata is unambiguous
    return f"{encoding}-{size_pattern}-{system_prompt}-{question_type}-{target}-{model_name}.txt"


def convert_cluster_results_to_responses(
    cluster_results_file: str,
    model_name: str = None,
    overwrite: bool = False,
    verbose: bool = False,
    store_individual_tokens: bool = False,
    analyze_tokens: bool = True,
) -> Dict[str, Any]:
    """
    Convert cluster batch results to individual response files with optional token analysis.

    Parameters:
    - cluster_results_file: Path to cluster results JSON
    - model_name: Model name override
    - overwrite: Whether to overwrite existing files
    - verbose: Print progress
    - store_individual_tokens: Whether to create individual _tokens.json files
    - analyze_tokens: Whether to perform token analysis

    Returns:
    - Dictionary with conversion statistics including token data
    """
    if not os.path.exists(cluster_results_file):
        raise FileNotFoundError(
            f"Cluster results file not found: {cluster_results_file}"
        )

    # Determine model name
    if model_name is None:
        model_name = extract_model_name_from_file(cluster_results_file)
        if verbose:
            print(f"🤖 Detected model name: {model_name}")

    # Load cluster results
    with open(cluster_results_file, "r", encoding="utf-8") as f:
        cluster_results = json.load(f)

    if verbose:
        print(f"📊 Loaded {len(cluster_results)} results from {cluster_results_file}")

    stats = {
        "total_results": len(cluster_results),
        "converted": 0,
        "skipped_existing": 0,
        "skipped_invalid": 0,
        "failed": 0,
        "by_benchmark": {},
        "by_question_type": {},
        "token_analysis": {},
        "token_data_file": None,
    }

    # Perform token analysis for the entire batch if requested
    if analyze_tokens:
        try:
            batch_token_data = analyze_cluster_batch_tokens(
                cluster_results, model_name, verbose
            )
            stats["token_analysis"] = batch_token_data["batch_statistics"]

            # Save consolidated token data file
            token_file_path = cluster_results_file.replace(".json", "_tokens.json")
            with open(token_file_path, "w", encoding="utf-8") as f:
                json.dump(batch_token_data, f, indent=2)

            stats["token_data_file"] = token_file_path

            if verbose:
                print(f"💾 Saved consolidated token data to {token_file_path}")

        except (OSError, IOError, json.JSONDecodeError, KeyError, TypeError, ValueError) as e:
            if verbose:
                print(f"⚠️ Token analysis failed: {e}")
            batch_token_data = {"per_response_tokens": {}}

    else:
        batch_token_data = {"per_response_tokens": {}}

    # Process individual responses
    for i, result in enumerate(cluster_results):
        try:
            # Expect each result to be a dict with id, completion, and metadata
            if not isinstance(result, dict):
                print(f"⚠️ Result {i+1}: Invalid entry (not an object)")
                stats["skipped_invalid"] += 1
                continue

            prompt_id = result.get("id")
            completion = result.get("completion")
            metadata = result.get("metadata")

            if not prompt_id or not completion or not isinstance(metadata, dict):
                print(f"⚠️ Result {i+1}: Missing id, completion, or metadata")
                stats["skipped_invalid"] += 1
                continue

            # Normalize key 'pattern' → 'size_pattern' if needed
            metadata = metadata.copy()
            if "pattern" in metadata and "size_pattern" not in metadata:
                metadata["size_pattern"] = metadata.pop("pattern")

            # Ensure minimal required metadata fields are present
            req = [
                "benchmark",
                "graph_type",
                "encoding",
                "size_pattern",
                "system_prompt",
                "question_type",
                "target",
                "n_pairs",
            ]
            missing = [k for k in req if k not in metadata]
            if missing:
                print(f"⚠️ Result {i+1}: Incomplete metadata in JSON; missing {missing}")
                stats["skipped_invalid"] += 1
                continue

            # Track by benchmark and question type
            benchmark = metadata["benchmark"]
            question_type = metadata["question_type"]

            if benchmark not in stats["by_benchmark"]:
                stats["by_benchmark"][benchmark] = {"converted": 0, "failed": 0}
            if question_type not in stats["by_question_type"]:
                stats["by_question_type"][question_type] = {"converted": 0, "failed": 0}

            # Generate response filename and path
            response_filename = generate_response_filename(metadata, model_name)
            response_dir = f"datasets/{benchmark}/{metadata['graph_type']}/responses"
            response_path = os.path.join(response_dir, response_filename)

            # Check if response already exists
            if os.path.exists(response_path) and not overwrite:
                stats["skipped_existing"] += 1
                if verbose:
                    print(f"   ⏩ Exists: {response_path}")
                continue

            # Create response directory
            os.makedirs(response_dir, exist_ok=True)

            # Write response file
            with open(response_path, "w", encoding="utf-8") as f:
                f.write(completion)

            # Optionally save individual token file
            if store_individual_tokens and analyze_tokens:
                token_usage = batch_token_data["per_response_tokens"].get(prompt_id)
                if token_usage:
                    token_output_path = response_path.replace(".txt", "_tokens.json")
                    with open(token_output_path, "w", encoding="utf-8") as f:
                        json.dump(token_usage, f, indent=2)

            stats["converted"] += 1
            stats["by_benchmark"][benchmark]["converted"] += 1
            stats["by_question_type"][question_type]["converted"] += 1

            if verbose:
                print(f"   ✅ Converted: {response_path}")

        except (
            OSError,
            IOError,
            json.JSONDecodeError,
            KeyError,
            TypeError,
            ValueError,
        ) as e:
            print(f"❌ Error processing result {i+1}: {e}")
            stats["failed"] += 1
            if benchmark in stats["by_benchmark"]:
                stats["by_benchmark"][benchmark]["failed"] += 1
            if question_type in stats["by_question_type"]:
                stats["by_question_type"][question_type]["failed"] += 1

    return stats


def main():
    parser = argparse.ArgumentParser(
        description="Convert a batch JSON of cluster results into individual response files with token analysis"
    )

    parser.add_argument(
        "cluster_results_file",
        help="Path to the cluster results JSON file (e.g., qwen_results.json)",
    )

    parser.add_argument(
        "--model-name", help="Override model name (default: auto-detect from filename)"
    )

    parser.add_argument(
        "--overwrite",
        action="store_true",
        help="Overwrite existing response files",
    )

    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Show detailed progress",
    )

    parser.add_argument(
        "--store-individual-tokens",
        action="store_true",
        help="Create individual _tokens.json files (default: only consolidated file)",
    )

    parser.add_argument(
        "--no-token-analysis",
        action="store_true",
        help="Skip token analysis entirely",
    )

    args = parser.parse_args()

    if args.verbose:
        print("🔄 Converting cluster results to individual response files...")
        if not args.no_token_analysis:
            print("🧠 Token analysis enabled...")

    try:
        stats = convert_cluster_results_to_responses(
            args.cluster_results_file,
            model_name=args.model_name,
            overwrite=args.overwrite,
            verbose=args.verbose,
            store_individual_tokens=args.store_individual_tokens,
            analyze_tokens=not args.no_token_analysis,
        )

        # Print conversion summary
        print("\n" + "=" * 60)
        print("CLUSTER RESULTS CONVERSION SUMMARY")
        print("=" * 60)
        print(f"Total results: {stats['total_results']:,}")
        print(f"Successfully converted: {stats['converted']:,}")
        print(f"Skipped (already exist): {stats['skipped_existing']:,}")
        print(f"Skipped (invalid): {stats['skipped_invalid']:,}")
        print(f"Failed: {stats['failed']:,}")

        if stats["token_analysis"]:
            print("\nToken Analysis Summary:")
            print(
                f"Responses analyzed: {stats['token_analysis']['total_output_tokens'] and 'Yes' or 'No'}"
            )
            print(
                f"Average output tokens: {stats['token_analysis']['avg_output_tokens']:.1f}"
            )
            print(
                f"Responses with reasoning: {stats['token_analysis']['responses_with_reasoning']}"
            )
            if stats["token_data_file"]:
                print(f"Consolidated token data: {stats['token_data_file']}")

        if stats["by_benchmark"]:
            print("\nBy benchmark:")
            for benchmark, bstats in stats["by_benchmark"].items():
                print(
                    f"  {benchmark}: {bstats['converted']} converted, {bstats['failed']} failed"
                )

        if stats["by_question_type"]:
            print("\nBy question type:")
            for qtype, qstats in stats["by_question_type"].items():
                print(
                    f"  {qtype}: {qstats['converted']} converted, {qstats['failed']} failed"
                )

        print("=" * 60)

    except (
        OSError,
        IOError,
        json.JSONDecodeError,
        KeyError,
        TypeError,
        ValueError,
    ) as e:
        print(f"❌ Error during conversion: {e}")
        if args.verbose:
            traceback.print_exc()


if __name__ == "__main__":
    main()
