import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import warnings
import os

warnings.filterwarnings('ignore')

# Set style for better visuals
plt.style.use('default')
sns.set_palette("husl")

# Ensure matplotlib works in different environments
import matplotlib

matplotlib.rcParams['figure.max_open_warning'] = 0


class SamplingMethodsAblationVisualizer:
    def __init__(self,
                 title_fontsize=26,
                 label_fontsize=24,
                 legend_fontsize=22,
                 tick_fontsize=20,
                 annotation_fontsize=20,
                 suptitle_fontsize=24):
        """
        Initialize the visualizer with the sampling methods ablation study data
        """
        # Font size configuration
        self.title_fontsize = title_fontsize
        self.label_fontsize = label_fontsize
        self.legend_fontsize = legend_fontsize
        self.tick_fontsize = tick_fontsize
        self.annotation_fontsize = annotation_fontsize
        self.suptitle_fontsize = suptitle_fontsize

        self.sampling_methods = [
            "bernoulli", "sigmoid", "gumbel_softmax", "straight_through",
            "gaussian", "relaxed_bernoulli", "spike_and_slab"
        ]

        # Data organized by split
        self.data = {
            'geom-gcn': [
                {'method': 'bernoulli', 'val_loss': 1.9448, 'val_f1': 4.26, 'test_accuracy': 11.60, 'test_f1': 3.27,
                 'duration': 71.866},
                {'method': 'sigmoid', 'val_loss': 1.8410, 'val_f1': 32.73, 'test_accuracy': 52.49, 'test_f1': 32.03,
                 'duration': 65.889},
                {'method': 'gumbel_softmax', 'val_loss': 1.2603, 'val_f1': 84.54, 'test_accuracy': 83.61,
                 'test_f1': 82.32, 'duration': 69.900},
                {'method': 'straight_through', 'val_loss': 1.2681, 'val_f1': 83.09, 'test_accuracy': 81.58,
                 'test_f1': 79.83, 'duration': 67.368},
                {'method': 'gaussian', 'val_loss': 1.3106, 'val_f1': 87.24, 'test_accuracy': 85.27, 'test_f1': 84.55,
                 'duration': 69.273},
                {'method': 'relaxed_bernoulli', 'val_loss': 1.2903, 'val_f1': 86.04, 'test_accuracy': 83.79,
                 'test_f1': 82.67, 'duration': 76.451},
            ],
            'public': [
                {'method': 'bernoulli', 'val_loss': 1.9436, 'val_f1': 3.82, 'test_accuracy': 12.80, 'test_f1': 3.60,
                 'duration': 66.772},
                {'method': 'sigmoid', 'val_loss': 1.8779, 'val_f1': 21.72, 'test_accuracy': 39.70, 'test_f1': 21.82,
                 'duration': 66.104},
                {'method': 'gumbel_softmax', 'val_loss': 1.3683, 'val_f1': 78.29, 'test_accuracy': 81.30,
                 'test_f1': 79.93, 'duration': 68.488},
                {'method': 'straight_through', 'val_loss': 1.4225, 'val_f1': 65.12, 'test_accuracy': 65.10,
                 'test_f1': 64.47, 'duration': 66.422},
                {'method': 'gaussian', 'val_loss': 1.4615, 'val_f1': 66.71, 'test_accuracy': 73.40, 'test_f1': 67.51,
                 'duration': 66.770},
                {'method': 'relaxed_bernoulli', 'val_loss': 1.9435, 'val_f1': 5.94, 'test_accuracy': 14.80,
                 'test_f1': 5.77, 'duration': 76.254},
            ],
            'random': [
                {'method': 'bernoulli', 'val_loss': 1.9436, 'val_f1': 3.82, 'test_accuracy': 12.80, 'test_f1': 3.60,
                 'duration': 65.860},
                {'method': 'sigmoid', 'val_loss': 1.9049, 'val_f1': 6.86, 'test_accuracy': 31.90, 'test_f1': 6.91,
                 'duration': 65.020},
                {'method': 'gumbel_softmax', 'val_loss': 1.2676, 'val_f1': 85.73, 'test_accuracy': 86.30,
                 'test_f1': 85.10, 'duration': 69.391},
                {'method': 'straight_through', 'val_loss': 1.3982, 'val_f1': 54.03, 'test_accuracy': 68.10,
                 'test_f1': 52.53, 'duration': 67.283},
                {'method': 'gaussian', 'val_loss': 1.4646, 'val_f1': 56.03, 'test_accuracy': 67.50, 'test_f1': 50.65,
                 'duration': 67.399},
                {'method': 'relaxed_bernoulli', 'val_loss': 1.8370, 'val_f1': 8.27, 'test_accuracy': 30.50,
                 'test_f1': 7.69, 'duration': 77.584},
            ]
        }

        # Convert to DataFrame for easier manipulation
        self.df = self._create_dataframe()

        # Color scheme for splits
        self.colors = {
            'geom-gcn': '#8884d8',
            'public': '#82ca9d',
            'random': '#ffc658',
            'complete': '#ff7c7c'
        }

        # Method display names
        self.method_names = {
            'bernoulli': 'Bernoulli',
            'sigmoid': 'Sigmoid',
            'gumbel_softmax': 'Gumbel Softmax',
            'straight_through': 'Straight Through',
            'gaussian': 'Gaussian',
            'relaxed_bernoulli': 'Relaxed Bernoulli',
            'spike_and_slab': 'Spike and Slab'
        }

    def _create_dataframe(self):
        """Convert the nested data structure to a pandas DataFrame"""
        rows = []
        for split, split_data in self.data.items():
            for row in split_data:
                row['split'] = split
                rows.append(row)
        return pd.DataFrame(rows)

    def plot_method_comparison(self, metric='test_f1', splits=None, figsize=(15, 8)):
        """Plot performance comparison across sampling methods using bar chart"""
        if splits is None:
            splits = list(self.data.keys())

        # Prepare data for grouped bar chart
        methods_in_data = []
        split_data = {}

        for method in self.sampling_methods:
            method_values = []
            has_data = False
            for split in splits:
                split_method_data = self.df[(self.df['split'] == split) & (self.df['method'] == method)]
                if not split_method_data.empty:
                    method_values.append(split_method_data.iloc[0][metric])
                    has_data = True
                else:
                    method_values.append(0)

            if has_data:
                methods_in_data.append(method)
                for i, split in enumerate(splits):
                    if split not in split_data:
                        split_data[split] = []
                    split_data[split].append(method_values[i])

        # Create the plot
        fig, ax = plt.subplots(figsize=figsize, facecolor='white')
        ax.set_facecolor('white')

        x = np.arange(len(methods_in_data))
        width = 0.8 / len(splits)

        for i, split in enumerate(splits):
            values = split_data[split]
            bars = ax.bar(x + i * width, values, width,
                          label=split.upper(), color=self.colors[split], alpha=0.8)

            # Add value labels on bars for significant values
            for j, bar in enumerate(bars):
                height = bar.get_height()
                if height > (max(values) * 0.1):  # Only label if > 10% of max
                    ax.annotate(f'{height:.1f}',
                                xy=(bar.get_x() + bar.get_width() / 2, height),
                                xytext=(0, 3),
                                textcoords="offset points",
                                ha='center', va='bottom', fontsize=self.annotation_fontsize)

        # Formatting
        ax.set_xlabel('Sampling Methods', fontsize=self.label_fontsize, color='black')

        ylabel_map = {
            'test_f1': 'Test F1 Score (%)',
            'test_accuracy': 'Test Accuracy (%)',
            'val_f1': 'Validation F1 Score (%)',
            'val_loss': 'Validation Loss',
            'duration': 'Duration (seconds)'
        }
        ax.set_ylabel(ylabel_map.get(metric, metric), fontsize=self.label_fontsize, color='black')

        ax.set_title(f'Sampling Methods Comparison: {ylabel_map.get(metric, metric)}\nCORA Dataset',
                     fontsize=self.title_fontsize, fontweight='bold', color='black', pad=20)

        ax.set_xticks(x + width * (len(splits) - 1) / 2)
        ax.set_xticklabels([self.method_names.get(m, m).replace('_', ' ') for m in methods_in_data],
                           rotation=45, ha='right', fontsize=self.tick_fontsize)

        ax.legend(fontsize=self.legend_fontsize, framealpha=0.9)
        ax.grid(True, alpha=0.3, color='gray', axis='y')

        # Style
        ax.spines['bottom'].set_color('black')
        ax.spines['top'].set_color('black')
        ax.spines['right'].set_color('black')
        ax.spines['left'].set_color('black')
        ax.tick_params(colors='black', labelsize=self.tick_fontsize)

        plt.tight_layout()
        return fig

    def plot_performance_vs_duration(self, metric='test_f1', splits=None, figsize=(12, 8)):
        """Plot performance metric vs duration (efficiency analysis)"""
        if splits is None:
            splits = list(self.data.keys())

        plt.figure(figsize=figsize, facecolor='white')
        ax = plt.gca()
        ax.set_facecolor('white')

        # Plot each split
        for split in splits:
            split_data = self.df[self.df['split'] == split]
            scatter = plt.scatter(split_data['duration'], split_data[metric],
                                  s=120, alpha=0.7, color=self.colors[split],
                                  label=split.upper(), edgecolors='white', linewidth=2)

            # Add method annotations
            for _, row in split_data.iterrows():
                method_name = self.method_names.get(row['method'], row['method'])
                plt.annotate(method_name,
                             (row['duration'], row[metric]),
                             xytext=(5, 5), textcoords='offset points',
                             fontsize=self.annotation_fontsize, color='black', fontweight='bold',
                             bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.7))

        plt.xlabel('Duration (seconds)', fontsize=self.label_fontsize, color='black')

        ylabel_map = {
            'test_f1': 'Test F1 Score (%)',
            'test_accuracy': 'Test Accuracy (%)',
            'val_f1': 'Validation F1 Score (%)',
            'val_loss': 'Validation Loss',
            'duration': 'Duration (seconds)'
        }
        plt.ylabel(ylabel_map.get(metric, metric), fontsize=self.label_fontsize, color='black')

        plt.title(f'Sampling Methods: Performance vs Duration\nCORA Dataset',
                  fontsize=self.title_fontsize, fontweight='bold', color='black', pad=20)
        plt.legend(fontsize=self.legend_fontsize, framealpha=0.9)
        plt.grid(True, alpha=0.3, color='gray')

        # Style the plot
        ax.spines['bottom'].set_color('black')
        ax.spines['top'].set_color('black')
        ax.spines['right'].set_color('black')
        ax.spines['left'].set_color('black')
        ax.tick_params(colors='black', labelsize=self.tick_fontsize)

        plt.tight_layout()
        return plt.gcf()

    def plot_comprehensive_analysis_row(self, figsize=(24, 6)):
        """Create a comprehensive single-row visualization for sampling methods"""
        fig, axes = plt.subplots(1, 4, figsize=figsize, facecolor='white')

        # Define metrics to plot
        metrics = [
            ('test_f1', 'Test F1 Score (%)'),
            ('test_accuracy', 'Test Accuracy (%)'),
            ('val_f1', 'Validation F1 Score (%)'),
            ('duration', 'Duration (seconds)')
        ]

        for i, (metric, title) in enumerate(metrics):
            ax = axes[i]
            ax.set_facecolor('white')

            # Prepare data for this metric
            methods_in_data = []
            split_data = {}

            for method in self.sampling_methods:
                method_values = []
                has_data = False
                for split in self.data.keys():
                    split_method_data = self.df[(self.df['split'] == split) & (self.df['method'] == method)]
                    if not split_method_data.empty:
                        method_values.append(split_method_data.iloc[0][metric])
                        has_data = True
                    else:
                        method_values.append(0)

                if has_data:
                    methods_in_data.append(method)
                    for j, split in enumerate(self.data.keys()):
                        if split not in split_data:
                            split_data[split] = []
                        split_data[split].append(method_values[j])

            # Create bar chart
            x = np.arange(len(methods_in_data))
            width = 0.8 / len(self.data.keys())

            for j, split in enumerate(self.data.keys()):
                values = split_data[split]
                ax.bar(x + j * width, values, width,
                       label=split.upper(), color=self.colors[split], alpha=0.8)

            ax.set_xlabel('Sampling Methods', fontsize=12, color='black')
            ax.set_ylabel(title, fontsize=10, color='black')
            ax.set_title(title, fontsize=12, fontweight='bold', color='black')
            ax.set_xticks(x + width * (len(self.data.keys()) - 1) / 2)
            ax.set_xticklabels([self.method_names.get(m, m) for m in methods_in_data],
                               rotation=45, ha='right', fontsize=14)
            if i == 0:  # Only show legend on first subplot
                ax.legend(fontsize=10)
            ax.grid(True, alpha=0.3, color='gray', axis='y')

            # Style
            ax.spines['bottom'].set_color('black')
            ax.spines['top'].set_color('black')
            ax.spines['right'].set_color('black')
            ax.spines['left'].set_color('black')
            ax.tick_params(colors='black', labelsize=9)

        plt.suptitle('Sampling Methods Ablation Study - Comprehensive Analysis\nCORA Dataset',
                     fontsize=16, fontweight='bold', color='black', y=1.02)
        plt.tight_layout()
        return fig

    def plot_comprehensive_analysis_no_duration(self, figsize=(18, 6)):
        """Create a comprehensive single-row visualization without duration metric"""
        fig, axes = plt.subplots(1, 3, figsize=figsize, facecolor='white')

        # Define metrics to plot (excluding duration)
        metrics = [
            ('test_f1', 'Test F1 Score (%)'),
            ('test_accuracy', 'Test Accuracy (%)'),
            ('val_f1', 'Validation F1 Score (%)')
        ]

        for i, (metric, title) in enumerate(metrics):
            ax = axes[i]
            ax.set_facecolor('white')

            # Prepare data for this metric
            methods_in_data = []
            split_data = {}

            for method in self.sampling_methods:
                method_values = []
                has_data = False
                for split in self.data.keys():
                    split_method_data = self.df[(self.df['split'] == split) & (self.df['method'] == method)]
                    if not split_method_data.empty:
                        method_values.append(split_method_data.iloc[0][metric])
                        has_data = True
                    else:
                        method_values.append(0)

                if has_data:
                    methods_in_data.append(method)
                    for j, split in enumerate(self.data.keys()):
                        if split not in split_data:
                            split_data[split] = []
                        split_data[split].append(method_values[j])

            # Create bar chart
            x = np.arange(len(methods_in_data))
            width = 0.8 / len(self.data.keys())

            for j, split in enumerate(self.data.keys()):
                values = split_data[split]
                ax.bar(x + j * width, values, width,
                       label=split.upper(), color=self.colors[split], alpha=0.8)

            ax.set_xlabel('Sampling Methods', fontsize=14, color='black')
            ax.set_ylabel(title, fontsize=12, color='black')
            ax.set_title(title, fontsize=14, fontweight='bold', color='black')
            ax.set_xticks(x + width * (len(self.data.keys()) - 1) / 2)
            ax.set_xticklabels([self.method_names.get(m, m) for m in methods_in_data],
                               rotation=45, ha='right', fontsize=14)
            if i == 0:  # Only show legend on first subplot
                ax.legend(fontsize=12)
            ax.grid(True, alpha=0.3, color='gray', axis='y')

            # Style
            ax.spines['bottom'].set_color('black')
            ax.spines['top'].set_color('black')
            ax.spines['right'].set_color('black')
            ax.spines['left'].set_color('black')
            ax.tick_params(colors='black', labelsize=16)

        plt.suptitle('Sampling Methods Ablation Study - Performance Analysis\nCORA Dataset',
                     fontsize=16, fontweight='bold', color='black', y=1.02)
        plt.tight_layout()
        return fig

    def plot_heatmap(self, metric='test_f1', figsize=(12, 6)):
        """Create a heatmap showing performance across methods and splits"""
        # Pivot the data for heatmap
        pivot_data = self.df.pivot(index='split', columns='method', values=metric)

        # Reorder columns to match our method order
        available_methods = [m for m in self.sampling_methods if m in pivot_data.columns]
        pivot_data = pivot_data[available_methods]

        plt.figure(figsize=figsize, facecolor='white')
        ax = plt.gca()

        # Create heatmap with light theme
        sns.heatmap(pivot_data, annot=True, fmt='.1f', cmap='RdYlBu_r',
                    cbar_kws={'label': f'{metric.replace("_", " ").title()}'})

        plt.title(f'Performance Heatmap: {metric.replace("_", " ").title()}\nCORA Dataset',
                  fontsize=self.title_fontsize, fontweight='bold', color='black', pad=20)
        plt.xlabel('Sampling Methods', fontsize=self.label_fontsize, color='black')
        plt.ylabel('Data Split', fontsize=self.label_fontsize, color='black')

        # Update x-axis labels
        method_labels = [self.method_names.get(m, m) for m in available_methods]
        ax.set_xticklabels(method_labels, rotation=45, ha='right', fontsize=self.tick_fontsize)

        # Style
        ax.tick_params(colors='black', labelsize=self.tick_fontsize)
        plt.tight_layout()
        return plt.gcf()

    def generate_statistics_report(self):
        """Generate a comprehensive statistics report"""
        print("=" * 80)
        print("SAMPLING METHODS ABLATION STUDY - COMPREHENSIVE STATISTICS REPORT")
        print("=" * 80)
        print(f"Dataset: CORA")
        print(f"Splits analyzed: {', '.join(self.data.keys()).upper()}")
        print(f"Sampling methods: {len(self.sampling_methods)}")
        print("\n" + "=" * 80)

        metrics = ['test_f1', 'test_accuracy', 'val_f1', 'val_loss', 'duration']

        # Overall method ranking
        print("\nOVERALL METHOD RANKING (by Test F1 Score):")
        print("-" * 50)

        method_avg_scores = {}
        for method in self.sampling_methods:
            scores = []
            for split in self.data.keys():
                split_data = self.df[(self.df['split'] == split) & (self.df['method'] == method)]
                if not split_data.empty:
                    scores.append(split_data.iloc[0]['test_f1'])
            if scores:
                method_avg_scores[method] = np.mean(scores)

        sorted_methods = sorted(method_avg_scores.items(), key=lambda x: x[1], reverse=True)
        for i, (method, avg_score) in enumerate(sorted_methods[:5]):
            print(f"{i + 1:2d}. {self.method_names.get(method, method):<20}: {avg_score:.2f}% average F1")

        # Split-wise analysis
        for split in self.data.keys():
            print(f"\n{split.upper()} SPLIT ANALYSIS:")
            print("-" * 40)

            split_data = self.df[self.df['split'] == split]

            for metric in metrics:
                values = split_data[metric].values
                if len(values) > 0:
                    if metric == 'val_loss':
                        best_idx = np.argmin(values)
                    else:
                        best_idx = np.argmax(values)

                    best_value = values[best_idx]
                    best_method = split_data.iloc[best_idx]['method']

                    mean_val = np.mean(values)
                    std_val = np.std(values)

                    unit = "%" if metric in ['test_f1', 'test_accuracy', 'val_f1'] else (
                        "s" if metric == 'duration' else "")

                    print(f"{metric.replace('_', ' ').title():<20}: "
                          f"Best = {best_value:.2f}{unit} ({self.method_names.get(best_method, best_method)}), "
                          f"Mean = {mean_val:.2f}{unit}, "
                          f"Std = {std_val:.2f}{unit}")

        # Method analysis
        print("\n" + "=" * 80)
        print("METHOD-WISE ANALYSIS:")
        print("-" * 40)

        for method in sorted_methods:
            method_name = method[0]
            print(f"\n{self.method_names.get(method_name, method_name).upper()}:")
            method_data = self.df[self.df['method'] == method_name]

            if not method_data.empty:
                avg_f1 = method_data['test_f1'].mean()
                std_f1 = method_data['test_f1'].std()
                avg_duration = method_data['duration'].mean()

                print(f"  • Test F1: {avg_f1:.2f}% ± {std_f1:.2f}%")
                print(f"  • Duration: {avg_duration:.1f}s average")
                print(f"  • Best split: {method_data.loc[method_data['test_f1'].idxmax(), 'split'].upper()}")

        print("\n" + "=" * 80)

    def save_all_plots(self, output_dir='./sampling_ablation_plots', formats=['png', 'pdf']):
        """Generate and save all visualization plots"""
        os.makedirs(output_dir, exist_ok=True)

        print(f"Generating and saving plots to {output_dir}/")

        # 1. Method comparison plots
        metrics = ['test_f1', 'test_accuracy', 'val_f1', 'val_loss', 'duration']

        for metric in metrics:
            print(f"  • Generating {metric} method comparison plot...")
            fig = self.plot_method_comparison(metric=metric)
            for fmt in formats:
                plt.savefig(f"{output_dir}/{metric}_method_comparison.{fmt}",
                            dpi=300, bbox_inches='tight', facecolor='white')
            plt.close(fig)

        # 2. Performance vs Duration plots
        for metric in ['test_f1', 'test_accuracy', 'val_f1']:
            print(f"  • Generating {metric} vs duration plot...")
            fig = self.plot_performance_vs_duration(metric=metric)
            for fmt in formats:
                plt.savefig(f"{output_dir}/{metric}_vs_duration.{fmt}",
                            dpi=300, bbox_inches='tight', facecolor='white')
            plt.close(fig)

        # 3. Comprehensive analysis
        print(f"  • Generating comprehensive analysis...")
        fig = self.plot_comprehensive_analysis_row()
        for fmt in formats:
            plt.savefig(f"{output_dir}/comprehensive_analysis_row.{fmt}",
                        dpi=300, bbox_inches='tight', facecolor='white')
        plt.close(fig)

        # 4. Comprehensive analysis without duration
        print(f"  • Generating comprehensive analysis without duration...")
        fig = self.plot_comprehensive_analysis_no_duration()
        for fmt in formats:
            plt.savefig(f"{output_dir}/comprehensive_analysis_no_duration.{fmt}",
                        dpi=300, bbox_inches='tight', facecolor='white')
        plt.close(fig)

        # 5. Heatmaps
        for metric in ['test_f1', 'test_accuracy', 'val_f1']:
            print(f"  • Generating {metric} heatmap...")
            fig = self.plot_heatmap(metric=metric)
            for fmt in formats:
                plt.savefig(f"{output_dir}/{metric}_heatmap.{fmt}",
                            dpi=300, bbox_inches='tight', facecolor='white')
            plt.close(fig)

        print("All plots saved successfully!")


def run_analysis():
    """Run the complete visualization and analysis"""
    print("Sampling Methods Ablation Study Visualization")
    print("=" * 60)

    # Initialize visualizer
    visualizer = SamplingMethodsAblationVisualizer()

    # Generate statistics report
    visualizer.generate_statistics_report()

    print("\nGenerating visualizations...")

    # Create and display individual plots
    print("\n1. Method comparison (Test F1 Score)...")
    fig1 = visualizer.plot_method_comparison(metric='test_f1')
    plt.show()

    print("\n2. Performance vs Duration...")
    fig2 = visualizer.plot_performance_vs_duration(metric='test_f1')
    plt.show()

    print("\n3. Comprehensive analysis with all metrics...")
    fig3 = visualizer.plot_comprehensive_analysis_row()
    plt.show()

    print("\n4. Comprehensive analysis without duration...")
    fig4 = visualizer.plot_comprehensive_analysis_no_duration()
    plt.show()

    print("\n5. Performance heatmap...")
    fig5 = visualizer.plot_heatmap(metric='test_f1')
    plt.show()

    # Save all plots
    print("\nSaving all plots...")
    visualizer.save_all_plots()

    print("\nAnalysis complete!")


if __name__ == "__main__":
    run_analysis()