

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy import stats
from pathlib import Path


def setup_plot_style():
    """Setup clean ICML-style plotting"""
    plt.style.use('seaborn-v0_8-paper')
    plt.rcParams.update({
        'font.family': 'sans-serif',
        'font.sans-serif': ['DejaVu Sans', 'Arial', 'Helvetica'],
        'font.size': 10,
        'axes.labelsize': 11,
        'axes.titlesize': 12,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': 9,
        'figure.titlesize': 13,
        'figure.dpi': 150,
        'savefig.dpi': 300,
        'savefig.bbox': 'tight',
        'axes.grid': True,
        'grid.alpha': 0.3,
        'grid.linestyle': '--',
        'axes.spines.top': False,
        'axes.spines.right': False,
    })

setup_plot_style()


class MetricsVisualizer:
    """Visualization tool for tracking metrics"""
    
    def __init__(self, tracker, save_dir: str = "plots"):
        self.tracker = tracker
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(exist_ok=True)
        
        # Color palette
        self.colors = {
            'blue': '#4C72B0',
            'green': '#55A868', 
            'orange': '#DD8452',
            'purple': '#8172B3',
            'red': '#C44E52',
            'cyan': '#64B5CD'
        }
    
    def plot_all_metrics(self, save: bool = True):
        """Generate all visualization plots"""
        print("Generating visualizations...")
        self.plot_adaptivity_metrics(save=save)
        self.plot_sequential_learning(save=save)
        self.plot_coordination_conflicts(save=save)
        self.plot_overview_dashboard(save=save)
        print(f"Plots saved to {self.save_dir}/" if save else "Plots displayed")
    
    def plot_adaptivity_metrics(self, save: bool = True):
        """Plot adaptivity metrics with fixed normalization"""
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        fig.suptitle('Adaptivity Metrics', fontsize=14, fontweight='bold')
        
        iterations = sorted(self.tracker.adaptivity_metrics.keys())
        if not iterations:
            print("No adaptivity data to plot")
            return
        
        # Extract data
        entropy = [self.tracker.adaptivity_metrics[i]['avg_entropy'] for i in iterations]
        max_prob = [self.tracker.adaptivity_metrics[i]['max_probability'] for i in iterations]
        uncertainty = [self.tracker.adaptivity_metrics[i]['uncertainty_level'] for i in iterations]
        confidence = [self.tracker.adaptivity_metrics[i]['decision_confidence'] for i in iterations]
        
        # 1. Entropy
        ax = axes[0, 0]
        ax.plot(iterations, entropy, color=self.colors['blue'], linewidth=2)
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Entropy')
        ax.set_title('Action Distribution Entropy')
        
        # Add trend
        if len(iterations) > 2:
            z = np.polyfit(iterations, entropy, 1)
            ax.plot(iterations, np.poly1d(z)(iterations), '--', color=self.colors['red'], 
                   alpha=0.7, label=f'Trend: {z[0]:.4f}/iter')
            ax.legend()
        
        # 2. Max Probability
        ax = axes[0, 1]
        ax.plot(iterations, max_prob, color=self.colors['green'], linewidth=2)
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Max Probability')
        ax.set_title('Action Probability Concentration')
        ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='Random (0.5)')
        ax.legend()
        
        # 3. Uncertainty Level
        ax = axes[1, 0]
        ax.plot(iterations, uncertainty, color=self.colors['orange'], linewidth=2)
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Uncertainty Level')
        ax.set_title('Normalized Uncertainty (0=certain, 1=uncertain)')
        ax.set_ylim(-0.05, 1.05)
        ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
        
        # 4. Decision Confidence
        ax = axes[1, 1]
        ax.plot(iterations, confidence, color=self.colors['purple'], linewidth=2)
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Confidence')
        ax.set_title('Decision Confidence (1 - Uncertainty)')
        ax.set_ylim(-0.05, 1.05)
        
        # Add gain annotation
        if len(iterations) > 1:
            gain = (confidence[-1] - confidence[0]) * 100
            ax.annotate(f'Gain: {gain:+.1f}%', xy=(0.05, 0.95), xycoords='axes fraction',
                       fontsize=10, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        plt.tight_layout()
        if save:
            plt.savefig(self.save_dir / 'adaptivity_metrics.png', dpi=300, bbox_inches='tight')
        # plt.show()
    
    def plot_sequential_learning(self, save: bool = True):
        """Plot sequential learning metrics"""
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        fig.suptitle('Sequential Learning Progress', fontsize=14, fontweight='bold')
        
        iterations = sorted(self.tracker.sequential_learning_metrics.keys())
        if not iterations:
            print("No learning data to plot")
            return
        
        # Extract data
        iter_imp = [self.tracker.sequential_learning_metrics[i].get('total_iteration_improvement', 0) 
                   for i in iterations]
        avg_imp = [self.tracker.sequential_learning_metrics[i].get('avg_improvement_per_subproblem', 0) 
                  for i in iterations]
        best_reward = [self.tracker.sequential_learning_metrics[i]['cumulative_best_reward'] 
                      for i in iterations]
        
        # 1. Total Iteration Improvement
        ax = axes[0, 0]
        ax.plot(iterations, iter_imp, color=self.colors['blue'], linewidth=2)
        ax.axhline(y=0, color='red', linestyle='--', alpha=0.5)
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Improvement')
        ax.set_title('Improvement per Iteration')
        
        # Add positive rate
        positive_rate = sum(1 for x in iter_imp if x > 0) / len(iter_imp) * 100 if iter_imp else 0
        ax.annotate(f'Positive: {positive_rate:.0f}%', xy=(0.05, 0.95), xycoords='axes fraction',
                   fontsize=10, bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.8))
        
        # 2. Average Improvement per Subproblem
        ax = axes[0, 1]
        ax.plot(iterations, avg_imp, color=self.colors['green'], linewidth=2)
        ax.axhline(y=0, color='red', linestyle='--', alpha=0.5)
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Avg Improvement')
        ax.set_title('Average Improvement per Subproblem')
        
        # 3. Cumulative Best Reward
        ax = axes[1, 0]
        ax.plot(iterations, best_reward, color=self.colors['purple'], linewidth=2)
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Best Reward')
        ax.set_title('Best Reward Over Time')
        
        # Add improvement annotation
        if len(iterations) > 1:
            total_gain = best_reward[-1] - best_reward[0]
            ax.annotate(f'Total: {total_gain:+.4f}', xy=(0.05, 0.95), xycoords='axes fraction',
                       fontsize=10, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        # 4. Improvement Heatmap
        ax = axes[1, 1]
        distributions = [self.tracker.sequential_learning_metrics[i].get('improvement_distribution', []) 
                        for i in iterations]
        
        # Check if we have valid distribution data
        has_data = any(len(d) > 0 and any(x != 0 for x in d) for d in distributions)
        
        if has_data:
            max_len = max(len(d) for d in distributions)
            matrix = np.zeros((max_len, len(iterations)))
            for j, dist in enumerate(distributions):
                for i, val in enumerate(dist):
                    matrix[i, j] = val
            
            im = ax.imshow(matrix, aspect='auto', cmap='RdBu_r', 
                          vmin=-1, vmax=1, interpolation='nearest')
            ax.set_xlabel('Iteration')
            ax.set_ylabel('Subproblem')
            ax.set_title('Improvement Distribution')
            plt.colorbar(im, ax=ax, label='Contribution')
        else:
            ax.text(0.5, 0.5, 'No distribution data\n(improvements may be zero)', 
                   ha='center', va='center', transform=ax.transAxes, fontsize=11)
            ax.set_title('Improvement Distribution')
        
        plt.tight_layout()
        if save:
            plt.savefig(self.save_dir / 'sequential_learning.png', dpi=300, bbox_inches='tight')
        # plt.show()
    
    def plot_coordination_conflicts(self, save: bool = True):
        """Plot coordination and conflict metrics"""
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        fig.suptitle('Coordination & Conflict Resolution', fontsize=14, fontweight='bold')
        
        iterations = sorted(self.tracker.conflict_metrics.keys())
        if not iterations:
            print("No conflict data to plot")
            return
        
        # Extract data
        dual_norm = [self.tracker.conflict_metrics[i]['dual_var_norm'] for i in iterations]
        value_diff = [self.tracker.conflict_metrics[i]['value_differentiation'] for i in iterations]
        disagreements = [self.tracker.conflict_metrics[i]['actual_disagreements'] for i in iterations]
        disagree_rate = [self.tracker.conflict_metrics[i].get('disagreement_rate', 0) for i in iterations]
        num_overlaps = [self.tracker.conflict_metrics[i].get('num_overlaps', 0) for i in iterations]
        coord_score = [self.tracker.conflict_metrics[i]['coordination_score'] for i in iterations]
        
        # 1. Dual Variable Norm
        ax = axes[0, 0]
        ax.plot(iterations, dual_norm, color=self.colors['red'], linewidth=2)
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Dual Norm')
        ax.set_title('Dual Variable Norm')
        
        # Use log scale only if values span multiple orders of magnitude
        if max(dual_norm) > 0 and max(dual_norm) / (min(d for d in dual_norm if d > 0) + 1e-15) > 100:
            ax.set_yscale('log')
        
        # 2. Value Differentiation
        ax = axes[0, 1]
        ax.plot(iterations, value_diff, color=self.colors['purple'], linewidth=2)
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Value Differentiation')
        ax.set_title('Value Learning Quality')
        
        # 3. Disagreements vs Overlaps
        ax = axes[1, 0]
        ax.bar(iterations, disagreements, alpha=0.7, color=self.colors['orange'], label='Disagreements')
        ax.plot(iterations, num_overlaps, '--', color=self.colors['blue'], label='Total Overlaps')
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Count')
        ax.set_title('Disagreements vs Total Overlaps')
        ax.legend()
        
        # 4. Coordination Score
        ax = axes[1, 1]
        ax.plot(iterations, coord_score, color=self.colors['green'], linewidth=2)
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Coordination Score')
        ax.set_title('Coordination Score (higher=better)')
        ax.set_ylim(0, 1.05)
        
        plt.tight_layout()
        if save:
            plt.savefig(self.save_dir / 'coordination_conflicts.png', dpi=300, bbox_inches='tight')
        # plt.show()
    
    def plot_overview_dashboard(self, save: bool = True):
        """Combined overview dashboard"""
        fig = plt.figure(figsize=(15, 10))
        gs = fig.add_gridspec(3, 3, hspace=0.35, wspace=0.3)
        fig.suptitle('Algorithm Performance Dashboard', fontsize=14, fontweight='bold')
        
        iterations = sorted(self.tracker.adaptivity_metrics.keys())
        if not iterations:
            print("No data for dashboard")
            return
        
        # Row 1: Adaptivity
        ax1 = fig.add_subplot(gs[0, 0])
        entropy = [self.tracker.adaptivity_metrics[i]['avg_entropy'] for i in iterations]
        ax1.plot(iterations, entropy, color=self.colors['blue'], linewidth=1.5)
        ax1.set_title('Entropy', fontweight='bold')
        ax1.set_ylabel('Value')
        
        ax2 = fig.add_subplot(gs[0, 1])
        confidence = [self.tracker.adaptivity_metrics[i]['decision_confidence'] for i in iterations]
        ax2.plot(iterations, confidence, color=self.colors['green'], linewidth=1.5)
        ax2.set_title('Confidence', fontweight='bold')
        ax2.set_ylim(0, 1.05)
        
        ax3 = fig.add_subplot(gs[0, 2])
        uncertainty = [self.tracker.adaptivity_metrics[i]['uncertainty_level'] for i in iterations]
        ax3.plot(iterations, uncertainty, color=self.colors['orange'], linewidth=1.5)
        ax3.set_title('Uncertainty', fontweight='bold')
        ax3.set_ylim(0, 1.05)
        
        # Row 2: Learning
        ax4 = fig.add_subplot(gs[1, 0])
        best_reward = [self.tracker.sequential_learning_metrics[i]['cumulative_best_reward'] for i in iterations]
        ax4.plot(iterations, best_reward, color=self.colors['purple'], linewidth=1.5)
        ax4.set_title('Best Reward', fontweight='bold')
        ax4.set_ylabel('Value')
        
        ax5 = fig.add_subplot(gs[1, 1])
        iter_imp = [self.tracker.sequential_learning_metrics[i].get('total_iteration_improvement', 0) for i in iterations]
        ax5.plot(iterations, iter_imp, color=self.colors['blue'], linewidth=1.5)
        ax5.axhline(y=0, color='red', linestyle='--', alpha=0.5)
        ax5.set_title('Iteration Improvement', fontweight='bold')
        
        ax6 = fig.add_subplot(gs[1, 2])
        avg_imp = [self.tracker.sequential_learning_metrics[i].get('avg_improvement_per_subproblem', 0) for i in iterations]
        ax6.plot(iterations, avg_imp, color=self.colors['green'], linewidth=1.5)
        ax6.axhline(y=0, color='red', linestyle='--', alpha=0.5)
        ax6.set_title('Avg Improvement/Subproblem', fontweight='bold')
        
        # Row 3: Coordination
        ax7 = fig.add_subplot(gs[2, 0])
        dual_norm = [self.tracker.conflict_metrics[i]['dual_var_norm'] for i in iterations]
        ax7.plot(iterations, dual_norm, color=self.colors['red'], linewidth=1.5)
        ax7.set_title('Dual Norm', fontweight='bold')
        ax7.set_ylabel('Value')
        ax7.set_xlabel('Iteration')
        
        ax8 = fig.add_subplot(gs[2, 1])
        coord_score = [self.tracker.conflict_metrics[i]['coordination_score'] for i in iterations]
        ax8.plot(iterations, coord_score, color=self.colors['green'], linewidth=1.5)
        ax8.set_title('Coordination Score', fontweight='bold')
        ax8.set_xlabel('Iteration')
        ax8.set_ylim(0, 1.05)
        
        ax9 = fig.add_subplot(gs[2, 2])
        disagree_rate = [self.tracker.conflict_metrics[i].get('disagreement_rate', 0) for i in iterations]
        ax9.plot(iterations, disagree_rate, color=self.colors['orange'], linewidth=1.5)
        ax9.set_title('Disagreement Rate', fontweight='bold')
        ax9.set_xlabel('Iteration')
        ax9.set_ylim(0, max(0.1, max(disagree_rate) * 1.1) if disagree_rate else 1)
        
        if save:
            plt.savefig(self.save_dir / 'overview_dashboard.png', dpi=300, bbox_inches='tight')
        plt.show()
    
    def export_data_to_csv(self):
        """Export all tracked metrics to CSV files"""
        print(f"Exporting data to {self.save_dir}/...")
        
        iterations = sorted(self.tracker.adaptivity_metrics.keys())
        
        # Adaptivity
        if self.tracker.adaptivity_metrics:
            df = pd.DataFrame({
                'iteration': iterations,
                'avg_entropy': [self.tracker.adaptivity_metrics[i]['avg_entropy'] for i in iterations],
                'max_probability': [self.tracker.adaptivity_metrics[i]['max_probability'] for i in iterations],
                'uncertainty_level': [self.tracker.adaptivity_metrics[i]['uncertainty_level'] for i in iterations],
                'decision_confidence': [self.tracker.adaptivity_metrics[i]['decision_confidence'] for i in iterations]
            })
            df.to_csv(self.save_dir / 'adaptivity_metrics.csv', index=False)
        
        # Learning
        if self.tracker.sequential_learning_metrics:
            df = pd.DataFrame({
                'iteration': iterations,
                'cumulative_best_reward': [self.tracker.sequential_learning_metrics[i]['cumulative_best_reward'] for i in iterations],
                'total_iteration_improvement': [self.tracker.sequential_learning_metrics[i].get('total_iteration_improvement', 0) for i in iterations],
                'avg_improvement_per_subproblem': [self.tracker.sequential_learning_metrics[i].get('avg_improvement_per_subproblem', 0) for i in iterations],
                'subproblems_improved': [self.tracker.sequential_learning_metrics[i]['subproblems_improved'] for i in iterations],
                'total_subproblems': [self.tracker.sequential_learning_metrics[i]['total_subproblems'] for i in iterations]
            })
            df.to_csv(self.save_dir / 'sequential_learning.csv', index=False)
        
        # Conflicts
        if self.tracker.conflict_metrics:
            df = pd.DataFrame({
                'iteration': iterations,
                'dual_var_norm': [self.tracker.conflict_metrics[i]['dual_var_norm'] for i in iterations],
                'value_differentiation': [self.tracker.conflict_metrics[i]['value_differentiation'] for i in iterations],
                'coordination_score': [self.tracker.conflict_metrics[i]['coordination_score'] for i in iterations],
                'actual_disagreements': [self.tracker.conflict_metrics[i]['actual_disagreements'] for i in iterations],
                'disagreement_rate': [self.tracker.conflict_metrics[i].get('disagreement_rate', 0) for i in iterations],
                'num_overlaps': [self.tracker.conflict_metrics[i].get('num_overlaps', 0) for i in iterations]
            })
            df.to_csv(self.save_dir / 'coordination_conflicts.csv', index=False)
        
        print("  CSV files exported")
    
    def print_report(self):
        """Print statistical summary"""
        iterations = sorted(self.tracker.adaptivity_metrics.keys())
        if not iterations:
            print("No data for report")
            return
        
        print("\n" + "="*70)
        print(" STATISTICAL ANALYSIS REPORT ".center(70))
        print("="*70)
        
        # Adaptivity
        confidence = [self.tracker.adaptivity_metrics[i]['decision_confidence'] for i in iterations]
        if len(confidence) > 2:
            slope, _, r_value, p_value, _ = stats.linregress(iterations, confidence)
            print(f"\n1. ADAPTIVITY:")
            print(f"   Confidence trend: slope={slope:.6f}, R²={r_value**2:.4f}, p={p_value:.6f}")
            print(f"   → {'✓ Significant improvement' if slope > 0 and p_value < 0.05 else '○ No significant change'}")
        
        # Learning
        best_reward = [self.tracker.sequential_learning_metrics[i]['cumulative_best_reward'] for i in iterations]
        iter_imp = [self.tracker.sequential_learning_metrics[i].get('total_iteration_improvement', 0) for i in iterations]
        
        if len(best_reward) > 2:
            slope, _, r_value, p_value, _ = stats.linregress(iterations, best_reward)
            positive_rate = sum(1 for x in iter_imp if x > 0) / len(iter_imp)
            print(f"\n2. LEARNING:")
            print(f"   Reward trend: slope={slope:.6f}, R²={r_value**2:.4f}, p={p_value:.6f}")
            print(f"   Positive improvement rate: {positive_rate*100:.1f}%")
            print(f"   → {'✓ Consistent improvement' if positive_rate > 0.5 else '○ Inconsistent'}")
        
        # Coordination
        coord_score = [self.tracker.conflict_metrics[i]['coordination_score'] for i in iterations]
        if len(coord_score) > 2:
            slope, _, r_value, p_value, _ = stats.linregress(iterations, coord_score)
            print(f"\n3. COORDINATION:")
            print(f"   Score trend: slope={slope:.6f}, R²={r_value**2:.4f}, p={p_value:.6f}")
            print(f"   Final score: {coord_score[-1]:.4f}")
        
        print("\n" + "="*70 + "\n")


def quick_analysis(tracker, save_plots: bool = True, show_report: bool = True):
    """Quick analysis function"""
    viz = MetricsVisualizer(tracker)
    viz.plot_all_metrics(save=save_plots)
    if show_report:
        viz.print_report()
    viz.export_data_to_csv()
    return viz