"""
Visualization and Analysis for Tracked Combinatorial Bandit Algorithm

This file provides comprehensive plotting and statistical analysis for:
1. Adaptivity metrics (entropy, exploration/exploitation, uncertainty)
2. Sequential learning metrics (per-subproblem improvements, transfer learning)
3. Coordination conflict metrics (dual variables, overlap consistency)

Usage:
    from visualization_and_analysis import analyze_and_plot
    tracking_data = optimizer.get_tracking_data()
    analyze_and_plot(tracking_data, save_dir='./plots')
"""

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Optional
import os
from scipy import stats
from matplotlib.gridspec import GridSpec

# Set plotting style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (15, 10)
plt.rcParams['font.size'] = 10


def analyze_and_plot(tracking_data: Dict, 
                     save_dir: str = './plots',
                     show_plots: bool = True,
                     save_plots: bool = True):
    """
    Main function to create all visualizations and analyses.
    
    Args:
        tracking_data: Dictionary returned from optimizer.get_tracking_data()
        save_dir: Directory to save plots
        show_plots: Whether to display plots
        save_plots: Whether to save plots to files
    """
    if save_plots:
        os.makedirs(save_dir, exist_ok=True)
    
    print("="*70)
    print("COMPREHENSIVE ANALYSIS OF TRACKED METRICS")
    print("="*70)
    
    # Extract metrics
    adaptivity = tracking_data['adaptivity_metrics']
    sequential = tracking_data['sequential_learning_metrics']
    conflicts = tracking_data['conflict_metrics']
    basic_stats = tracking_data['basic_stats']
    rewards = tracking_data['rewards']
    
    print(f"\nBasic Statistics:")
    print(f"  Total iterations: {basic_stats['total_iterations']}")
    print(f"  Best reward: {basic_stats['best_reward']:.4f}")
    print(f"  Number of subproblems: {basic_stats['n_subproblems']}")
    print(f"  Convergence iteration: {basic_stats['convergence_iteration']}")
    
    # Create comprehensive plots
    print("\n" + "="*70)
    print("CREATING VISUALIZATIONS")
    print("="*70)
    
    # 1. ADAPTIVITY ANALYSIS
    print("\n1. Plotting Adaptivity Metrics...")
    plot_adaptivity_analysis(adaptivity, rewards, save_dir, show_plots, save_plots)
    
    # 2. SEQUENTIAL LEARNING ANALYSIS
    print("2. Plotting Sequential Learning Metrics...")
    plot_sequential_learning_analysis(sequential, rewards, basic_stats, save_dir, show_plots, save_plots)
    
    # 3. COORDINATION CONFLICT ANALYSIS
    print("3. Plotting Coordination Conflict Metrics...")
    plot_conflict_analysis(conflicts, rewards, save_dir, show_plots, save_plots)
    
    # 4. COMPREHENSIVE OVERVIEW
    print("4. Creating Comprehensive Overview...")
    plot_comprehensive_overview(adaptivity, sequential, conflicts, rewards, save_dir, show_plots, save_plots)
    
    # 5. STATISTICAL ANALYSIS
    print("5. Performing Statistical Analysis...")
    perform_statistical_analysis(adaptivity, sequential, conflicts, rewards, save_dir)
    
    print("\n" + "="*70)
    print("ANALYSIS COMPLETE!")
    print("="*70)
    if save_plots:
        print(f"\nPlots saved to: {save_dir}/")


def plot_adaptivity_analysis(adaptivity: Dict, rewards: List[float], 
                             save_dir: str, show: bool, save: bool):
    """Plot all adaptivity-related metrics."""
    
    fig = plt.figure(figsize=(18, 12))
    gs = GridSpec(3, 3, figure=fig, hspace=0.3, wspace=0.3)
    
    iterations = range(len(rewards))
    
    # 1. Average Entropy Over Time
    ax1 = fig.add_subplot(gs[0, 0])
    avg_entropy = adaptivity['avg_entropy']
    ax1.plot(iterations, avg_entropy, 'b-', linewidth=2, label='Avg Entropy')
    ax1.set_xlabel('Iteration')
    ax1.set_ylabel('Average Entropy')
    ax1.set_title('Entropy Evolution (Exploration → Exploitation)')
    ax1.grid(True, alpha=0.3)
    ax1.legend()
    
    # Add trend line
    if len(iterations) > 1:
        z = np.polyfit(list(iterations), avg_entropy, 1)
        p = np.poly1d(z)
        ax1.plot(iterations, p(list(iterations)), "r--", alpha=0.5, label=f'Trend (slope={z[0]:.4f})')
        ax1.legend()
    
    # 2. Maximum Probability Evolution (Concentration)
    ax2 = fig.add_subplot(gs[0, 1])
    if adaptivity['max_prob_per_position']:
        max_probs = np.array(adaptivity['max_prob_per_position'])
        avg_max_prob = np.mean(max_probs, axis=1)
        ax2.plot(iterations, avg_max_prob, 'g-', linewidth=2, label='Avg Max Prob')
        ax2.set_xlabel('Iteration')
        ax2.set_ylabel('Maximum Probability')
        ax2.set_title('Action Concentration Over Time')
        ax2.grid(True, alpha=0.3)
        ax2.legend()
    
    # 3. Temperature Decay
    ax3 = fig.add_subplot(gs[0, 2])
    temp_history = adaptivity['temperature_history']
    ax3.plot(iterations, temp_history, 'r-', linewidth=2)
    ax3.set_xlabel('Iteration')
    ax3.set_ylabel('Temperature')
    ax3.set_title('Temperature Decay (Annealing)')
    ax3.set_yscale('log')
    ax3.grid(True, alpha=0.3)
    
    # 4. Overlap Evolution
    ax4 = fig.add_subplot(gs[1, 0])
    overlap_history = adaptivity['overlap_history']
    ax4.plot(iterations, overlap_history, 'm-', linewidth=2)
    ax4.set_xlabel('Iteration')
    ax4.set_ylabel('Overlap Size')
    ax4.set_title('Overlap Diminishing Over Time')
    ax4.grid(True, alpha=0.3)
    
    # 5. Uncertainty Evolution
    ax5 = fig.add_subplot(gs[1, 1])
    uncertainty = adaptivity['uncertainty_by_iteration']
    ax5.plot(iterations, uncertainty, 'c-', linewidth=2)
    ax5.set_xlabel('Iteration')
    ax5.set_ylabel('Uncertainty')
    ax5.set_title('Uncertainty Reduction')
    ax5.grid(True, alpha=0.3)
    
    # 6. Entropy Heatmap (per position over time)
    ax6 = fig.add_subplot(gs[1, 2])
    if adaptivity['entropy_per_position']:
        entropy_matrix = np.array(adaptivity['entropy_per_position'][:50])  # First 50 iterations
        if len(entropy_matrix) > 0:
            im = ax6.imshow(entropy_matrix.T, aspect='auto', cmap='viridis', interpolation='nearest')
            ax6.set_xlabel('Iteration')
            ax6.set_ylabel('Position')
            ax6.set_title('Entropy Heatmap (First 50 Iterations)')
            plt.colorbar(im, ax=ax6, label='Entropy')
    
    # 7. Reward vs Entropy (Dual Axis)
    ax7 = fig.add_subplot(gs[2, :])
    ax7_twin = ax7.twinx()
    
    ax7.plot(iterations, rewards, 'b-', linewidth=2, label='Reward', alpha=0.7)
    ax7_twin.plot(iterations, avg_entropy, 'r-', linewidth=2, label='Avg Entropy', alpha=0.7)
    
    ax7.set_xlabel('Iteration')
    ax7.set_ylabel('Reward', color='b')
    ax7_twin.set_ylabel('Average Entropy', color='r')
    ax7.set_title('Reward Improvement vs Entropy Reduction')
    ax7.tick_params(axis='y', labelcolor='b')
    ax7_twin.tick_params(axis='y', labelcolor='r')
    ax7.grid(True, alpha=0.3)
    
    # Combine legends
    lines1, labels1 = ax7.get_legend_handles_labels()
    lines2, labels2 = ax7_twin.get_legend_handles_labels()
    ax7.legend(lines1 + lines2, labels1 + labels2, loc='best')
    
    fig.suptitle('ADAPTIVITY ANALYSIS: Exploration to Exploitation Transition', 
                 fontsize=16, fontweight='bold')
    
    if save:
        plt.savefig(f"{save_dir}/adaptivity_analysis.png", dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    plt.close()


def plot_sequential_learning_analysis(sequential: Dict, rewards: List[float],
                                      basic_stats: Dict, save_dir: str, show: bool, save: bool):
    """Plot all sequential learning metrics."""
    
    fig = plt.figure(figsize=(18, 12))
    gs = GridSpec(3, 3, figure=fig, hspace=0.3, wspace=0.3)
    
    iterations = range(len(rewards))
    
    # 1. Cumulative Rewards from Subproblem Solutions
    ax1 = fig.add_subplot(gs[0, :])
    cumulative_rewards = sequential['cumulative_rewards']
    subproblem_indices = range(len(cumulative_rewards))
    ax1.plot(subproblem_indices, cumulative_rewards, 'b-', linewidth=1.5, alpha=0.7)
    ax1.set_xlabel('Cumulative Subproblems Solved')
    ax1.set_ylabel('Reward')
    ax1.set_title('Learning Curve: Reward vs Cumulative Subproblems Solved')
    ax1.grid(True, alpha=0.3)
    
    # Add iteration markers
    n_subproblems = basic_stats['n_subproblems']
    for i in range(0, len(cumulative_rewards), n_subproblems):
        ax1.axvline(x=i, color='r', linestyle='--', alpha=0.2)
    
    # 2. Success Rate Over Iterations
    ax2 = fig.add_subplot(gs[1, 0])
    success_rate = sequential['subproblem_success_rate']
    ax2.plot(iterations, success_rate, 'g-', linewidth=2)
    ax2.set_xlabel('Iteration')
    ax2.set_ylabel('Success Rate')
    ax2.set_title('Subproblem Success Rate per Iteration')
    ax2.set_ylim([0, 1.05])
    ax2.grid(True, alpha=0.3)
    
    # Add trend line
    if len(iterations) > 1:
        z = np.polyfit(list(iterations), success_rate, 1)
        p = np.poly1d(z)
        ax2.plot(iterations, p(list(iterations)), "r--", alpha=0.5, 
                label=f'Trend (slope={z[0]:.4f})')
        ax2.legend()
    
    # 3. Improvement per Subproblem (Heatmap)
    ax3 = fig.add_subplot(gs[1, 1])
    subproblem_improvements = sequential['subproblem_improvements']
    
    if subproblem_improvements:
        # Create matrix: iterations x subproblems
        improvement_matrix = []
        for iter_data in subproblem_improvements:
            iter_improvements = [sp['improvement'] for sp in iter_data]
            improvement_matrix.append(iter_improvements)
        
        improvement_matrix = np.array(improvement_matrix[:50])  # First 50 iterations
        if len(improvement_matrix) > 0:
            im = ax3.imshow(improvement_matrix, aspect='auto', cmap='RdYlGn', 
                           interpolation='nearest', vmin=-0.1, vmax=0.5)
            ax3.set_xlabel('Subproblem Index')
            ax3.set_ylabel('Iteration')
            ax3.set_title('Improvement Heatmap (First 50 Iterations)')
            plt.colorbar(im, ax=ax3, label='Improvement')
    
    # 4. Average Improvement per Subproblem Position
    ax4 = fig.add_subplot(gs[1, 2])
    if subproblem_improvements:
        n_subproblems = len(subproblem_improvements[0])
        avg_improvements = [0.0] * n_subproblems
        
        for iter_data in subproblem_improvements:
            for sp_data in iter_data:
                sp_idx = sp_data['sp_idx']
                avg_improvements[sp_idx] += sp_data['improvement']
        
        avg_improvements = [imp / len(subproblem_improvements) for imp in avg_improvements]
        
        ax4.bar(range(n_subproblems), avg_improvements, color='skyblue', edgecolor='navy')
        ax4.set_xlabel('Subproblem Index')
        ax4.set_ylabel('Average Improvement')
        ax4.set_title('Average Improvement by Subproblem Position')
        ax4.grid(True, alpha=0.3, axis='y')
    
    # 5. Early vs Late Iteration Comparison
    ax5 = fig.add_subplot(gs[2, 0])
    if len(subproblem_improvements) >= 20:
        early_iters = subproblem_improvements[:10]
        late_iters = subproblem_improvements[-10:]
        
        early_improvements = [sp['improvement'] for iter_data in early_iters for sp in iter_data]
        late_improvements = [sp['improvement'] for iter_data in late_iters for sp in iter_data]
        
        ax5.hist(early_improvements, bins=20, alpha=0.5, label='Early (1-10)', color='red')
        ax5.hist(late_improvements, bins=20, alpha=0.5, label='Late (last 10)', color='green')
        ax5.set_xlabel('Improvement')
        ax5.set_ylabel('Frequency')
        ax5.set_title('Distribution: Early vs Late Iterations')
        ax5.legend()
        ax5.grid(True, alpha=0.3)
    
    # 6. Convergence Speed: Improvements over Time
    ax6 = fig.add_subplot(gs[2, 1])
    improvements_per_subproblem = sequential['improvement_per_subproblem']
    
    # Moving average
    window_size = min(50, len(improvements_per_subproblem) // 10)
    if window_size > 1:
        moving_avg = np.convolve(improvements_per_subproblem, 
                                np.ones(window_size)/window_size, mode='valid')
        ax6.plot(range(len(moving_avg)), moving_avg, 'b-', linewidth=2)
        ax6.set_xlabel('Subproblem Number')
        ax6.set_ylabel('Improvement (Moving Avg)')
        ax6.set_title(f'Improvement Trend (Window={window_size})')
        ax6.grid(True, alpha=0.3)
    
    # 7. Transfer Learning Evidence
    ax7 = fig.add_subplot(gs[2, 2])
    if subproblem_improvements:
        # Compare first subproblem of each iteration
        first_sp_improvements = [iter_data[0]['improvement'] 
                                for iter_data in subproblem_improvements]
        
        ax7.plot(range(len(first_sp_improvements)), first_sp_improvements, 
                'o-', markersize=4, linewidth=1.5, color='purple')
        ax7.set_xlabel('Iteration')
        ax7.set_ylabel('Improvement (First Subproblem)')
        ax7.set_title('Transfer Learning: First Subproblem Performance')
        ax7.grid(True, alpha=0.3)
        
        # Add trend
        if len(first_sp_improvements) > 1:
            z = np.polyfit(range(len(first_sp_improvements)), first_sp_improvements, 1)
            p = np.poly1d(z)
            ax7.plot(range(len(first_sp_improvements)), 
                    p(range(len(first_sp_improvements))), 
                    "r--", alpha=0.5, label=f'Trend (slope={z[0]:.4f})')
            ax7.legend()
    
    fig.suptitle('SEQUENTIAL LEARNING ANALYSIS: Improvement Over Subproblems', 
                 fontsize=16, fontweight='bold')
    
    if save:
        plt.savefig(f"{save_dir}/sequential_learning_analysis.png", dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    plt.close()


def plot_conflict_analysis(conflicts: Dict, rewards: List[float],
                           save_dir: str, show: bool, save: bool):
    """Plot all coordination conflict metrics."""
    
    fig = plt.figure(figsize=(18, 12))
    gs = GridSpec(3, 3, figure=fig, hspace=0.3, wspace=0.3)
    
    iterations = range(len(rewards))
    
    # 1. Dual Variable Norm Convergence
    ax1 = fig.add_subplot(gs[0, 0])
    dual_norm = conflicts['dual_variable_norm']
    ax1.plot(iterations, dual_norm, 'r-', linewidth=2)
    ax1.set_xlabel('Iteration')
    ax1.set_ylabel('||λ|| (Dual Variable Norm)')
    ax1.set_title('Dual Variable Convergence')
    ax1.grid(True, alpha=0.3)
    ax1.set_yscale('log')
    
    # 2. Overlap Conflict Scores
    ax2 = fig.add_subplot(gs[0, 1])
    conflict_scores = conflicts['overlap_conflict_scores']
    ax2.plot(iterations, conflict_scores, 'orange', linewidth=2)
    ax2.set_xlabel('Iteration')
    ax2.set_ylabel('Conflict Score')
    ax2.set_title('Overlap Conflict Reduction')
    ax2.grid(True, alpha=0.3)
    
    # 3. Action Consistency
    ax3 = fig.add_subplot(gs[0, 2])
    consistency = conflicts['action_consistency']
    ax3.plot(iterations, consistency, 'g-', linewidth=2)
    ax3.set_xlabel('Iteration')
    ax3.set_ylabel('Consistency Ratio')
    ax3.set_title('Action Consistency in Overlaps')
    ax3.set_ylim([0, 1.05])
    ax3.grid(True, alpha=0.3)
    
    # 4. Constraint Violations
    ax4 = fig.add_subplot(gs[1, 0])
    violations = conflicts['constraint_violations']
    ax4.plot(iterations, violations, 'purple', linewidth=2)
    ax4.set_xlabel('Iteration')
    ax4.set_ylabel('Avg Violation')
    ax4.set_title('Constraint Violation Reduction')
    ax4.grid(True, alpha=0.3)
    
    # 5. Dual Variables Heatmap
    ax5 = fig.add_subplot(gs[1, 1])
    dual_history = np.array(conflicts['dual_variables_history'][:50])  # First 50 iterations
    if len(dual_history) > 0:
        im = ax5.imshow(dual_history.T, aspect='auto', cmap='coolwarm', interpolation='nearest')
        ax5.set_xlabel('Iteration')
        ax5.set_ylabel('Position')
        ax5.set_title('Dual Variables Evolution (First 50 Iterations)')
        plt.colorbar(im, ax=ax5, label='λ value')
    
    # 6. Reward vs Dual Norm (Dual Axis)
    ax6 = fig.add_subplot(gs[1, 2])
    ax6_twin = ax6.twinx()
    
    ax6.plot(iterations, rewards, 'b-', linewidth=2, label='Reward', alpha=0.7)
    ax6_twin.plot(iterations, dual_norm, 'r-', linewidth=2, label='||λ||', alpha=0.7)
    
    ax6.set_xlabel('Iteration')
    ax6.set_ylabel('Reward', color='b')
    ax6_twin.set_ylabel('||λ||', color='r')
    ax6.set_title('Reward vs Dual Variable Norm')
    ax6.tick_params(axis='y', labelcolor='b')
    ax6_twin.tick_params(axis='y', labelcolor='r')
    ax6.grid(True, alpha=0.3)
    
    lines1, labels1 = ax6.get_legend_handles_labels()
    lines2, labels2 = ax6_twin.get_legend_handles_labels()
    ax6.legend(lines1 + lines2, labels1 + labels2, loc='best')
    
    # 7. All Conflict Metrics Together (Normalized)
    ax7 = fig.add_subplot(gs[2, :])
    
    # Normalize metrics to [0, 1] for comparison
    def normalize(data):
        data = np.array(data)
        min_val, max_val = np.min(data), np.max(data)
        if max_val - min_val < 1e-10:
            return np.ones_like(data)
        return (data - min_val) / (max_val - min_val)
    
    ax7.plot(iterations, normalize(dual_norm), 'r-', linewidth=2, label='Dual Norm (norm)', alpha=0.7)
    ax7.plot(iterations, normalize(conflict_scores), 'orange', linewidth=2, 
            label='Conflict Score (norm)', alpha=0.7)
    ax7.plot(iterations, 1 - normalize(consistency), 'g--', linewidth=2, 
            label='Inconsistency (norm)', alpha=0.7)
    ax7.plot(iterations, normalize(violations), 'purple', linewidth=2, 
            label='Violations (norm)', alpha=0.7)
    
    ax7.set_xlabel('Iteration')
    ax7.set_ylabel('Normalized Value')
    ax7.set_title('All Conflict Metrics (Normalized) - Should Decrease')
    ax7.legend(loc='best')
    ax7.grid(True, alpha=0.3)
    
    fig.suptitle('COORDINATION CONFLICT ANALYSIS: Conflict Resolution', 
                 fontsize=16, fontweight='bold')
    
    if save:
        plt.savefig(f"{save_dir}/conflict_analysis.png", dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    plt.close()


def plot_comprehensive_overview(adaptivity: Dict, sequential: Dict, 
                                conflicts: Dict, rewards: List[float],
                                save_dir: str, show: bool, save: bool):
    """Create a single comprehensive overview figure."""
    
    fig = plt.figure(figsize=(20, 12))
    gs = GridSpec(3, 4, figure=fig, hspace=0.3, wspace=0.3)
    
    iterations = range(len(rewards))
    
    # 1. Reward Evolution
    ax1 = fig.add_subplot(gs[0, :2])
    ax1.plot(iterations, rewards, 'b-', linewidth=2)
    ax1.set_xlabel('Iteration')
    ax1.set_ylabel('Reward')
    ax1.set_title('Reward Evolution', fontsize=14, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    
    # 2. Entropy
    ax2 = fig.add_subplot(gs[0, 2])
    avg_entropy = adaptivity['avg_entropy']
    ax2.plot(iterations, avg_entropy, 'orange', linewidth=2)
    ax2.set_xlabel('Iteration')
    ax2.set_ylabel('Avg Entropy')
    ax2.set_title('Adaptivity: Entropy')
    ax2.grid(True, alpha=0.3)
    
    # 3. Success Rate
    ax3 = fig.add_subplot(gs[0, 3])
    success_rate = sequential['subproblem_success_rate']
    ax3.plot(iterations, success_rate, 'g-', linewidth=2)
    ax3.set_xlabel('Iteration')
    ax3.set_ylabel('Success Rate')
    ax3.set_title('Learning: Success Rate')
    ax3.grid(True, alpha=0.3)
    
    # 4. Cumulative Learning
    ax4 = fig.add_subplot(gs[1, :2])
    cumulative_rewards = sequential['cumulative_rewards']
    ax4.plot(range(len(cumulative_rewards)), cumulative_rewards, 'b-', linewidth=1.5, alpha=0.7)
    ax4.set_xlabel('Cumulative Subproblems')
    ax4.set_ylabel('Reward')
    ax4.set_title('Sequential Learning: Cumulative Progress', fontsize=14, fontweight='bold')
    ax4.grid(True, alpha=0.3)
    
    # 5. Dual Norm
    ax5 = fig.add_subplot(gs[1, 2])
    dual_norm = conflicts['dual_variable_norm']
    ax5.plot(iterations, dual_norm, 'r-', linewidth=2)
    ax5.set_xlabel('Iteration')
    ax5.set_ylabel('||λ||')
    ax5.set_title('Conflict: Dual Norm')
    ax5.set_yscale('log')
    ax5.grid(True, alpha=0.3)
    
    # 6. Conflict Score
    ax6 = fig.add_subplot(gs[1, 3])
    conflict_scores = conflicts['overlap_conflict_scores']
    ax6.plot(iterations, conflict_scores, 'purple', linewidth=2)
    ax6.set_xlabel('Iteration')
    ax6.set_ylabel('Conflict Score')
    ax6.set_title('Conflict: Overlap Score')
    ax6.grid(True, alpha=0.3)
    
    # 7. Three Key Metrics Together
    ax7 = fig.add_subplot(gs[2, :])
    
    # Create normalized versions for comparison
    def normalize(data):
        data = np.array(data)
        min_val, max_val = np.min(data), np.max(data)
        if max_val - min_val < 1e-10:
            return np.ones_like(data)
        return (data - min_val) / (max_val - min_val)
    
    ax7_1 = ax7
    ax7_2 = ax7.twinx()
    ax7_3 = ax7.twinx()
    ax7_3.spines['right'].set_position(('outward', 60))
    
    p1 = ax7_1.plot(iterations, normalize(avg_entropy), 'orange', linewidth=2, 
                    label='Entropy (Adaptivity)', alpha=0.8)
    p2 = ax7_2.plot(iterations, success_rate, 'g-', linewidth=2, 
                    label='Success Rate (Learning)', alpha=0.8)
    p3 = ax7_3.plot(iterations, normalize(dual_norm), 'r-', linewidth=2, 
                    label='Dual Norm (Conflict)', alpha=0.8)
    
    ax7_1.set_xlabel('Iteration', fontsize=12)
    ax7_1.set_ylabel('Entropy (norm)', color='orange', fontsize=10)
    ax7_2.set_ylabel('Success Rate', color='g', fontsize=10)
    ax7_3.set_ylabel('Dual Norm (norm)', color='r', fontsize=10)
    
    ax7_1.tick_params(axis='y', labelcolor='orange')
    ax7_2.tick_params(axis='y', labelcolor='g')
    ax7_3.tick_params(axis='y', labelcolor='r')
    
    ax7_1.set_title('Key Metrics: Adaptivity, Learning, Conflict Resolution', 
                    fontsize=14, fontweight='bold')
    ax7_1.grid(True, alpha=0.3)
    
    # Combine legends
    lns = p1 + p2 + p3
    labs = [l.get_label() for l in lns]
    ax7_1.legend(lns, labs, loc='best', fontsize=10)
    
    fig.suptitle('COMPREHENSIVE OVERVIEW: Algorithm Behavior', 
                 fontsize=18, fontweight='bold')
    
    if save:
        plt.savefig(f"{save_dir}/comprehensive_overview.png", dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    plt.close()


def perform_statistical_analysis(adaptivity: Dict, sequential: Dict,
                                 conflicts: Dict, rewards: List[float],
                                 save_dir: str):
    """Perform statistical tests and save results."""
    
    report = []
    report.append("="*70)
    report.append("STATISTICAL ANALYSIS REPORT")
    report.append("="*70)
    
    iterations = np.arange(len(rewards))
    
    # 1. ADAPTIVITY: Entropy decreases?
    report.append("\n1. ADAPTIVITY ANALYSIS")
    report.append("-" * 70)
    
    avg_entropy = np.array(adaptivity['avg_entropy'])
    slope_entropy, intercept, r_value, p_value, std_err = stats.linregress(iterations, avg_entropy)
    
    report.append(f"Entropy Trend:")
    report.append(f"  Slope: {slope_entropy:.6f}")
    report.append(f"  R²: {r_value**2:.4f}")
    report.append(f"  P-value: {p_value:.4e}")
    report.append(f"  Result: {'✓ DECREASING' if slope_entropy < 0 and p_value < 0.05 else '✗ NOT SIGNIFICANT'}")
    
    # Temperature
    temp = np.array(adaptivity['temperature_history'])
    slope_temp = (temp[-1] - temp[0]) / len(temp)
    report.append(f"\nTemperature Decay:")
    report.append(f"  Initial: {temp[0]:.4f}, Final: {temp[-1]:.4f}")
    report.append(f"  Avg Decay Rate: {slope_temp:.6f} per iteration")
    
    # 2. SEQUENTIAL LEARNING: Success rate improves?
    report.append("\n2. SEQUENTIAL LEARNING ANALYSIS")
    report.append("-" * 70)
    
    success_rate = np.array(sequential['subproblem_success_rate'])
    slope_success, intercept, r_value, p_value, std_err = stats.linregress(iterations, success_rate)
    
    report.append(f"Success Rate Trend:")
    report.append(f"  Slope: {slope_success:.6f}")
    report.append(f"  R²: {r_value**2:.4f}")
    report.append(f"  P-value: {p_value:.4e}")
    report.append(f"  Result: {'✓ IMPROVING' if slope_success > 0 else '✗ NOT IMPROVING'}")
    
    # Compare early vs late
    if len(sequential['subproblem_improvements']) >= 20:
        early = sequential['subproblem_improvements'][:10]
        late = sequential['subproblem_improvements'][-10:]
        
        early_improvements = [sp['improvement'] for iter_data in early for sp in iter_data]
        late_improvements = [sp['improvement'] for iter_data in late for sp in iter_data]
        
        t_stat, p_value_ttest = stats.ttest_ind(late_improvements, early_improvements)
        
        report.append(f"\nEarly vs Late Iterations (t-test):")
        report.append(f"  Early mean improvement: {np.mean(early_improvements):.6f}")
        report.append(f"  Late mean improvement: {np.mean(late_improvements):.6f}")
        report.append(f"  T-statistic: {t_stat:.4f}")
        report.append(f"  P-value: {p_value_ttest:.4e}")
        report.append(f"  Result: {'✓ SIGNIFICANT IMPROVEMENT' if p_value_ttest < 0.05 and t_stat > 0 else '✗ NOT SIGNIFICANT'}")
    
    # 3. CONFLICT RESOLUTION: Conflicts decrease?
    report.append("\n3. COORDINATION CONFLICT ANALYSIS")
    report.append("-" * 70)
    
    dual_norm = np.array(conflicts['dual_variable_norm'])
    if len(dual_norm) > 0 and np.max(dual_norm) > 0:
        slope_dual, intercept, r_value, p_value, std_err = stats.linregress(iterations, dual_norm)
        
        report.append(f"Dual Variable Norm Trend:")
        report.append(f"  Slope: {slope_dual:.6f}")
        report.append(f"  R²: {r_value**2:.4f}")
        report.append(f"  P-value: {p_value:.4e}")
        report.append(f"  Result: {'✓ CONVERGING' if slope_dual < 0 and p_value < 0.05 else '✗ NOT CONVERGING'}")
    
    conflict_scores = np.array(conflicts['overlap_conflict_scores'])
    slope_conflict, intercept, r_value, p_value, std_err = stats.linregress(iterations, conflict_scores)
    
    report.append(f"\nOverlap Conflict Score Trend:")
    report.append(f"  Slope: {slope_conflict:.6f}")
    report.append(f"  R²: {r_value**2:.4f}")
    report.append(f"  P-value: {p_value:.4e}")
    report.append(f"  Result: {'✓ RESOLVING' if slope_conflict < 0 and p_value < 0.05 else '✗ NOT RESOLVING'}")
    
    consistency = np.array(conflicts['action_consistency'])
    slope_consistency, intercept, r_value, p_value, std_err = stats.linregress(iterations, consistency)
    
    report.append(f"\nAction Consistency Trend:")
    report.append(f"  Slope: {slope_consistency:.6f}")
    report.append(f"  R²: {r_value**2:.4f}")
    report.append(f"  P-value: {p_value:.4e}")
    report.append(f"  Result: {'✓ IMPROVING' if slope_consistency > 0 and p_value < 0.05 else '✗ NOT IMPROVING'}")
    
    # 4. OVERALL SUMMARY
    report.append("\n" + "="*70)
    report.append("SUMMARY OF FINDINGS")
    report.append("="*70)
    
    report.append("\n✓ = Statistically significant evidence (p < 0.05)")
    report.append("✗ = Not statistically significant")
    
    # Save report
    report_text = "\n".join(report)
    print("\n" + report_text)
    
    with open(f"{save_dir}/statistical_analysis.txt", 'w') as f:
        f.write(report_text)
    
    print(f"\nStatistical analysis saved to: {save_dir}/statistical_analysis.txt")


if __name__ == "__main__":
    print("This module provides visualization functions for the tracked algorithm.")
    print("\nUsage example:")
    print("  from visualization_and_analysis import analyze_and_plot")
    print("  tracking_data = optimizer.get_tracking_data()")
    print("  analyze_and_plot(tracking_data, save_dir='./plots')")