#!/usr/bin/env python3
"""
Check all prompt files for duplicate graphs between examples and test cases.
Focuses on patterns where the first example size equals the test case size.
"""

import os
import re
import json
from typing import Dict, List, Tuple
from collections import defaultdict

# Define patterns that potentially have duplicates
# (where first example size could equal test size)
PATTERNS_WITH_POTENTIAL_DUPLICATES = {
    "mixed_3": [5, 10, 5],  # First: 5, Test: 5 - DUPLICATE RISK
    "cap10_3": [10, 10, 10],  # First: 10, Test: 10 - DUPLICATE RISK
    "small_4": [5, 4, 5, 5],  # First: 5, Test: 5 - DUPLICATE RISK
    "large_2": [15, 15],  # First: 15, Test: 15 - DUPLICATE RISK
    # These patterns don't have the issue:
    "cap25_3": [10, 10, 25],  # First: 10, Test: 25 - OK
    "cap50_3": [10, 10, 50],  # First: 10, Test: 50 - OK
    "cap100_3": [10, 10, 100],  # First: 10, Test: 100 - OK
    "cap250_3": [10, 10, 250],  # First: 10, Test: 250 - OK
    "scale_up_2": [5, 15],  # First: 5, Test: 15 - OK
    "scale_up_3": [5, 10, 15],  # First: 5, Test: 15 - OK
    "scale_up_4": [5, 10, 15, 15],  # First: 5, Test: 15 - OK
    "progressive_5": [3, 5, 8, 10, 15],  # First: 3, Test: 15 - OK
}


def extract_graph_sections(content: str) -> Dict[str, str]:
    """
    Extract input graph sections from prompt content.
    Returns dict mapping section names to their content.
    """
    sections = {}

    # Split by "Input" to get all input sections
    input_pattern = r"Input (\d+):(.*?)(?=(?:Output \d+:|Input \d+:|$))"
    matches = re.findall(input_pattern, content, re.DOTALL)

    for num, graph_content in matches:
        sections[f"Input {num}"] = graph_content.strip()

    return sections


def check_for_duplicates(filepath: str) -> Tuple[bool, List[str]]:
    """
    Check a single prompt file for duplicate graphs.
    Returns (has_duplicates, list_of_duplicate_pairs)
    """
    try:
        with open(filepath, "r", encoding="utf-8") as f:
            content = f.read()
    except (IOError, UnicodeDecodeError):
        return False, []

    # Extract all input sections
    sections = extract_graph_sections(content)

    if not sections:
        return False, []

    # Compare all sections for duplicates
    duplicates = []
    section_items = list(sections.items())

    for i in range(len(section_items)):
        for j in range(i + 1, len(section_items)):
            name1, content1 = section_items[i]
            name2, content2 = section_items[j]

            # Normalize content for comparison (remove extra whitespace)
            norm1 = " ".join(content1.split())
            norm2 = " ".join(content2.split())

            if norm1 == norm2 and norm1:  # Don't count empty sections
                duplicates.append(f"{name1} == {name2}")

    return len(duplicates) > 0, duplicates


def identify_pattern_from_filename(filename: str) -> str:
    """
    Extract pattern name from filename.
    """
    for pattern in PATTERNS_WITH_POTENTIAL_DUPLICATES.keys():
        if pattern in filename:
            return pattern
    return None


def scan_all_prompts(verbose: bool = False) -> Dict[str, List[str]]:
    """
    Scan all prompt files and check for duplicates.
    Returns dict mapping pattern names to lists of files with duplicates.
    """
    results = defaultdict(list)
    files_checked = 0
    duplicates_found = 0

    # Track issues by pattern
    pattern_stats = defaultdict(lambda: {"total": 0, "duplicates": 0})

    print("Scanning for duplicate graphs in prompts...")
    print("-" * 60)

    for root, dirs, files in os.walk("datasets"):
        if "prompts" not in root:
            continue

        for filename in files:
            if not filename.endswith(".txt"):
                continue

            filepath = os.path.join(root, filename)
            pattern = identify_pattern_from_filename(filename)

            if pattern:
                pattern_stats[pattern]["total"] += 1

            files_checked += 1
            has_dups, dup_list = check_for_duplicates(filepath)

            if has_dups:
                duplicates_found += 1
                relative_path = os.path.relpath(filepath)
                results[pattern or "unknown"].append(relative_path)

                if pattern:
                    pattern_stats[pattern]["duplicates"] += 1

                if verbose:
                    print(f"❌ DUPLICATE in {relative_path}")
                    for dup in dup_list:
                        print(f"   - {dup}")

    print(f"\nScanned {files_checked} prompt files")
    print(f"Found {duplicates_found} files with duplicates")
    print("\n" + "=" * 60)
    print("PATTERN ANALYSIS:")
    print("=" * 60)

    # Analyze patterns
    for pattern, sizes in PATTERNS_WITH_POTENTIAL_DUPLICATES.items():
        stats = pattern_stats[pattern]
        total = stats["total"]
        dups = stats["duplicates"]

        # Check if pattern should have duplicates
        should_have_issue = sizes[0] == sizes[-1]  # First equals last

        if total > 0:
            percentage = (dups / total * 100) if total > 0 else 0
            status = "⚠️ AT RISK" if should_have_issue else "✓ SAFE"

            print(f"\n{pattern}: {sizes}")
            print(f"  Status: {status}")
            print(f"  Files: {total}")
            print(f"  Duplicates: {dups} ({percentage:.1f}%)")

            if dups > 0 and not should_have_issue:
                print(f"  🔍 UNEXPECTED: This pattern shouldn't have duplicates!")
            elif dups == 0 and should_have_issue:
                print(f"  ✅ GOOD: Fix appears to be working!")

    return results


def main():
    """
    Main function to run the duplicate check.
    """
    import argparse

    parser = argparse.ArgumentParser(
        description="Check prompt files for duplicate graphs"
    )
    parser.add_argument(
        "--verbose",
        "-v",
        action="store_true",
        help="Show detailed output for each file",
    )
    parser.add_argument("--pattern", help="Check only specific pattern (e.g., mixed_3)")
    parser.add_argument(
        "--fix-check",
        action="store_true",
        help="Specifically check if the fix is working",
    )

    args = parser.parse_args()

    if args.fix_check:
        print("🔍 Checking if duplicate fix is working...")
        print("-" * 60)

        # Check patterns that should have had issues
        risky_patterns = ["mixed_3", "cap10_3", "small_4", "large_2"]
        all_good = True

        for pattern in risky_patterns:
            pattern_files = []
            for root, dirs, files in os.walk("datasets"):
                if "prompts" not in root:
                    continue
                for f in files:
                    if pattern in f and f.endswith(".txt"):
                        pattern_files.append(os.path.join(root, f))

            if not pattern_files:
                print(f"⚠️ No files found for pattern: {pattern}")
                continue

            # Check a sample
            import random

            sample_size = min(10, len(pattern_files))
            samples = random.sample(pattern_files, sample_size)

            dups_found = 0
            for filepath in samples:
                has_dups, _ = check_for_duplicates(filepath)
                if has_dups:
                    dups_found += 1

            if dups_found > 0:
                print(
                    f"❌ {pattern}: Found {dups_found}/{sample_size} files with duplicates"
                )
                all_good = False
            else:
                print(f"✅ {pattern}: No duplicates in {sample_size} sampled files")

        if all_good:
            print(
                "\n🎉 Fix appears to be working! No duplicates found in risky patterns."
            )
        else:
            print("\n⚠️ Fix may not be fully applied. Some duplicates still exist.")

    else:
        results = scan_all_prompts(verbose=args.verbose)

        # Print summary
        print("\n" + "=" * 60)
        print("SUMMARY BY PATTERN:")
        print("=" * 60)

        for pattern, files in sorted(results.items()):
            if args.pattern and pattern != args.pattern:
                continue

            print(f"\n{pattern}: {len(files)} files with duplicates")
            if args.verbose and files:
                for f in files[:5]:  # Show first 5
                    print(f"  - {f}")
                if len(files) > 5:
                    print(f"  ... and {len(files) - 5} more")


if __name__ == "__main__":
    main()
