"""
Publication-Quality Figure Generation for Educational Chatbot Research
Creates all necessary visualizations for the research paper
"""

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import json
import os
from datetime import datetime
import pandas as pd

# Set publication-quality style
plt.style.use('default')  # More reliable than seaborn styles
sns.set_palette("husl")

# Global figure settings
plt.rcParams.update({
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 16,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'legend.fontsize': 12,
    'figure.titlesize': 18,
    'savefig.bbox': 'tight',
    'savefig.format': 'pdf'
})

def load_results():
    """Load experimental results"""
    # Find the most recent results directory
    result_dirs = [d for d in os.listdir('.') if d.startswith('results_')]
    if not result_dirs:
        raise FileNotFoundError("No results directory found!")
    
    latest_dir = sorted(result_dirs)[-1]
    results_file = os.path.join(latest_dir, 'comprehensive_results.json')
    
    with open(results_file, 'r') as f:
        results = json.load(f)
    
    return results, latest_dir

def create_training_curves(results, output_dir):
    """Create training and validation curves"""
    print("Creating training curves...")
    
    training_data = results['main_results']['training_curves']
    epochs = training_data['epochs']
    train_loss = training_data['train_loss']
    val_loss = training_data['val_loss']
    val_accuracy = training_data['val_accuracy']
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Loss curves
    ax1.plot(epochs, train_loss, 'b-', linewidth=2, label='Training Loss', marker='o', markersize=4)
    ax1.plot(epochs, val_loss, 'r-', linewidth=2, label='Validation Loss', marker='s', markersize=4)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim(0, max(max(train_loss), max(val_loss)) * 1.1)
    
    # Accuracy curve
    ax2.plot(epochs, [acc * 100 for acc in val_accuracy], 'g-', linewidth=2, 
             label='Validation Accuracy', marker='^', markersize=4)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_title('Validation Accuracy')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0, 100)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'training_curves.pdf'))
    plt.savefig(os.path.join(output_dir, 'training_curves.png'))
    plt.close()

def create_performance_comparison(results, output_dir):
    """Create performance comparison bar chart"""
    print("Creating performance comparison...")
    
    baselines = results['baselines']
    main = results['main_results']
    
    # Prepare data
    methods = ['Offline\nTextbooks', 'Kolibri\nVanilla', 'ChatGPT\nEducational', 
              'DistilBERT\nBaseline', 'Our Method']
    
    accuracies = [
        0,  # Textbooks have no interactive accuracy
        0,  # Kolibri vanilla has no AI
        baselines['chatgpt_educational']['educational_accuracy'] * 100,
        baselines['distilbert_baseline']['educational_accuracy'] * 100,
        main['educational_accuracy'] * 100
    ]
    
    satisfaction = [
        baselines['offline_textbooks']['user_satisfaction'],
        baselines['kolibri_vanilla']['user_satisfaction'],
        baselines['chatgpt_educational']['user_satisfaction'],
        baselines['distilbert_baseline']['user_satisfaction'],
        main['user_satisfaction']
    ]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Educational accuracy
    bars1 = ax1.bar(methods, accuracies, color=['#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#1f77b4'])
    ax1.set_ylabel('Educational Accuracy (%)')
    ax1.set_title('Educational Query Accuracy Comparison')
    ax1.set_ylim(0, 100)
    
    # Add value labels on bars
    for bar, acc in zip(bars1, accuracies):
        if acc > 0:
            ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
                    f'{acc:.1f}%', ha='center', va='bottom', fontweight='bold')
        else:
            ax1.text(bar.get_x() + bar.get_width()/2, 5, 
                    'N/A', ha='center', va='bottom', fontweight='bold')
    
    # User satisfaction  
    bars2 = ax2.bar(methods, satisfaction, color=['#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#1f77b4'])
    ax2.set_ylabel('User Satisfaction (1-7 scale)')
    ax2.set_title('User Satisfaction Comparison')
    ax2.set_ylim(0, 7)
    
    # Add value labels
    for bar, sat in zip(bars2, satisfaction):
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
                f'{sat:.1f}', ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'performance_comparison.pdf'))
    plt.savefig(os.path.join(output_dir, 'performance_comparison.png'))
    plt.close()

def create_ablation_study(results, output_dir):
    """Create ablation study visualization"""
    print("Creating ablation study visualization...")
    
    ablations = results['ablation_study']
    base_accuracy = results['main_results']['educational_accuracy'] * 100
    
    components = []
    accuracy_changes = []
    final_accuracies = []
    
    for component, data in ablations.items():
        components.append(component.replace('_', ' ').replace('no ', '').title())
        
        if 'accuracy_drop' in data:
            accuracy_changes.append(-data['accuracy_drop'] * 100)
            final_accuracies.append(data['final_accuracy'] * 100)
        else:  # accuracy gain
            accuracy_changes.append(data['accuracy_gain'] * 100)
            final_accuracies.append(data['final_accuracy'] * 100)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Accuracy changes
    colors = ['red' if change < 0 else 'green' for change in accuracy_changes]
    bars1 = ax1.barh(components, accuracy_changes, color=colors, alpha=0.7)
    ax1.set_xlabel('Accuracy Change (%)')
    ax1.set_title('Component Contribution (Ablation Study)')
    ax1.axvline(x=0, color='black', linestyle='-', alpha=0.5)
    ax1.grid(True, alpha=0.3)
    
    # Add value labels
    for bar, change in zip(bars1, accuracy_changes):
        ax1.text(change + (0.5 if change > 0 else -0.5), bar.get_y() + bar.get_height()/2,
                f'{change:+.1f}%', ha='left' if change > 0 else 'right', va='center')
    
    # Final accuracies
    bars2 = ax2.barh(components, final_accuracies, color='skyblue', alpha=0.7)
    ax2.set_xlabel('Final Accuracy (%)')
    ax2.set_title('Final Performance with Component Removed/Modified')
    ax2.set_xlim(70, 100)
    ax2.grid(True, alpha=0.3)
    
    # Add value labels
    for bar, acc in zip(bars2, final_accuracies):
        ax2.text(acc + 0.5, bar.get_y() + bar.get_height()/2,
                f'{acc:.1f}%', ha='left', va='center')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'ablation_study.pdf'))
    plt.savefig(os.path.join(output_dir, 'ablation_study.png'))
    plt.close()

def create_deployment_analysis(results, output_dir):
    """Create deployment environment analysis"""
    print("Creating deployment analysis...")
    
    deployment = results['domain_specific']['deployment_results']
    
    environments = list(deployment.keys())
    institutions = [deployment[env]['institutions'] for env in environments]
    users = [deployment[env]['users'] for env in environments]
    satisfaction = [deployment[env]['satisfaction'] for env in environments]
    
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 10))
    
    # Institutions pie chart
    ax1.pie(institutions, labels=environments, autopct='%1.1f%%', startangle=90)
    ax1.set_title(f'Institutions Deployed\n(Total: {sum(institutions)})')
    
    # Users bar chart
    bars2 = ax2.bar(environments, users, color='lightcoral')
    ax2.set_ylabel('Number of Users')
    ax2.set_title(f'Users per Environment\n(Total: {sum(users):,})')
    ax2.tick_params(axis='x', rotation=45)
    
    # Add value labels
    for bar, user_count in zip(bars2, users):
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 50,
                f'{user_count:,}', ha='center', va='bottom')
    
    # Satisfaction comparison
    bars3 = ax3.bar(environments, satisfaction, color='lightgreen')
    ax3.set_ylabel('User Satisfaction (1-7 scale)')
    ax3.set_title('Satisfaction by Environment')
    ax3.tick_params(axis='x', rotation=45)
    ax3.set_ylim(0, 7)
    
    # Add value labels
    for bar, sat in zip(bars3, satisfaction):
        ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
                f'{sat:.1f}', ha='center', va='bottom')
    
    # Success rate by environment
    success_rates = [deployment[env]['technical_success_rate'] * 100 for env in environments]
    bars4 = ax4.bar(environments, success_rates, color='gold')
    ax4.set_ylabel('Technical Success Rate (%)')
    ax4.set_title('Deployment Success by Environment')
    ax4.tick_params(axis='x', rotation=45)
    ax4.set_ylim(0, 100)
    
    # Add value labels
    for bar, rate in zip(bars4, success_rates):
        ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                f'{rate:.0f}%', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'deployment_analysis.pdf'))
    plt.savefig(os.path.join(output_dir, 'deployment_analysis.png'))
    plt.close()

def create_subject_performance(results, output_dir):
    """Create subject-specific performance analysis"""
    print("Creating subject performance analysis...")
    
    subject_perf = results['domain_specific']['subject_performance']
    
    subjects = list(subject_perf.keys())
    accuracies = [subject_perf[subj]['accuracy'] * 100 for subj in subjects]
    query_counts = [subject_perf[subj]['query_count'] for subj in subjects]
    response_times = [subject_perf[subj]['avg_response_time'] for subj in subjects]
    satisfaction_scores = [subject_perf[subj]['user_satisfaction'] for subj in subjects]
    
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 10))
    
    # Subject accuracy
    bars1 = ax1.bar(subjects, accuracies, color=['#1f77b4', '#ff7f0e', '#2ca02c'])
    ax1.set_ylabel('Accuracy (%)')
    ax1.set_title('Educational Accuracy by Subject')
    ax1.set_ylim(80, 100)
    
    for bar, acc in zip(bars1, accuracies):
        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                f'{acc:.1f}%', ha='center', va='bottom', fontweight='bold')
    
    # Query distribution
    ax2.pie(query_counts, labels=subjects, autopct='%1.1f%%', startangle=90,
           colors=['#1f77b4', '#ff7f0e', '#2ca02c'])
    ax2.set_title(f'Query Distribution by Subject\n(Total: {sum(query_counts):,})')
    
    # Response times
    bars3 = ax3.bar(subjects, response_times, color=['#d62728', '#9467bd', '#8c564b'])
    ax3.set_ylabel('Average Response Time (ms)')
    ax3.set_title('Response Time by Subject')
    ax3.axhline(y=500, color='red', linestyle='--', alpha=0.7, label='Target (500ms)')
    ax3.legend()
    
    for bar, time in zip(bars3, response_times):
        ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 10,
                f'{time:.0f}ms', ha='center', va='bottom')
    
    # Subject satisfaction
    bars4 = ax4.bar(subjects, satisfaction_scores, color=['#e377c2', '#7f7f7f', '#bcbd22'])
    ax4.set_ylabel('User Satisfaction (1-7 scale)')
    ax4.set_title('User Satisfaction by Subject')
    ax4.set_ylim(0, 7)
    
    for bar, sat in zip(bars4, satisfaction_scores):
        ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
                f'{sat:.1f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'subject_performance.pdf'))
    plt.savefig(os.path.join(output_dir, 'subject_performance.png'))
    plt.close()

def create_grade_performance(results, output_dir):
    """Create grade-level performance analysis"""
    print("Creating grade-level performance analysis...")
    
    grade_perf = results['domain_specific']['grade_performance']
    
    grades = [int(grade) for grade in grade_perf.keys()]  # Convert to integers
    accuracies = [grade_perf[str(grade)]['accuracy'] * 100 for grade in grades]
    complexity_scores = [grade_perf[str(grade)]['complexity_score'] for grade in grades]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Grade-level accuracy
    ax1.plot(grades, accuracies, 'bo-', linewidth=2, markersize=8)
    ax1.set_xlabel('Grade Level')
    ax1.set_ylabel('Educational Accuracy (%)')
    ax1.set_title('Performance vs Grade Level')
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim(85, 100)
    ax1.set_xticks(grades)
    
    # Add trend line
    z = np.polyfit(grades, accuracies, 1)
    p = np.poly1d(z)
    ax1.plot(grades, p(grades), 'r--', alpha=0.8, label=f'Trend (slope: {z[0]:.2f}%/grade)')
    ax1.legend()
    
    # Complexity vs accuracy scatter
    ax2.scatter(complexity_scores, accuracies, s=100, c=grades, cmap='viridis', alpha=0.7)
    ax2.set_xlabel('Content Complexity Score')
    ax2.set_ylabel('Educational Accuracy (%)')
    ax2.set_title('Accuracy vs Content Complexity')
    ax2.grid(True, alpha=0.3)
    
    # Add grade labels
    for grade, complex_score, acc in zip(grades, complexity_scores, accuracies):
        ax2.annotate(f'Grade {grade}', (complex_score, acc), 
                    xytext=(5, 5), textcoords='offset points', fontsize=9)
    
    # Add colorbar
    cbar = plt.colorbar(plt.cm.ScalarMappable(cmap='viridis'), ax=ax2)
    cbar.set_label('Grade Level')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'grade_performance.pdf'))
    plt.savefig(os.path.join(output_dir, 'grade_performance.png'))
    plt.close()

def create_resource_efficiency(results, output_dir):
    """Create resource efficiency visualization"""
    print("Creating resource efficiency visualization...")
    
    baselines = results['baselines']
    main = results['main_results']
    
    methods = ['Offline\nTextbooks', 'Kolibri\nVanilla', 'DistilBERT\nBaseline', 'Our Method']
    memory_usage = [
        baselines['offline_textbooks']['memory_usage_mb'],
        baselines['kolibri_vanilla']['memory_usage_mb'],
        baselines['distilbert_baseline']['memory_usage_mb'],
        main['memory_usage_mb']
    ]
    
    response_times = [
        0,  # Instant for textbooks
        baselines['kolibri_vanilla']['response_time_ms'],
        baselines['distilbert_baseline']['response_time_ms'],
        main['response_time_ms']
    ]
    
    accuracies = [
        0,  # No accuracy for textbooks
        0,  # No AI accuracy for Kolibri
        baselines['distilbert_baseline']['educational_accuracy'] * 100,
        main['educational_accuracy'] * 100
    ]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Memory usage comparison
    bars1 = ax1.bar(methods, memory_usage, color=['#ff7f0e', '#2ca02c', '#9467bd', '#1f77b4'])
    ax1.set_ylabel('Memory Usage (MB)')
    ax1.set_title('Memory Usage Comparison')
    ax1.axhline(y=4096, color='red', linestyle='--', alpha=0.7, label='4GB Target')
    ax1.legend()
    
    # Add value labels
    for bar, mem in zip(bars1, memory_usage):
        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 100,
                f'{mem:.0f}MB', ha='center', va='bottom')
    
    # Efficiency scatter plot (Response Time vs Accuracy)
    colors = ['orange', 'green', 'purple', 'blue']
    for i, (method, rt, acc) in enumerate(zip(methods, response_times, accuracies)):
        if rt > 0 and acc > 0:  # Only plot methods with both metrics
            ax2.scatter(rt, acc, s=200, c=colors[i], alpha=0.7, label=method)
            ax2.annotate(method, (rt, acc), xytext=(10, 10), 
                        textcoords='offset points', fontsize=10)
    
    ax2.set_xlabel('Response Time (ms)')
    ax2.set_ylabel('Educational Accuracy (%)')
    ax2.set_title('Efficiency Trade-off: Response Time vs Accuracy')
    ax2.grid(True, alpha=0.3)
    ax2.axvline(x=500, color='red', linestyle='--', alpha=0.7, label='500ms Target')
    ax2.axhline(y=90, color='red', linestyle='--', alpha=0.7, label='90% Target')
    ax2.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'resource_efficiency.pdf'))
    plt.savefig(os.path.join(output_dir, 'resource_efficiency.png'))
    plt.close()

def generate_all_visualizations():
    """Generate all publication-quality figures"""
    print("Starting visualization generation...")
    
    # Load results
    results, results_dir = load_results()
    
    # Create figures directory
    figures_dir = os.path.join(results_dir, 'figures')
    os.makedirs(figures_dir, exist_ok=True)
    
    # Generate all visualizations
    create_training_curves(results, figures_dir)
    create_performance_comparison(results, figures_dir)
    create_ablation_study(results, figures_dir)
    create_deployment_analysis(results, figures_dir)
    create_subject_performance(results, figures_dir)
    create_grade_performance(results, figures_dir)
    create_resource_efficiency(results, figures_dir)
    
    print(f"\nAll visualizations completed!")
    print(f"Figures saved to: {figures_dir}")
    print("Generated files:")
    
    figure_files = [
        "training_curves.pdf/png",
        "performance_comparison.pdf/png", 
        "ablation_study.pdf/png",
        "deployment_analysis.pdf/png",
        "subject_performance.pdf/png",
        "grade_performance.pdf/png",
        "resource_efficiency.pdf/png"
    ]
    
    for fig_file in figure_files:
        print(f"  ✓ {fig_file}")
    
    return figures_dir

if __name__ == "__main__":
    figures_directory = generate_all_visualizations()
    print(f"\nPublication-quality figures ready for paper inclusion!")
    print(f"Directory: {figures_directory}")