#!/usr/bin/env python3
"""
Fast analysis script for QWEN-3-32B response patterns.
Analyzes odd behaviors like repetition, missing sections, etc.
"""

import json
import os
import glob
from collections import Counter
from typing import Dict, List, Any
import argparse


def detect_repetition(text: str) -> Dict[str, Any]:
    """
    Detect repetitive patterns in text.
    Returns info about repetition severity.
    """
    lines = text.split("\n")

    # Count identical consecutive lines
    consecutive_repeats = 0
    max_consecutive = 0
    prev_line = ""

    for line in lines:
        line = line.strip()
        if line and line == prev_line:
            consecutive_repeats += 1
            max_consecutive = max(max_consecutive, consecutive_repeats)
        else:
            consecutive_repeats = 0
        prev_line = line

    # Count repeated chunks (like "THINKING\n\nThe transformation..." blocks)
    chunks = text.split("\n\n")
    chunk_counts = Counter(chunk.strip() for chunk in chunks if chunk.strip())
    max_chunk_repeats = max(chunk_counts.values()) if chunk_counts else 0

    # Simple repetition score
    repetition_score = max_consecutive + (max_chunk_repeats - 1) * 5

    return {
        "max_consecutive_lines": max_consecutive,
        "max_chunk_repeats": max_chunk_repeats,
        "repetition_score": repetition_score,
        "is_highly_repetitive": repetition_score > 10,
    }


def check_format_compliance(text: str) -> Dict[str, Any]:
    """
    Check if response follows THINKING/ANSWER format.
    """
    text_upper = text.upper()

    # Find sections
    has_thinking = "THINKING" in text_upper
    has_answer = "ANSWER" in text_upper

    # Count occurrences
    thinking_count = text_upper.count("THINKING")
    answer_count = text_upper.count("ANSWER")

    # Check if properly structured
    lines = text.split("\n")
    thinking_line = None
    answer_line = None

    for i, line in enumerate(lines):
        line_clean = line.strip().upper()
        if line_clean == "THINKING":
            thinking_line = i
        elif line_clean == "ANSWER":
            answer_line = i

    proper_structure = (
        thinking_line is not None
        and answer_line is not None
        and answer_line > thinking_line
    )

    return {
        "has_thinking": has_thinking,
        "has_answer": has_answer,
        "thinking_count": thinking_count,
        "answer_count": answer_count,
        "proper_structure": proper_structure,
        "multiple_sections": thinking_count > 1 or answer_count > 1,
    }


def analyze_response_quality(text: str) -> Dict[str, Any]:
    """
    Quick analysis of response quality indicators.
    """
    # Basic metrics
    char_count = len(text)
    word_count = len(text.split())
    line_count = len([l for l in text.split("\n") if l.strip()])

    # Check for truncation indicators
    ends_mid_sentence = text.strip() and not text.strip()[-1] in ".!?"
    ends_abruptly = text.endswith("...") or text.endswith("THINK")

    # Check for empty or very short responses
    is_too_short = word_count < 5
    is_too_long = word_count > 2000  # Suspiciously long

    return {
        "char_count": char_count,
        "word_count": word_count,
        "line_count": line_count,
        "ends_mid_sentence": ends_mid_sentence,
        "ends_abruptly": ends_abruptly,
        "is_too_short": is_too_short,
        "is_too_long": is_too_long,
    }


def analyze_single_file(file_path: str) -> Dict[str, Any]:
    """
    Analyze a single JSON results file.
    """
    print(f"Analyzing {os.path.basename(file_path)}...")

    try:
        with open(file_path, "r", encoding="utf-8") as f:
            data = json.load(f)
    except (json.JSONDecodeError, IOError) as e:
        return {"error": f"Failed to load {file_path}: {e}"}

    if not isinstance(data, list):
        return {"error": f"Expected list, got {type(data)}"}

    results = {
        "file_path": file_path,
        "total_responses": len(data),
        "responses_analyzed": 0,
        "patterns": {
            "repetition": {
                "highly_repetitive": 0,
                "some_repetition": 0,
                "max_repetition_score": 0,
            },
            "format": {
                "missing_thinking": 0,
                "missing_answer": 0,
                "proper_structure": 0,
                "multiple_sections": 0,
                "no_format": 0,
            },
            "quality": {"too_short": 0, "too_long": 0, "truncated": 0, "normal": 0},
        },
        "examples": {"repetitive": [], "malformed": [], "truncated": []},
    }

    for i, item in enumerate(data):
        if not isinstance(item, dict) or "completion" not in item:
            continue

        completion = item.get("completion", "")
        if not completion:
            continue

        results["responses_analyzed"] += 1

        # Analyze patterns
        repetition = detect_repetition(completion)
        format_check = check_format_compliance(completion)
        quality = analyze_response_quality(completion)

        # Update repetition stats
        if repetition["is_highly_repetitive"]:
            results["patterns"]["repetition"]["highly_repetitive"] += 1
            if len(results["examples"]["repetitive"]) < 3:
                results["examples"]["repetitive"].append(
                    {
                        "id": item.get("id", f"item_{i}"),
                        "repetition_score": repetition["repetition_score"],
                        "preview": (
                            completion[:200] + "..."
                            if len(completion) > 200
                            else completion
                        ),
                    }
                )
        elif repetition["repetition_score"] > 3:
            results["patterns"]["repetition"]["some_repetition"] += 1

        results["patterns"]["repetition"]["max_repetition_score"] = max(
            results["patterns"]["repetition"]["max_repetition_score"],
            repetition["repetition_score"],
        )

        # Update format stats
        if not format_check["has_thinking"]:
            results["patterns"]["format"]["missing_thinking"] += 1
        if not format_check["has_answer"]:
            results["patterns"]["format"]["missing_answer"] += 1
        if format_check["proper_structure"]:
            results["patterns"]["format"]["proper_structure"] += 1
        if format_check["multiple_sections"]:
            results["patterns"]["format"]["multiple_sections"] += 1
        if not format_check["has_thinking"] and not format_check["has_answer"]:
            results["patterns"]["format"]["no_format"] += 1
            if len(results["examples"]["malformed"]) < 3:
                results["examples"]["malformed"].append(
                    {
                        "id": item.get("id", f"item_{i}"),
                        "issue": "no_format",
                        "preview": (
                            completion[:200] + "..."
                            if len(completion) > 200
                            else completion
                        ),
                    }
                )

        # Update quality stats
        if quality["is_too_short"]:
            results["patterns"]["quality"]["too_short"] += 1
        elif quality["is_too_long"]:
            results["patterns"]["quality"]["too_long"] += 1
        elif quality["ends_mid_sentence"] or quality["ends_abruptly"]:
            results["patterns"]["quality"]["truncated"] += 1
            if len(results["examples"]["truncated"]) < 3:
                results["examples"]["truncated"].append(
                    {
                        "id": item.get("id", f"item_{i}"),
                        "word_count": quality["word_count"],
                        "preview": (
                            completion[-200:] if len(completion) > 200 else completion
                        ),
                    }
                )
        else:
            results["patterns"]["quality"]["normal"] += 1

    return results


def print_summary(all_results: List[Dict[str, Any]]):
    """
    Print a summary of all analyzed files.
    """
    # Filter out error results
    valid_results = [r for r in all_results if "error" not in r]
    error_results = [r for r in all_results if "error" in r]

    if error_results:
        print(f"\n❌ Failed to analyze {len(error_results)} files:")
        for err in error_results:
            print(f"  - {err['error']}")

    if not valid_results:
        print("\n❌ No valid results to analyze!")
        return

    # Aggregate stats
    total_responses = sum(r["responses_analyzed"] for r in valid_results)
    total_files = len(valid_results)

    print("\n📊 QWEN-3-32B Response Analysis Summary")
    print("=" * 50)
    print(f"Files analyzed: {total_files}")
    print(f"Total responses: {total_responses:,}")

    # Repetition analysis
    total_highly_repetitive = sum(
        r["patterns"]["repetition"]["highly_repetitive"] for r in valid_results
    )
    total_some_repetition = sum(
        r["patterns"]["repetition"]["some_repetition"] for r in valid_results
    )
    max_repetition_score = max(
        r["patterns"]["repetition"]["max_repetition_score"] for r in valid_results
    )

    print("\n🔄 Repetition Issues:")
    print(
        f"  Highly repetitive: {total_highly_repetitive:,} ({100*total_highly_repetitive/total_responses:.1f}%)"
    )
    print(
        f"  Some repetition: {total_some_repetition:,} ({100*total_some_repetition/total_responses:.1f}%)"
    )
    print(f"  Max repetition score: {max_repetition_score}")

    # Format analysis
    total_missing_thinking = sum(
        r["patterns"]["format"]["missing_thinking"] for r in valid_results
    )
    total_missing_answer = sum(
        r["patterns"]["format"]["missing_answer"] for r in valid_results
    )
    total_proper_structure = sum(
        r["patterns"]["format"]["proper_structure"] for r in valid_results
    )
    total_multiple_sections = sum(
        r["patterns"]["format"]["multiple_sections"] for r in valid_results
    )
    total_no_format = sum(r["patterns"]["format"]["no_format"] for r in valid_results)

    print("\n📝 Format Issues:")
    print(
        f"  Missing THINKING: {total_missing_thinking:,} ({100*total_missing_thinking/total_responses:.1f}%)"
    )
    print(
        f"  Missing ANSWER: {total_missing_answer:,} ({100*total_missing_answer/total_responses:.1f}%)"
    )
    print(
        f"  Proper structure: {total_proper_structure:,} ({100*total_proper_structure/total_responses:.1f}%)"
    )
    print(
        f"  Multiple sections: {total_multiple_sections:,} ({100*total_multiple_sections/total_responses:.1f}%)"
    )
    print(
        f"  No format at all: {total_no_format:,} ({100*total_no_format/total_responses:.1f}%)"
    )

    # Quality analysis
    total_too_short = sum(r["patterns"]["quality"]["too_short"] for r in valid_results)
    total_too_long = sum(r["patterns"]["quality"]["too_long"] for r in valid_results)
    total_truncated = sum(r["patterns"]["quality"]["truncated"] for r in valid_results)
    total_normal = sum(r["patterns"]["quality"]["normal"] for r in valid_results)

    print("\n📏 Quality Issues:")
    print(
        f"  Too short: {total_too_short:,} ({100*total_too_short/total_responses:.1f}%)"
    )
    print(f"  Too long: {total_too_long:,} ({100*total_too_long/total_responses:.1f}%)")
    print(
        f"  Truncated: {total_truncated:,} ({100*total_truncated/total_responses:.1f}%)"
    )
    print(f"  Normal: {total_normal:,} ({100*total_normal/total_responses:.1f}%)")

    # Show examples of worst issues
    if total_highly_repetitive > 0:
        print("\n🔍 Examples of highly repetitive responses:")
        for result in valid_results:
            for example in result["examples"]["repetitive"]:
                print(f"  - {example['id']} (score: {example['repetition_score']})")
                print(f"    Preview: {example['preview'][:100]}...")
                break

    print("\n💡 Recommendations:")
    if total_highly_repetitive > total_responses * 0.05:  # >5% repetitive
        print("  - HIGH: Add repetition_penalty=1.1-1.2 to vLLM config")
    if total_missing_thinking > total_responses * 0.1:  # >10% missing sections
        print("  - MEDIUM: Improve prompt formatting or add stop sequences")
    if total_truncated > total_responses * 0.05:  # >5% truncated
        print("  - MEDIUM: Increase MAX_LENGTH or check memory pressure")


def main():
    parser = argparse.ArgumentParser(description="Analyze QWEN-3-32B response patterns")
    parser.add_argument(
        "--pattern",
        default="llm-inference/results/qwen-3-32B-v2_results*_batch*.json",
        help="File pattern to analyze (default: llm-inference/results/qwen-3-32B_results*_batch*.json)",
    )
    parser.add_argument("--output", help="Save detailed results to JSON file")
    parser.add_argument(
        "--verbose", action="store_true", help="Show per-file statistics"
    )

    args = parser.parse_args()

    # Find all matching files and exclude _tokens files
    all_files = glob.glob(args.pattern)
    files = [f for f in all_files if not f.endswith("_tokens.json")]

    if not files:
        print(f"❌ No files found matching pattern: {args.pattern}")
        if all_files:
            print(f"   (Found {len(all_files)} files but all were _tokens.json files)")
        return

    print(f"🔍 Found {len(files)} files to analyze...")

    # Analyze each file
    all_results = []
    for file_path in sorted(files):
        result = analyze_single_file(file_path)
        all_results.append(result)

        if args.verbose and "error" not in result:
            print(
                f"  {os.path.basename(file_path)}: {result['responses_analyzed']} responses"
            )

    # Print summary
    print_summary(all_results)

    # Save detailed results if requested
    if args.output:
        with open(args.output, "w", encoding="utf-8") as f:
            json.dump(all_results, f, indent=2)
        print(f"\n💾 Detailed results saved to: {args.output}")


if __name__ == "__main__":
    main()
