#!/usr/bin/env python3
"""
Script to analyze conversation lengths in dataset_minimal.json

This script counts the number of turns in each conversation (in the "prompts" field)
and returns statistics about the distribution of conversation lengths.
Analyzes both overall distribution and agentic vs non-agentic samples.
"""

import json
from collections import Counter, defaultdict
import matplotlib.pyplot as plt
import numpy as np


def analyze_conversation_lengths(dataset_path):
    """
    Analyze conversation lengths in the dataset.
    
    Args:
        dataset_path (str): Path to the dataset JSON file
        
    Returns:
        dict: Statistics about conversation lengths
    """
    print(f"Loading dataset from: {dataset_path}")
    
    with open(dataset_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    print(f"Total samples in dataset: {len(data)}")
    
    # Define agentic groups
    agentic_groups = {
        'aider_commit', 'aider_manual', 'aider_multi', 'continue_dev', 'irl_cursor', 'openhands',
        'smol_real_manager', 'smol_real_search', 'agentharm', 'aider_full', 'cybench', 'gaia_runs',
        'honeypot', 'mvp', 'smol_gaia_manager', 'smol_gaia_search', 'swe_gym', 'swebench'
    }
    
    # Count conversation lengths
    conversation_lengths = []
    agentic_lengths = []
    non_agentic_lengths = []
    role_stats = defaultdict(int)
    agentic_role_stats = defaultdict(int)
    non_agentic_role_stats = defaultdict(int)
    
    for i, sample in enumerate(data):
        if 'prompts' not in sample:
            print(f"Warning: Sample {i} (id: {sample.get('id', 'unknown')}) has no 'prompts' field")
            continue
            
        prompts = sample['prompts']
        conversation_length = len(prompts)
        conversation_lengths.append(conversation_length)
        
        # Determine if sample is agentic or non-agentic
        sample_id = sample.get('id', '')
        group_name = sample_id.split(':')[0] if ':' in sample_id else sample_id
        is_agentic = group_name in agentic_groups
        
        if is_agentic:
            agentic_lengths.append(conversation_length)
        else:
            non_agentic_lengths.append(conversation_length)
        
        # Count roles in this conversation
        for prompt in prompts:
            if 'role' in prompt:
                role_stats[prompt['role']] += 1
                if is_agentic:
                    agentic_role_stats[prompt['role']] += 1
                else:
                    non_agentic_role_stats[prompt['role']] += 1
    
    # Calculate statistics
    if not conversation_lengths:
        print("No valid conversations found!")
        return {}
    
    def calculate_stats(lengths, name):
        if not lengths:
            return {}
        return {
            'count': len(lengths),
            'min_length': min(lengths),
            'max_length': max(lengths),
            'average_length': sum(lengths) / len(lengths),
            'median_length': sorted(lengths)[len(lengths) // 2],
            'length_distribution': dict(Counter(lengths))
        }
    
    overall_stats = calculate_stats(conversation_lengths, 'overall')
    agentic_stats = calculate_stats(agentic_lengths, 'agentic')
    non_agentic_stats = calculate_stats(non_agentic_lengths, 'non_agentic')
    
    stats = {
        'overall': overall_stats,
        'agentic': agentic_stats,
        'non_agentic': non_agentic_stats,
        'role_statistics': {
            'overall': dict(role_stats),
            'agentic': dict(agentic_role_stats),
            'non_agentic': dict(non_agentic_role_stats)
        }
    }
    
    return stats


def print_statistics(stats):
    """Print formatted statistics."""
    if not stats:
        return
        
    print("\n" + "="*80)
    print("CONVERSATION LENGTH ANALYSIS")
    print("="*80)
    
    def print_section_stats(section_stats, section_name):
        if not section_stats:
            print(f"\n{section_name.upper()} - No data available")
            return
            
        print(f"\n{section_name.upper()} Statistics:")
        print(f"  Total conversations: {section_stats['count']:,}")
        print(f"  Minimum length: {section_stats['min_length']} turns")
        print(f"  Maximum length: {section_stats['max_length']} turns")
        print(f"  Average length: {section_stats['average_length']:.2f} turns")
        print(f"  Median length: {section_stats['median_length']} turns")
        
        print(f"\n{section_name.capitalize()} Length Distribution:")
        print(f"{'Length (turns)':<15} {'Count':<10} {'Percentage':<12} {'Cumulative %':<15}")
        print("-" * 55)
        
        total = section_stats['count']
        cumulative = 0
        
        # Sort by length for better readability
        for length in sorted(section_stats['length_distribution'].keys()):
            count = section_stats['length_distribution'][length]
            percentage = (count / total) * 100
            cumulative += count
            cumulative_pct = (cumulative / total) * 100
            
            print(f"{length:<15} {count:<10,} {percentage:<12.2f}% {cumulative_pct:<15.2f}%")
    
    # Print overall statistics
    print_section_stats(stats['overall'], 'overall')
    
    # Print agentic vs non-agentic comparison
    print(f"\n" + "="*80)
    print("AGENTIC VS NON-AGENTIC COMPARISON")
    print("="*80)
    
    agentic_count = stats['agentic']['count'] if stats['agentic'] else 0
    non_agentic_count = stats['non_agentic']['count'] if stats['non_agentic'] else 0
    total_count = agentic_count + non_agentic_count
    
    print(f"\nSample Distribution:")
    print(f"  Agentic samples: {agentic_count:,} ({(agentic_count/total_count)*100:.2f}%)")
    print(f"  Non-agentic samples: {non_agentic_count:,} ({(non_agentic_count/total_count)*100:.2f}%)")
    print(f"  Total: {total_count:,}")
    
    # Print agentic statistics
    print_section_stats(stats['agentic'], 'agentic')
    
        # Print non-agentic statistics  
    print_section_stats(stats['non_agentic'], 'non-agentic')


def plot_distributions(stats):
    """Plot the conversation length distributions as a stylish line plot."""
    
    # Set up the style
    plt.style.use('seaborn-v0_8-darkgrid')
    
    # Prepare data
    overall_dist = stats['overall']['length_distribution']
    agentic_dist = stats['agentic']['length_distribution'] if stats['agentic'] else {}
    non_agentic_dist = stats['non_agentic']['length_distribution'] if stats['non_agentic'] else {}
    
    # Get all unique lengths across all distributions
    all_lengths = set()
    all_lengths.update(overall_dist.keys())
    all_lengths.update(agentic_dist.keys())
    all_lengths.update(non_agentic_dist.keys())
    all_lengths = sorted(list(all_lengths))
    
    # Convert to arrays for plotting
    overall_counts = [overall_dist.get(length, 0) for length in all_lengths]
    agentic_counts = [agentic_dist.get(length, 0) for length in all_lengths]
    non_agentic_counts = [non_agentic_dist.get(length, 0) for length in all_lengths]
    
    # Create figure with elegant styling
    fig, ax = plt.subplots(1, 1, figsize=(14, 8))
    fig.patch.set_facecolor('white')
    
    # Define colors with good contrast and style
    colors = {
        'overall': '#2E86AB',      # Deep blue
        'agentic': '#A23B72',      # Deep magenta
        'non_agentic': '#F18F01'   # Orange
    }
    
    # Plot lines with markers and styling
    line_overall = ax.plot(all_lengths, overall_counts, 
                          color=colors['overall'], 
                          linewidth=3, 
                          marker='o', 
                          markersize=4,
                          markerfacecolor='white',
                          markeredgecolor=colors['overall'],
                          markeredgewidth=2,
                          label=f'Overall (n={stats["overall"]["count"]})',
                          alpha=0.9)
    
    if agentic_counts:
        line_agentic = ax.plot(all_lengths, agentic_counts, 
                              color=colors['agentic'], 
                              linewidth=3, 
                              marker='s', 
                              markersize=4,
                              markerfacecolor='white',
                              markeredgecolor=colors['agentic'],
                              markeredgewidth=2,
                              label=f'Agentic (n={stats["agentic"]["count"]})',
                              alpha=0.9)
    
    if non_agentic_counts:
        line_non_agentic = ax.plot(all_lengths, non_agentic_counts, 
                                  color=colors['non_agentic'], 
                                  linewidth=3, 
                                  marker='^', 
                                  markersize=4,
                                  markerfacecolor='white',
                                  markeredgecolor=colors['non_agentic'],
                                  markeredgewidth=2,
                                  label=f'Non-Agentic (n={stats["non_agentic"]["count"]})',
                                  alpha=0.9)
    
    # Styling the plot
    ax.set_title('Conversation Length Distributions', 
                fontsize=20, 
                fontweight='bold', 
                pad=20,
                color='#2c3e50')
    
    ax.set_xlabel('Conversation Length (turns)', 
                 fontsize=14, 
                 fontweight='semibold',
                 color='#2c3e50')
    
    ax.set_ylabel('Number of Conversations', 
                 fontsize=14, 
                 fontweight='semibold',
                 color='#2c3e50')
    
    # Customize grid
    ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.8)
    ax.set_facecolor('#fafafa')
    
    # Set axis limits for better visualization
    ax.set_xlim(min(all_lengths) - 1, min(max(all_lengths), 50))  # Limit x-axis to 50 for better readability
    ax.set_ylim(0, max(max(overall_counts), max(agentic_counts if agentic_counts else [0]), 
                      max(non_agentic_counts if non_agentic_counts else [0])) * 1.1)
    
    # Customize legend
    legend = ax.legend(loc='upper right', 
                      frameon=True, 
                      fancybox=True, 
                      shadow=True,
                      fontsize=12,
                      title='Distribution Types',
                      title_fontsize=13)
    legend.get_frame().set_facecolor('white')
    legend.get_frame().set_alpha(0.9)
    legend.get_title().set_fontweight('bold')
    
    # Customize ticks
    ax.tick_params(axis='both', which='major', labelsize=11, colors='#2c3e50')
    
    # Add subtle annotations for key insights
    if non_agentic_counts:
        max_non_agentic_idx = non_agentic_counts.index(max(non_agentic_counts))
        max_non_agentic_x = all_lengths[max_non_agentic_idx]
        max_non_agentic_y = non_agentic_counts[max_non_agentic_idx]
        
        ax.annotate(f'Peak: {max_non_agentic_y} conversations\nat {max_non_agentic_x} turns', 
                   xy=(max_non_agentic_x, max_non_agentic_y), 
                   xytext=(max_non_agentic_x + 10, max_non_agentic_y + 50),
                   arrowprops=dict(arrowstyle='->', color=colors['non_agentic'], alpha=0.7),
                   fontsize=10, 
                   bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.8),
                   color='#2c3e50')
    
    # Adjust layout
    plt.tight_layout()
    
    print(f"\nDisplaying stylish line plot of conversation length distributions...")
    
    # Show the plot
    plt.show()


def main():
    dataset_path = 'datasets/dataset_minimal.json'
    
    try:
        stats = analyze_conversation_lengths(dataset_path)
        print_statistics(stats)
        plot_distributions(stats)
            
    except FileNotFoundError:
        print(f"Error: Could not find dataset file: {dataset_path}")
    except json.JSONDecodeError as e:
        print(f"Error: Invalid JSON in dataset file: {e}")
    except Exception as e:
        print(f"Error: {e}")


if __name__ == "__main__":
    main()
