import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from matplotlib.patches import Rectangle
import matplotlib.patches as mpatches
import warnings

warnings.filterwarnings('ignore')

# Set academic style configuration
plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Times New Roman', 'DejaVu Serif', 'serif'],
    'font.size': 10,
    'axes.labelsize': 12,
    'axes.titlesize': 14,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 10,
    'figure.titlesize': 16,
    'text.usetex': False,  # Set to True if LaTeX is available
    'axes.grid': True,
    'grid.alpha': 0.3,
    'axes.axisbelow': True,
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight'
})

# Academic color palette - colorblind friendly
ACADEMIC_COLORS = {
    'geom-gcn': '#1f77b4',  # Blue
    'public': '#ff7f0e',  # Orange
    'random': '#2ca02c'  # Green
}

# Alternative colorblind-friendly palette
COLORBLIND_PALETTE = {
    'geom-gcn': '#0173B2',  # Blue
    'public': '#DE8F05',  # Orange
    'random': '#029E73'  # Green
}


class AcademicKStepsVisualizer:
    def __init__(self, use_colorblind_palette=False):
        """Initialize the academic visualizer with ablation study data"""

        # Font configuration - define all font sizes at the beginning
        self.FONT_CONFIG = {
            'title_main': 18,  # Main figure titles
            'title_subplot': 15,  # Subplot titles
            'suptitle': 18,  # Figure super titles
            'axis_label': 16,  # X and Y axis labels
            'tick_label': 16,  # Axis tick labels
            'legend': 16,  # Legend text
            'annotation': 14,  # Point annotations
            'annotation_small': 12,  # Small annotations
            'heatmap_annot': 14,  # Heatmap cell annotations
            'colorbar': 14  # Colorbar labels
        }

        # Plot styling configuration
        self.PLOT_CONFIG = {
            'line_width': 4,  # Main line width
            'line_width_small': 2,  # Smaller line width for subplots
            'marker_size': 8,  # Main marker size
            'marker_size_small': 6,  # Smaller marker size for subplots
            'marker_edge_width': 1,  # Marker edge width
            'marker_edge_width_small': 0.5,  # Smaller marker edge width
            'grid_alpha': 0.3,  # Grid transparency
            'spine_width': 0.8,  # Plot border width
            'tick_length': 4,  # Tick mark length
            'tick_length_small': 3  # Smaller tick mark length
        }

        self.steps = [1, 2, 3, 4, 5, 7, 10, 15, 20, 25]

        # Data organized by split (renamed: fixed -> geom-gcn, full -> random, public -> public)
        self.data = {
            'Geom-GCN': [
                {'steps': 1, 'val_loss': 1.2603, 'val_f1': 85.33, 'test_accuracy': 84.71, 'test_f1': 83.87,
                 'duration': 81.870},
                {'steps': 2, 'val_loss': 1.2639, 'val_f1': 86.56, 'test_accuracy': 84.53, 'test_f1': 83.36,
                 'duration': 78.367},
                {'steps': 3, 'val_loss': 1.2619, 'val_f1': 85.53, 'test_accuracy': 85.82, 'test_f1': 84.64,
                 'duration': 61.508},
                {'steps': 4, 'val_loss': 1.2601, 'val_f1': 85.54, 'test_accuracy': 84.16, 'test_f1': 83.18,
                 'duration': 67.351},
                {'steps': 5, 'val_loss': 1.2603, 'val_f1': 84.54, 'test_accuracy': 83.61, 'test_f1': 82.32,
                 'duration': 71.850},
                {'steps': 7, 'val_loss': 1.2618, 'val_f1': 86.44, 'test_accuracy': 85.45, 'test_f1': 84.77,
                 'duration': 85.947},
                {'steps': 10, 'val_loss': 1.2588, 'val_f1': 87.44, 'test_accuracy': 84.90, 'test_f1': 83.85,
                 'duration': 101.341},
                {'steps': 15, 'val_loss': 1.2615, 'val_f1': 86.18, 'test_accuracy': 83.06, 'test_f1': 82.00,
                 'duration': 130.009},
                {'steps': 20, 'val_loss': 1.2621, 'val_f1': 86.02, 'test_accuracy': 84.71, 'test_f1': 83.38,
                 'duration': 158.724},
                {'steps': 25, 'val_loss': 1.2651, 'val_f1': 86.52, 'test_accuracy': 84.53, 'test_f1': 83.13,
                 'duration': 188.220}
            ],
            'Public': [
                {'steps': 1, 'val_loss': 1.3724, 'val_f1': 77.98, 'test_accuracy': 78.30, 'test_f1': 77.16,
                 'duration': 52.432},
                {'steps': 2, 'val_loss': 1.3613, 'val_f1': 77.91, 'test_accuracy': 77.80, 'test_f1': 76.64,
                 'duration': 62.481},
                {'steps': 3, 'val_loss': 1.3728, 'val_f1': 75.45, 'test_accuracy': 79.30, 'test_f1': 78.11,
                 'duration': 62.681},
                {'steps': 4, 'val_loss': 1.3763, 'val_f1': 76.62, 'test_accuracy': 79.10, 'test_f1': 78.57,
                 'duration': 66.106},
                {'steps': 5, 'val_loss': 1.3683, 'val_f1': 78.29, 'test_accuracy': 81.30, 'test_f1': 79.93,
                 'duration': 71.591},
                {'steps': 7, 'val_loss': 1.3806, 'val_f1': 75.72, 'test_accuracy': 77.40, 'test_f1': 76.71,
                 'duration': 81.074},
                {'steps': 10, 'val_loss': 1.3728, 'val_f1': 76.50, 'test_accuracy': 78.20, 'test_f1': 77.18,
                 'duration': 98.142},
                {'steps': 15, 'val_loss': 1.3712, 'val_f1': 74.74, 'test_accuracy': 79.80, 'test_f1': 78.55,
                 'duration': 126.408},
                {'steps': 20, 'val_loss': 1.3727, 'val_f1': 76.33, 'test_accuracy': 76.10, 'test_f1': 75.72,
                 'duration': 153.507},
                {'steps': 25, 'val_loss': 1.3724, 'val_f1': 76.58, 'test_accuracy': 77.70, 'test_f1': 76.27,
                 'duration': 198.301}
            ],
            'Random': [
                {'steps': 1, 'val_loss': 1.3151, 'val_f1': 72.90, 'test_accuracy': 78.80, 'test_f1': 70.92,
                 'duration': 78.349},
                {'steps': 2, 'val_loss': 1.2668, 'val_f1': 86.47, 'test_accuracy': 85.90, 'test_f1': 84.85,
                 'duration': 63.251},
                {'steps': 3, 'val_loss': 1.2695, 'val_f1': 87.33, 'test_accuracy': 84.90, 'test_f1': 83.84,
                 'duration': 63.726},
                {'steps': 4, 'val_loss': 1.2687, 'val_f1': 87.01, 'test_accuracy': 84.90, 'test_f1': 83.76,
                 'duration': 67.727},
                {'steps': 5, 'val_loss': 1.2676, 'val_f1': 85.73, 'test_accuracy': 86.30, 'test_f1': 85.10,
                 'duration': 72.184},
                {'steps': 7, 'val_loss': 1.3141, 'val_f1': 74.51, 'test_accuracy': 80.30, 'test_f1': 73.13,
                 'duration': 81.419},
                {'steps': 10, 'val_loss': 1.2700, 'val_f1': 87.26, 'test_accuracy': 85.30, 'test_f1': 83.80,
                 'duration': 97.647},
                {'steps': 15, 'val_loss': 1.2687, 'val_f1': 85.21, 'test_accuracy': 85.30, 'test_f1': 84.09,
                 'duration': 125.266},
                {'steps': 20, 'val_loss': 1.2682, 'val_f1': 85.55, 'test_accuracy': 85.30, 'test_f1': 84.13,
                 'duration': 152.736},
                {'steps': 25, 'val_loss': 1.2661, 'val_f1': 86.03, 'test_accuracy': 84.70, 'test_f1': 83.35,
                 'duration': 183.008}
            ]
        }

        # Convert to DataFrame for easier manipulation
        self.df = self._create_dataframe()

        # Color scheme selection
        if use_colorblind_palette:
            color_map = COLORBLIND_PALETTE
        else:
            color_map = ACADEMIC_COLORS

        # Map split names to colors
        self.colors = {}
        split_keys = list(self.data.keys())
        color_keys = list(color_map.keys())
        for i, split in enumerate(split_keys):
            self.colors[split] = color_map[color_keys[i]]

        # Line styles and markers for better distinction
        self.line_styles = {
            'Geom-GCN': '-',
            'Public': '--',
            'Random': ':'
        }

        self.markers = {
            'Geom-GCN': 'o',
            'Public': 's',
            'Random': '^'
        }

    def _create_dataframe(self):
        """Convert the nested data structure to a pandas DataFrame"""
        rows = []
        for split, split_data in self.data.items():
            for row in split_data:
                row['split'] = split
                rows.append(row)
        return pd.DataFrame(rows)

    def plot_performance_vs_steps(self, metric='test_f1', splits=None, figsize=(8, 6),
                                  show_confidence=False, add_annotations=False):
        """
        Plot performance metric vs k-steps with academic formatting

        Args:
            metric: Metric to plot ('test_f1', 'test_accuracy', 'val_f1', 'val_loss', 'duration')
            splits: List of splits to include (default: all)
            figsize: Figure size tuple (width, height) in inches
            show_confidence: Whether to show confidence intervals (placeholder for future)
            add_annotations: Whether to add value annotations on points
        """
        if splits is None:
            splits = list(self.data.keys())

        fig, ax = plt.subplots(figsize=figsize, facecolor='white')
        ax.set_facecolor('white')

        # Plot each split
        for split in splits:
            split_data = self.df[self.df['split'] == split]

            line = ax.plot(split_data['steps'], split_data[metric],
                           color=self.colors[split],
                           linestyle=self.line_styles[split],
                           marker=self.markers[split],
                           linewidth=self.PLOT_CONFIG['line_width'],
                           markersize=self.PLOT_CONFIG['marker_size'],
                           markerfacecolor=self.colors[split],
                           markeredgecolor='white',
                           markeredgewidth=self.PLOT_CONFIG['marker_edge_width'],
                           label=split,
                           alpha=0.8)

            # Add annotations if requested
            if add_annotations:
                for _, row in split_data.iterrows():
                    ax.annotate(f'{row[metric]:.1f}',
                                (row['steps'], row[metric]),
                                xytext=(0, 10), textcoords='offset points',
                                ha='center', va='bottom',
                                fontsize=self.FONT_CONFIG['annotation'], alpha=0.7)

        # Formatting
        ax.set_xlabel('Number of Steps (k)', fontweight='bold',
                      fontsize=self.FONT_CONFIG['axis_label'])

        # Set y-label based on metric with proper formatting
        ylabel_map = {
            'test_f1': 'Test F1-Score (%)',
            'test_accuracy': 'Test Accuracy (%)',
            'val_f1': 'Validation F1-Score (%)',
            'val_loss': 'Validation Loss',
            'duration': 'Training Duration (seconds)'
        }
        ax.set_ylabel(ylabel_map.get(metric, metric.replace('_', ' ').title()),
                      fontweight='bold', fontsize=self.FONT_CONFIG['axis_label'])

        # Title
        title_map = {
            'test_f1': 'Effect of K-Steps on Test F1-Score Performance',
            'test_accuracy': 'Effect of K-Steps on Test Accuracy',
            'val_f1': 'Effect of K-Steps on Validation F1-Score',
            'val_loss': 'Effect of K-Steps on Validation Loss',
            'duration': 'Effect of K-Steps on Training Duration'
        }
        ax.set_title(title_map.get(metric, f'K-Steps vs {metric.replace("_", " ").title()}'),
                     fontweight='bold', pad=15, fontsize=self.FONT_CONFIG['title_main'])

        # Legend with better positioning
        legend = ax.legend(loc='best', frameon=True, fancybox=True, shadow=True,
                           framealpha=0.9, edgecolor='gray',
                           fontsize=self.FONT_CONFIG['legend'])
        legend.get_frame().set_facecolor('white')

        # Grid styling
        ax.grid(True, alpha=self.PLOT_CONFIG['grid_alpha'], linestyle='-', linewidth=0.5)
        ax.set_axisbelow(True)

        # Spines styling
        for spine in ax.spines.values():
            spine.set_color('black')
            spine.set_linewidth(self.PLOT_CONFIG['spine_width'])

        # Tick formatting
        ax.tick_params(colors='black', direction='in',
                       length=self.PLOT_CONFIG['tick_length'],
                       labelsize=self.FONT_CONFIG['tick_label'])

        # Set x-axis to show all step values
        ax.set_xticks(sorted(self.df['steps'].unique()))

        plt.tight_layout()
        return fig, ax

    def plot_efficiency_analysis(self, figsize=(8, 6)):
        """
        Create an efficiency plot showing performance vs computational cost
        """
        fig, ax = plt.subplots(figsize=figsize, facecolor='white')
        ax.set_facecolor('white')

        # Plot each split
        for split in self.data.keys():
            split_data = self.df[self.df['split'] == split]

            # Create scatter plot
            scatter = ax.scatter(split_data['duration'], split_data['test_f1'],
                                 c=self.colors[split],
                                 marker=self.markers[split],
                                 s=80, alpha=0.7,
                                 edgecolors='white',
                                 linewidth=1,
                                 label=split)

            # Add step annotations
            for _, row in split_data.iterrows():
                ax.annotate(f'k={int(row["steps"])}',
                            (row['duration'], row['test_f1']),
                            xytext=(5, 5), textcoords='offset points',
                            fontsize=self.FONT_CONFIG['annotation'], alpha=0.8,
                            bbox=dict(boxstyle='round,pad=0.2',
                                      facecolor=self.colors[split], alpha=0.3))

        # Formatting
        ax.set_xlabel('Training Duration (seconds)', fontweight='bold',
                      fontsize=self.FONT_CONFIG['axis_label'])
        ax.set_ylabel('Test F1-Score (%)', fontweight='bold',
                      fontsize=self.FONT_CONFIG['axis_label'])
        ax.set_title('Performance-Efficiency Trade-off Analysis', fontweight='bold',
                     pad=15, fontsize=self.FONT_CONFIG['title_main'])

        # Legend
        legend = ax.legend(loc='best', frameon=True, fancybox=True, shadow=True,
                           framealpha=0.9, edgecolor='gray',
                           fontsize=self.FONT_CONFIG['legend'])
        legend.get_frame().set_facecolor('white')

        # Grid and styling
        ax.grid(True, alpha=self.PLOT_CONFIG['grid_alpha'], linestyle='-', linewidth=0.5)
        ax.set_axisbelow(True)

        for spine in ax.spines.values():
            spine.set_color('black')
            spine.set_linewidth(self.PLOT_CONFIG['spine_width'])

        ax.tick_params(colors='black', direction='in',
                       length=self.PLOT_CONFIG['tick_length'],
                       labelsize=self.FONT_CONFIG['tick_label'])

        plt.tight_layout()
        return fig, ax

    def plot_comprehensive_analysis(self, figsize=(12, 10)):
        """
        Create a comprehensive multi-panel visualization for academic papers
        """
        fig, axes = plt.subplots(2, 2, figsize=figsize, facecolor='white')
        fig.suptitle('K-Steps Ablation Study: Comprehensive Performance Analysis on Cora Dataset',
                     fontsize=self.FONT_CONFIG['suptitle'], fontweight='bold', y=0.99)

        # Define metrics and their properties
        metrics_config = [
            ('test_f1', 'Test F1-Score (%)', axes[0, 0]),
            ('test_accuracy', 'Test Accuracy (%)', axes[0, 1]),
            ('val_loss', 'Validation Loss', axes[1, 0]),
            ('duration', 'Training Duration (s)', axes[1, 1])
        ]

        for metric, ylabel, ax in metrics_config:
            ax.set_facecolor('white')

            # Plot each split
            for split in self.data.keys():
                split_data = self.df[self.df['split'] == split]

                ax.plot(split_data['steps'], split_data[metric],
                        color=self.colors[split],
                        linestyle=self.line_styles[split],
                        marker=self.markers[split],
                        linewidth=self.PLOT_CONFIG['line_width_small'],
                        markersize=self.PLOT_CONFIG['marker_size_small'],
                        markerfacecolor=self.colors[split],
                        markeredgecolor='white',
                        markeredgewidth=self.PLOT_CONFIG['marker_edge_width_small'],
                        label=split,
                        alpha=0.8)

            # Formatting
            ax.set_xlabel('Number of Steps (k)', fontsize=self.FONT_CONFIG['axis_label'],
                          fontweight='bold')
            ax.set_ylabel(ylabel, fontsize=self.FONT_CONFIG['axis_label'], fontweight='bold')
            ax.set_title(f'{metric.replace("_", " ").title()} Performance',
                         fontsize=self.FONT_CONFIG['title_subplot'], fontweight='bold')

            # Only show legend on first subplot
            if ax == axes[0, 0]:
                legend = ax.legend(loc='best', fontsize=self.FONT_CONFIG['legend'],
                                   frameon=True, fancybox=True, shadow=True, framealpha=0.9)
                legend.get_frame().set_facecolor('white')

            # Grid and styling
            ax.grid(True, alpha=self.PLOT_CONFIG['grid_alpha'], linestyle='-', linewidth=0.5)
            ax.set_axisbelow(True)

            for spine in ax.spines.values():
                spine.set_color('black')
                spine.set_linewidth(self.PLOT_CONFIG['spine_width'])

            ax.tick_params(colors='black', direction='in',
                           length=self.PLOT_CONFIG['tick_length_small'],
                           labelsize=self.FONT_CONFIG['tick_label'])
            ax.set_xticks(sorted(self.df['steps'].unique()))

        plt.tight_layout()
        return fig

    def plot_performance_heatmap(self, metric='test_f1', figsize=(10, 4)):
        """
        Create an academic-style heatmap
        """
        # Pivot the data for heatmap
        pivot_data = self.df.pivot(index='split', columns='steps', values=metric)

        fig, ax = plt.subplots(figsize=figsize, facecolor='white')

        # Create heatmap with academic styling
        heatmap = sns.heatmap(pivot_data,
                              annot=True,
                              fmt='.2f',
                              cmap='RdYlBu_r',
                              center=None,
                              square=False,
                              linewidths=0.5,
                              cbar_kws={'label': f'{metric.replace("_", " ").title()}'},
                              annot_kws={'fontsize': self.FONT_CONFIG['heatmap_annot']},
                              ax=ax)

        # Formatting
        ax.set_title(f'{metric.replace("_", " ").title()} Performance Across Different K-Steps',
                     fontweight='bold', pad=15, fontsize=self.FONT_CONFIG['title_main'])
        ax.set_xlabel('Number of Steps (k)', fontweight='bold',
                      fontsize=self.FONT_CONFIG['axis_label'])
        ax.set_ylabel('Data Split', fontweight='bold',
                      fontsize=self.FONT_CONFIG['axis_label'])

        # Style the colorbar
        cbar = heatmap.collections[0].colorbar
        cbar.ax.tick_params(labelsize=self.FONT_CONFIG['colorbar'])

        # Style tick labels
        ax.tick_params(labelsize=self.FONT_CONFIG['tick_label'])

        plt.tight_layout()
        return fig, ax

    def generate_academic_report(self):
        """
        Generate a comprehensive academic-style report
        """
        print("=" * 90)
        print("K-STEPS ABLATION STUDY - ACADEMIC ANALYSIS REPORT")
        print("=" * 90)
        print("Dataset: Cora Citation Network")
        print(f"Node splits analyzed: {', '.join(self.data.keys())}")
        print(f"K-steps evaluated: {', '.join(map(str, self.steps))}")
        print(f"Total experiments: {len(self.df)} configurations")
        print("\n" + "=" * 90)

        # Statistical summary
        print("\nSTATISTICAL SUMMARY:")
        print("-" * 50)

        metrics = ['test_f1', 'test_accuracy', 'val_f1', 'val_loss', 'duration']

        summary_stats = []
        for split in self.data.keys():
            split_data = self.df[self.df['split'] == split]

            print(f"\n{split} Split:")
            print("  Metric                 Best Value    Optimal k    Mean ± Std")
            print("  " + "-" * 60)

            for metric in metrics:
                values = split_data[metric].values

                if metric == 'val_loss':
                    best_idx = np.argmin(values)
                    best_value = values[best_idx]
                else:
                    best_idx = np.argmax(values)
                    best_value = values[best_idx]

                best_k = split_data.iloc[best_idx]['steps']
                mean_val = np.mean(values)
                std_val = np.std(values)

                unit = "%" if metric in ['test_f1', 'test_accuracy', 'val_f1'] else ""
                if metric == 'duration':
                    unit = "s"

                print(f"  {metric.replace('_', ' ').title():<20} "
                      f"{best_value:>8.2f}{unit:<2} "
                      f"{int(best_k):>8}      "
                      f"{mean_val:>6.2f} ± {std_val:.2f}{unit}")

        # Performance ranking
        print("\n\nPERFORMANCE RANKING (by Test F1-Score):")
        print("-" * 50)

        performance_summary = []
        for split in self.data.keys():
            split_data = self.df[self.df['split'] == split]
            best_f1 = split_data['test_f1'].max()
            best_k = split_data.loc[split_data['test_f1'].idxmax(), 'steps']
            performance_summary.append((split, best_f1, best_k))

        # Sort by performance
        performance_summary.sort(key=lambda x: x[1], reverse=True)

        for rank, (split, f1, k) in enumerate(performance_summary, 1):
            print(f"{rank}. {split:<12}: {f1:.2f}% F1-Score (k={int(k)})")

        # Efficiency analysis
        print("\n\nEFFICIENCY ANALYSIS:")
        print("-" * 50)
        print("Best performance-to-time ratio for each split:")

        for split in self.data.keys():
            split_data = self.df[self.df['split'] == split]
            # Calculate efficiency as F1/duration ratio
            split_data_copy = split_data.copy()
            split_data_copy['efficiency'] = split_data_copy['test_f1'] / split_data_copy['duration']
            best_eff_idx = split_data_copy['efficiency'].idxmax()
            best_eff_row = split_data_copy.loc[best_eff_idx]

            print(f"  {split:<12}: k={int(best_eff_row['steps']):<2} "
                  f"({best_eff_row['test_f1']:.2f}% F1, {best_eff_row['duration']:.1f}s, "
                  f"ratio={best_eff_row['efficiency']:.3f})")

        print("\n" + "=" * 90)

    def save_publication_figures(self, output_dir='./vis/ablation', formats=['pdf', 'png'], dpi=300):
        """
        Generate and save publication-ready figures
        """
        import os
        os.makedirs(output_dir, exist_ok=True)

        print(f"Generating publication-ready figures...")
        print(f"Output directory: {output_dir}")
        print(f"Formats: {', '.join(formats)}")
        print(f"Resolution: {dpi} DPI")

        figures_generated = []

        # 1. Main performance plot (Test F1)
        print("  • Generating main performance figure...")
        fig, ax = self.plot_performance_vs_steps(metric='test_f1', figsize=(8, 6))
        filename = 'k_steps_performance_analysis'
        for fmt in formats:
            filepath = f"{output_dir}/{filename}.{fmt}"
            fig.savefig(filepath, dpi=dpi, bbox_inches='tight',
                        facecolor='white', edgecolor='none')
            figures_generated.append(filepath)
        plt.close(fig)

        # 2. Efficiency analysis
        print("  • Generating efficiency analysis figure...")
        fig, ax = self.plot_efficiency_analysis(figsize=(8, 6))
        filename = 'k_steps_efficiency_analysis'
        for fmt in formats:
            filepath = f"{output_dir}/{filename}.{fmt}"
            fig.savefig(filepath, dpi=dpi, bbox_inches='tight',
                        facecolor='white', edgecolor='none')
            figures_generated.append(filepath)
        plt.close(fig)

        # 3. Comprehensive analysis
        print("  • Generating comprehensive analysis figure...")
        fig = self.plot_comprehensive_analysis(figsize=(12, 10))
        filename = 'k_steps_comprehensive_analysis'
        for fmt in formats:
            filepath = f"{output_dir}/{filename}.{fmt}"
            fig.savefig(filepath, dpi=dpi, bbox_inches='tight',
                        facecolor='white', edgecolor='none')
            figures_generated.append(filepath)
        plt.close(fig)

        # 4. Performance heatmap
        print("  • Generating performance heatmap...")
        fig, ax = self.plot_performance_heatmap(metric='test_f1', figsize=(10, 4))
        filename = 'k_steps_performance_heatmap'
        for fmt in formats:
            filepath = f"{output_dir}/{filename}.{fmt}"
            fig.savefig(filepath, dpi=dpi, bbox_inches='tight',
                        facecolor='white', edgecolor='none')
            figures_generated.append(filepath)
        plt.close(fig)

        print(f"\nSuccessfully generated {len(figures_generated)} figure files:")
        for filepath in figures_generated:
            print(f"  • {filepath}")

        return figures_generated



class AcademicKStepsVisualizer2:
    def __init__(self, use_colorblind_palette=False):
        """Initialize the academic visualizer with ablation study data"""
        self.steps = [1, 2, 3, 4, 5, 7, 10, 15, 20, 25]

        # Data organized by split (renamed: fixed -> geom-gcn, full -> random, public -> public)
        self.data = {
            'Geom-GCN': [
                {'steps': 1, 'val_loss': 1.2603, 'val_f1': 85.33, 'test_accuracy': 84.71, 'test_f1': 83.87,
                 'duration': 81.870},
                {'steps': 2, 'val_loss': 1.2639, 'val_f1': 86.56, 'test_accuracy': 84.53, 'test_f1': 83.36,
                 'duration': 78.367},
                {'steps': 3, 'val_loss': 1.2619, 'val_f1': 85.53, 'test_accuracy': 85.82, 'test_f1': 84.64,
                 'duration': 61.508},
                {'steps': 4, 'val_loss': 1.2601, 'val_f1': 85.54, 'test_accuracy': 84.16, 'test_f1': 83.18,
                 'duration': 67.351},
                {'steps': 5, 'val_loss': 1.2603, 'val_f1': 84.54, 'test_accuracy': 83.61, 'test_f1': 82.32,
                 'duration': 71.850},
                {'steps': 7, 'val_loss': 1.2618, 'val_f1': 86.44, 'test_accuracy': 85.45, 'test_f1': 84.77,
                 'duration': 85.947},
                {'steps': 10, 'val_loss': 1.2588, 'val_f1': 87.44, 'test_accuracy': 84.90, 'test_f1': 83.85,
                 'duration': 101.341},
                {'steps': 15, 'val_loss': 1.2615, 'val_f1': 86.18, 'test_accuracy': 83.06, 'test_f1': 82.00,
                 'duration': 130.009},
                {'steps': 20, 'val_loss': 1.2621, 'val_f1': 86.02, 'test_accuracy': 84.71, 'test_f1': 83.38,
                 'duration': 158.724},
                {'steps': 25, 'val_loss': 1.2651, 'val_f1': 86.52, 'test_accuracy': 84.53, 'test_f1': 83.13,
                 'duration': 188.220}
            ],
            'Public': [
                {'steps': 1, 'val_loss': 1.3724, 'val_f1': 77.98, 'test_accuracy': 78.30, 'test_f1': 77.16,
                 'duration': 52.432},
                {'steps': 2, 'val_loss': 1.3613, 'val_f1': 77.91, 'test_accuracy': 77.80, 'test_f1': 76.64,
                 'duration': 62.481},
                {'steps': 3, 'val_loss': 1.3728, 'val_f1': 75.45, 'test_accuracy': 79.30, 'test_f1': 78.11,
                 'duration': 62.681},
                {'steps': 4, 'val_loss': 1.3763, 'val_f1': 76.62, 'test_accuracy': 79.10, 'test_f1': 78.57,
                 'duration': 66.106},
                {'steps': 5, 'val_loss': 1.3683, 'val_f1': 78.29, 'test_accuracy': 81.30, 'test_f1': 79.93,
                 'duration': 71.591},
                {'steps': 7, 'val_loss': 1.3806, 'val_f1': 75.72, 'test_accuracy': 77.40, 'test_f1': 76.71,
                 'duration': 81.074},
                {'steps': 10, 'val_loss': 1.3728, 'val_f1': 76.50, 'test_accuracy': 78.20, 'test_f1': 77.18,
                 'duration': 98.142},
                {'steps': 15, 'val_loss': 1.3712, 'val_f1': 74.74, 'test_accuracy': 79.80, 'test_f1': 78.55,
                 'duration': 126.408},
                {'steps': 20, 'val_loss': 1.3727, 'val_f1': 76.33, 'test_accuracy': 76.10, 'test_f1': 75.72,
                 'duration': 153.507},
                {'steps': 25, 'val_loss': 1.3724, 'val_f1': 76.58, 'test_accuracy': 77.70, 'test_f1': 76.27,
                 'duration': 198.301}
            ],
            'Random': [
                {'steps': 1, 'val_loss': 1.3151, 'val_f1': 72.90, 'test_accuracy': 78.80, 'test_f1': 70.92,
                 'duration': 78.349},
                {'steps': 2, 'val_loss': 1.2668, 'val_f1': 86.47, 'test_accuracy': 85.90, 'test_f1': 84.85,
                 'duration': 63.251},
                {'steps': 3, 'val_loss': 1.2695, 'val_f1': 87.33, 'test_accuracy': 84.90, 'test_f1': 83.84,
                 'duration': 63.726},
                {'steps': 4, 'val_loss': 1.2687, 'val_f1': 87.01, 'test_accuracy': 84.90, 'test_f1': 83.76,
                 'duration': 67.727},
                {'steps': 5, 'val_loss': 1.2676, 'val_f1': 85.73, 'test_accuracy': 86.30, 'test_f1': 85.10,
                 'duration': 72.184},
                {'steps': 7, 'val_loss': 1.3141, 'val_f1': 74.51, 'test_accuracy': 80.30, 'test_f1': 73.13,
                 'duration': 81.419},
                {'steps': 10, 'val_loss': 1.2700, 'val_f1': 87.26, 'test_accuracy': 85.30, 'test_f1': 83.80,
                 'duration': 97.647},
                {'steps': 15, 'val_loss': 1.2687, 'val_f1': 85.21, 'test_accuracy': 85.30, 'test_f1': 84.09,
                 'duration': 125.266},
                {'steps': 20, 'val_loss': 1.2682, 'val_f1': 85.55, 'test_accuracy': 85.30, 'test_f1': 84.13,
                 'duration': 152.736},
                {'steps': 25, 'val_loss': 1.2661, 'val_f1': 86.03, 'test_accuracy': 84.70, 'test_f1': 83.35,
                 'duration': 183.008}
            ]
        }

        # Convert to DataFrame for easier manipulation
        self.df = self._create_dataframe()

        # Color scheme selection
        if use_colorblind_palette:
            color_map = COLORBLIND_PALETTE
        else:
            color_map = ACADEMIC_COLORS

        # Map split names to colors
        self.colors = {}
        split_keys = list(self.data.keys())
        color_keys = list(color_map.keys())
        for i, split in enumerate(split_keys):
            self.colors[split] = color_map[color_keys[i]]

        # Line styles and markers for better distinction
        self.line_styles = {
            'Geom-GCN': '-',
            'Public': '--',
            'Random': ':'
        }

        self.markers = {
            'Geom-GCN': 'o',
            'Public': 's',
            'Random': '^'
        }

    def _create_dataframe(self):
        """Convert the nested data structure to a pandas DataFrame"""
        rows = []
        for split, split_data in self.data.items():
            for row in split_data:
                row['split'] = split
                rows.append(row)
        return pd.DataFrame(rows)

    def plot_performance_vs_steps(self, metric='test_f1', splits=None, figsize=(8, 6),
                                  show_confidence=False, add_annotations=False):
        """
        Plot performance metric vs k-steps with academic formatting

        Args:
            metric: Metric to plot ('test_f1', 'test_accuracy', 'val_f1', 'val_loss', 'duration')
            splits: List of splits to include (default: all)
            figsize: Figure size tuple (width, height) in inches
            show_confidence: Whether to show confidence intervals (placeholder for future)
            add_annotations: Whether to add value annotations on points
        """
        if splits is None:
            splits = list(self.data.keys())

        fig, ax = plt.subplots(figsize=figsize, facecolor='white')
        ax.set_facecolor('white')

        # Plot each split
        for split in splits:
            split_data = self.df[self.df['split'] == split]

            line = ax.plot(split_data['steps'], split_data[metric],
                           color=self.colors[split],
                           linestyle=self.line_styles[split],
                           marker=self.markers[split],
                           linewidth=2,
                           markersize=6,
                           markerfacecolor=self.colors[split],
                           markeredgecolor='white',
                           markeredgewidth=1,
                           label=split,
                           alpha=0.8)

            # Add annotations if requested
            if add_annotations:
                for _, row in split_data.iterrows():
                    ax.annotate(f'{row[metric]:.1f}',
                                (row['steps'], row[metric]),
                                xytext=(0, 10), textcoords='offset points',
                                ha='center', va='bottom',
                                fontsize=8, alpha=0.7)

        # Formatting
        ax.set_xlabel('Number of Steps (k)', fontweight='bold')

        # Set y-label based on metric with proper formatting
        ylabel_map = {
            'test_f1': 'Test F1-Score (%)',
            'test_accuracy': 'Test Accuracy (%)',
            'val_f1': 'Validation F1-Score (%)',
            'val_loss': 'Validation Loss',
            'duration': 'Training Duration (seconds)'
        }
        ax.set_ylabel(ylabel_map.get(metric, metric.replace('_', ' ').title()), fontweight='bold')

        # Title
        title_map = {
            'test_f1': 'Effect of K-Steps on Test F1-Score Performance',
            'test_accuracy': 'Effect of K-Steps on Test Accuracy',
            'val_f1': 'Effect of K-Steps on Validation F1-Score',
            'val_loss': 'Effect of K-Steps on Validation Loss',
            'duration': 'Effect of K-Steps on Training Duration'
        }
        ax.set_title(title_map.get(metric, f'K-Steps vs {metric.replace("_", " ").title()}'),
                     fontweight='bold', pad=15)

        # Legend with better positioning
        legend = ax.legend(loc='best', frameon=True, fancybox=True, shadow=True,
                           framealpha=0.9, edgecolor='gray')
        legend.get_frame().set_facecolor('white')

        # Grid styling
        ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
        ax.set_axisbelow(True)

        # Spines styling
        for spine in ax.spines.values():
            spine.set_color('black')
            spine.set_linewidth(0.8)

        # Tick formatting
        ax.tick_params(colors='black', direction='in', length=4)

        # Set x-axis to show all step values
        ax.set_xticks(sorted(self.df['steps'].unique()))

        plt.tight_layout()
        return fig, ax

    def plot_efficiency_analysis(self, figsize=(8, 6)):
        """
        Create an efficiency plot showing performance vs computational cost
        """
        fig, ax = plt.subplots(figsize=figsize, facecolor='white')
        ax.set_facecolor('white')

        # Plot each split
        for split in self.data.keys():
            split_data = self.df[self.df['split'] == split]

            # Create scatter plot
            scatter = ax.scatter(split_data['duration'], split_data['test_f1'],
                                 c=self.colors[split],
                                 marker=self.markers[split],
                                 s=80, alpha=0.7,
                                 edgecolors='white',
                                 linewidth=1,
                                 label=split)

            # Add step annotations
            for _, row in split_data.iterrows():
                ax.annotate(f'k={int(row["steps"])}',
                            (row['duration'], row['test_f1']),
                            xytext=(5, 5), textcoords='offset points',
                            fontsize=8, alpha=0.8,
                            bbox=dict(boxstyle='round,pad=0.2',
                                      facecolor=self.colors[split], alpha=0.3))

        # Formatting
        ax.set_xlabel('Training Duration (seconds)', fontweight='bold')
        ax.set_ylabel('Test F1-Score (%)', fontweight='bold')
        ax.set_title('Performance-Efficiency Trade-off Analysis', fontweight='bold', pad=15)

        # Legend
        legend = ax.legend(loc='best', frameon=True, fancybox=True, shadow=True,
                           framealpha=0.9, edgecolor='gray')
        legend.get_frame().set_facecolor('white')

        # Grid and styling
        ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
        ax.set_axisbelow(True)

        for spine in ax.spines.values():
            spine.set_color('black')
            spine.set_linewidth(0.8)

        ax.tick_params(colors='black', direction='in', length=4)

        plt.tight_layout()
        return fig, ax

    def plot_comprehensive_analysis(self, figsize=(12, 10)):
        """
        Create a comprehensive multi-panel visualization for academic papers
        """
        fig, axes = plt.subplots(2, 2, figsize=figsize, facecolor='white')
        fig.suptitle('K-Steps Ablation Study: Comprehensive Performance Analysis on Cora Dataset',
                     fontsize=14, fontweight='bold', y=0.99)

        # Define metrics and their properties
        metrics_config = [
            ('test_f1', 'Test F1-Score (%)', axes[0, 0]),
            ('test_accuracy', 'Test Accuracy (%)', axes[0, 1]),
            ('val_loss', 'Validation Loss', axes[1, 0]),
            ('duration', 'Training Duration (s)', axes[1, 1])
        ]

        for metric, ylabel, ax in metrics_config:
            ax.set_facecolor('white')

            # Plot each split
            for split in self.data.keys():
                split_data = self.df[self.df['split'] == split]

                ax.plot(split_data['steps'], split_data[metric],
                        color=self.colors[split],
                        linestyle=self.line_styles[split],
                        marker=self.markers[split],
                        linewidth=1.5,
                        markersize=4,
                        markerfacecolor=self.colors[split],
                        markeredgecolor='white',
                        markeredgewidth=0.5,
                        label=split,
                        alpha=0.8)

            # Formatting
            ax.set_xlabel('Number of Steps (k)', fontsize=10, fontweight='bold')
            ax.set_ylabel(ylabel, fontsize=10, fontweight='bold')
            ax.set_title(f'{metric.replace("_", " ").title()} Performance',
                         fontsize=11, fontweight='bold')

            # Only show legend on first subplot
            if ax == axes[0, 0]:
                legend = ax.legend(loc='best', fontsize=9, frameon=True,
                                   fancybox=True, shadow=True, framealpha=0.9)
                legend.get_frame().set_facecolor('white')

            # Grid and styling
            ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
            ax.set_axisbelow(True)

            for spine in ax.spines.values():
                spine.set_color('black')
                spine.set_linewidth(0.8)

            ax.tick_params(colors='black', direction='in', length=3, labelsize=9)
            ax.set_xticks(sorted(self.df['steps'].unique()))

        plt.tight_layout()
        return fig

    def plot_performance_heatmap(self, metric='test_f1', figsize=(10, 4)):
        """
        Create an academic-style heatmap
        """
        # Pivot the data for heatmap
        pivot_data = self.df.pivot(index='split', columns='steps', values=metric)

        fig, ax = plt.subplots(figsize=figsize, facecolor='white')

        # Create heatmap with academic styling
        heatmap = sns.heatmap(pivot_data,
                              annot=True,
                              fmt='.2f',
                              cmap='RdYlBu_r',
                              center=None,
                              square=False,
                              linewidths=0.5,
                              cbar_kws={'label': f'{metric.replace("_", " ").title()}'},
                              annot_kws={'fontsize': 9},
                              ax=ax)

        # Formatting
        ax.set_title(f'{metric.replace("_", " ").title()} Performance Across Different K-Steps',
                     fontweight='bold', pad=15)
        ax.set_xlabel('Number of Steps (k)', fontweight='bold')
        ax.set_ylabel('Data Split', fontweight='bold')

        # Style the colorbar
        cbar = heatmap.collections[0].colorbar
        cbar.ax.tick_params(labelsize=9)

        plt.tight_layout()
        return fig, ax

    def generate_academic_report(self):
        """
        Generate a comprehensive academic-style report
        """
        print("=" * 90)
        print("K-STEPS ABLATION STUDY - ACADEMIC ANALYSIS REPORT")
        print("=" * 90)
        print("Dataset: Cora Citation Network")
        print(f"Node splits analyzed: {', '.join(self.data.keys())}")
        print(f"K-steps evaluated: {', '.join(map(str, self.steps))}")
        print(f"Total experiments: {len(self.df)} configurations")
        print("\n" + "=" * 90)

        # Statistical summary
        print("\nSTATISTICAL SUMMARY:")
        print("-" * 50)

        metrics = ['test_f1', 'test_accuracy', 'val_f1', 'val_loss', 'duration']

        summary_stats = []
        for split in self.data.keys():
            split_data = self.df[self.df['split'] == split]

            print(f"\n{split} Split:")
            print("  Metric                 Best Value    Optimal k    Mean ± Std")
            print("  " + "-" * 60)

            for metric in metrics:
                values = split_data[metric].values

                if metric == 'val_loss':
                    best_idx = np.argmin(values)
                    best_value = values[best_idx]
                else:
                    best_idx = np.argmax(values)
                    best_value = values[best_idx]

                best_k = split_data.iloc[best_idx]['steps']
                mean_val = np.mean(values)
                std_val = np.std(values)

                unit = "%" if metric in ['test_f1', 'test_accuracy', 'val_f1'] else ""
                if metric == 'duration':
                    unit = "s"

                print(f"  {metric.replace('_', ' ').title():<20} "
                      f"{best_value:>8.2f}{unit:<2} "
                      f"{int(best_k):>8}      "
                      f"{mean_val:>6.2f} ± {std_val:.2f}{unit}")

        # Performance ranking
        print("\n\nPERFORMANCE RANKING (by Test F1-Score):")
        print("-" * 50)

        performance_summary = []
        for split in self.data.keys():
            split_data = self.df[self.df['split'] == split]
            best_f1 = split_data['test_f1'].max()
            best_k = split_data.loc[split_data['test_f1'].idxmax(), 'steps']
            performance_summary.append((split, best_f1, best_k))

        # Sort by performance
        performance_summary.sort(key=lambda x: x[1], reverse=True)

        for rank, (split, f1, k) in enumerate(performance_summary, 1):
            print(f"{rank}. {split:<12}: {f1:.2f}% F1-Score (k={int(k)})")

        # Efficiency analysis
        print("\n\nEFFICIENCY ANALYSIS:")
        print("-" * 50)
        print("Best performance-to-time ratio for each split:")

        for split in self.data.keys():
            split_data = self.df[self.df['split'] == split]
            # Calculate efficiency as F1/duration ratio
            split_data_copy = split_data.copy()
            split_data_copy['efficiency'] = split_data_copy['test_f1'] / split_data_copy['duration']
            best_eff_idx = split_data_copy['efficiency'].idxmax()
            best_eff_row = split_data_copy.loc[best_eff_idx]

            print(f"  {split:<12}: k={int(best_eff_row['steps']):<2} "
                  f"({best_eff_row['test_f1']:.2f}% F1, {best_eff_row['duration']:.1f}s, "
                  f"ratio={best_eff_row['efficiency']:.3f})")

        print("\n" + "=" * 90)

    def save_publication_figures(self, output_dir='./vis/ablation', formats=['pdf', 'png'], dpi=300):
        """
        Generate and save publication-ready figures
        """
        import os
        os.makedirs(output_dir, exist_ok=True)

        print(f"Generating publication-ready figures...")
        print(f"Output directory: {output_dir}")
        print(f"Formats: {', '.join(formats)}")
        print(f"Resolution: {dpi} DPI")

        figures_generated = []

        # 1. Main performance plot (Test F1)
        print("  • Generating main performance figure...")
        fig, ax = self.plot_performance_vs_steps(metric='test_f1', figsize=(8, 6))
        filename = 'k_steps_performance_analysis'
        for fmt in formats:
            filepath = f"{output_dir}/{filename}.{fmt}"
            fig.savefig(filepath, dpi=dpi, bbox_inches='tight',
                        facecolor='white', edgecolor='none')
            figures_generated.append(filepath)
        plt.close(fig)

        # 2. Efficiency analysis
        print("  • Generating efficiency analysis figure...")
        fig, ax = self.plot_efficiency_analysis(figsize=(8, 6))
        filename = 'k_steps_efficiency_analysis'
        for fmt in formats:
            filepath = f"{output_dir}/{filename}.{fmt}"
            fig.savefig(filepath, dpi=dpi, bbox_inches='tight',
                        facecolor='white', edgecolor='none')
            figures_generated.append(filepath)
        plt.close(fig)

        # 3. Comprehensive analysis
        print("  • Generating comprehensive analysis figure...")
        fig = self.plot_comprehensive_analysis(figsize=(12, 10))
        filename = 'k_steps_comprehensive_analysis'
        for fmt in formats:
            filepath = f"{output_dir}/{filename}.{fmt}"
            fig.savefig(filepath, dpi=dpi, bbox_inches='tight',
                        facecolor='white', edgecolor='none')
            figures_generated.append(filepath)
        plt.close(fig)

        # 4. Performance heatmap
        print("  • Generating performance heatmap...")
        fig, ax = self.plot_performance_heatmap(metric='test_f1', figsize=(10, 4))
        filename = 'k_steps_performance_heatmap'
        for fmt in formats:
            filepath = f"{output_dir}/{filename}.{fmt}"
            fig.savefig(filepath, dpi=dpi, bbox_inches='tight',
                        facecolor='white', edgecolor='none')
            figures_generated.append(filepath)
        plt.close(fig)

        print(f"\nSuccessfully generated {len(figures_generated)} figure files:")
        for filepath in figures_generated:
            print(f"  • {filepath}")

        return figures_generated


def main():
    """
    Main function for academic visualization
    """
    print("Academic K-Steps Ablation Study Visualization")
    print("=" * 60)

    # Initialize visualizer (set use_colorblind_palette=True for colorblind-friendly colors)
    visualizer = AcademicKStepsVisualizer(use_colorblind_palette=False)

    # Generate comprehensive report
    visualizer.generate_academic_report()

    # Create and display individual plots
    print("\n" + "=" * 60)
    print("GENERATING VISUALIZATIONS")
    print("=" * 60)

    # 1. Main performance plot
    print("\n1. Test F1-Score Performance Analysis:")
    fig, ax = visualizer.plot_performance_vs_steps(metric='test_f1', add_annotations=False)
    plt.show()

    # 2. Efficiency analysis
    print("\n2. Performance-Efficiency Trade-off Analysis:")
    fig, ax = visualizer.plot_efficiency_analysis()
    plt.show()

    # 3. Comprehensive analysis
    print("\n3. Comprehensive Multi-Panel Analysis:")
    fig = visualizer.plot_comprehensive_analysis()
    plt.show()

    # 4. Performance heatmap
    print("\n4. Performance Heatmap:")
    fig, ax = visualizer.plot_performance_heatmap()
    plt.show()

    # Save publication-ready figures
    print("\n" + "=" * 60)
    print("SAVING PUBLICATION FIGURES")
    print("=" * 60)

    figures = visualizer.save_publication_figures(output_dir='./publication_figures')

    print(f"\nAnalysis complete! {len(figures)} publication-ready figures generated.")
    print("\nRecommended usage in academic papers:")
    print("• Use k_steps_performance_analysis.pdf for main results")
    print("• Use k_steps_comprehensive_analysis.pdf for detailed analysis")
    print("• Use k_steps_efficiency_analysis.pdf for computational cost discussion")
    print("• Use k_steps_performance_heatmap.pdf for tabular comparison")


if __name__ == "__main__":
    main()