#!/usr/bin/env python3
import json
import re
import matplotlib.pyplot as plt
from collections import Counter
import seaborn as sns

def extract_answer_types(generated_questions):
    """Extract all answer_type values from the generated_questions field."""
    answer_types = []
    
    # Find all answer_type tags using regex
    pattern = r'<answer_type>(.*?)</answer_type>'
    matches = re.findall(pattern, generated_questions, re.DOTALL)
    
    for match in matches:
        # Clean up the answer type (remove extra whitespace)
        answer_type = match.strip()
        if answer_type and answer_type != "No leakage found":
            answer_types.append(answer_type)
    
    return answer_types

def analyze_answer_types():
    """Read the combined JSONL file and analyze answer type distributions."""
    answer_types = []
    
    # Read the combined file
    with open('qgen/debug/combined_dw_30_free1.jsonl', 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            try:
                data = json.loads(line.strip())
                generated_questions = data.get('generated_questions', '')
                
                # Extract answer types from this line
                line_answer_types = extract_answer_types(generated_questions)
                answer_types.extend(line_answer_types)
                
            except json.JSONDecodeError as e:
                print(f"Error parsing line {line_num}: {e}")
                continue
    
    return answer_types

def group_low_frequency_types(type_counts, threshold=1):
    """Group answer types with count <= threshold into 'Others' category."""
    grouped_counts = {}
    others_count = 0
    others_items = []
    
    for answer_type, count in type_counts.items():
        if count <= threshold:
            others_count += count
            others_items.append(f"{answer_type} ({count})")
        else:
            grouped_counts[answer_type] = count
    
    if others_count > 0:
        grouped_counts['Others'] = others_count
    
    return grouped_counts, others_items

def create_pie_chart(answer_types):
    """Create and display a beautiful pie chart of answer type distributions."""
    # Count the frequency of each answer type
    type_counts = Counter(answer_types)
    
    # Group low-frequency types into "Others"
    grouped_counts, others_items = group_low_frequency_types(type_counts, threshold=1)
    
    # Prepare data for the pie chart
    labels = list(grouped_counts.keys())
    sizes = list(grouped_counts.values())
    
    # Set up the plot style
    plt.style.use('default')
    sns.set_palette("husl")
    
    # Create a beautiful color palette
    colors = sns.color_palette("husl", len(labels))
    
    # Create the figure with larger size
    fig, ax = plt.subplots(figsize=(16, 12))
    
    # Create the pie chart with improved styling
    wedges, texts, autotexts = ax.pie(sizes, labels=labels, autopct='%1.1f%%', 
                                      colors=colors, startangle=90,
                                      explode=[0.05 if label == 'Others' else 0 for label in labels],
                                      shadow=True, textprops={'fontsize': 18, 'fontweight': 'bold'})
    
    # Improve text styling
    for autotext in autotexts:
        autotext.set_color('white')
        autotext.set_fontweight('bold')
        autotext.set_fontsize(20)
    
    # Style the labels
    for text in texts:
        text.set_fontsize(20)
        text.set_fontweight('bold')
    
    # Add a beautiful title
    plt.title('Distribution of Answer Types in Generated Questions', 
              fontsize=28, fontweight='bold', pad=30, color='darkblue')
    
    # Create a detailed legend
    legend_labels = []
    for label, count in grouped_counts.items():
        if label == 'Others' and others_items:
            others_detail = ', '.join(others_items[:5])  # Show first 5 items
            if len(others_items) > 5:
                others_detail += f', and {len(others_items)-5} more'
            legend_labels.append(f'{label} ({count})\n{others_detail}')
        else:
            legend_labels.append(f'{label} ({count})')
    
    plt.legend(wedges, legend_labels,
               title="Answer Types (Count)",
               title_fontsize=18,
               fontsize=14,
               loc="center left",
               bbox_to_anchor=(1, 0, 0.5, 1),
               frameon=True,
               fancybox=True,
               shadow=True)
    
    # Adjust layout
    plt.tight_layout()
    
    # Save the chart with high quality
    plt.savefig('answer_types_distribution.png', dpi=300, bbox_inches='tight', 
                facecolor='white', edgecolor='none')
    plt.show()
    
    # Print detailed summary statistics
    print(f"\n{'='*60}")
    print(f"{'ANSWER TYPE ANALYSIS SUMMARY':^60}")
    print(f"{'='*60}")
    print(f"Total questions analyzed: {sum(sizes)}")
    print(f"Number of unique answer types: {len(type_counts)}")
    print(f"Number of categories shown: {len(grouped_counts)}")
    print(f"\nDetailed breakdown:")
    print(f"{'-'*60}")
    
    for answer_type, count in Counter(grouped_counts).most_common():
        percentage = (count / sum(sizes)) * 100
        print(f"{answer_type:<30}: {count:>3} ({percentage:>5.1f}%)")
    
    if others_items:
        print(f"\nItems in 'Others' category:")
        print(f"{'-'*40}")
        for item in others_items:
            print(f"  • {item}")

def main():
    print("🔍 Analyzing answer types from combined JSONL file...")
    
    # Extract all answer types
    answer_types = analyze_answer_types()
    
    if not answer_types:
        print("❌ No answer types found in the file.")
        return
    
    print(f"✅ Found {len(answer_types)} total answer types")
    
    # Create and display the pie chart
    create_pie_chart(answer_types)
    
    print(f"\n📊 Chart saved as 'answer_types_distribution.png'")

if __name__ == "__main__":
    main() 