import os
import numpy as np
import pandas as pd
from collections import defaultdict

from sklearn.metrics import confusion_matrix, classification_report
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

def tabulate_events(dpath):
    summary_iterators = [EventAccumulator(os.path.join(dpath, dname)).Reload() for dname in os.listdir(dpath) if dname.startswith('events')]
    assert len(summary_iterators) == 1
    tags = set(*[si.Tags()['scalars'] for si in summary_iterators])
    
    out = defaultdict(list)
    steps = []

    for tag in tags:
        steps = [e.step for e in summary_iterators[0].Scalars(tag)]
        for events in zip(*[acc.Scalars(tag) for acc in summary_iterators]):
            assert len(set(e.step for e in events)) == 1
            out[tag].append([e.value for e in events])
    return out, steps

def to_csv(dpath):
    dirs = os.listdir(dpath)

    d, steps = tabulate_events(dpath)
    tags, values = zip(*d.items())
    np_values = np.array(values)
    df = pd.DataFrame(dict((f"{tags[i]}", np_values[i][:, 0]) for i in range(np_values.shape[0])), index=steps, columns=tags)
    df.to_csv(os.path.join(dpath, "logger.csv"))

def read_event(path):
    to_csv(path)
    return pd.read_csv(os.path.join(path, "logger.csv"), index_col=0)

def empty_dir(folder):
    if os.path.exists(folder):
        for filename in os.listdir(folder):
            file_path = os.path.join(folder, filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)
            except Exception as e:
                print('Failed to delete %s. Reason: %s' % (file_path, e))



def log_run_results_to_file(
    output_path,
    dataset_name,
    split_type,
    val_losses,
    accs,
    durations,
    str_optimizer,
    str_preconditioner,
    lr,
    weight_decay,
    epochs,
    runs,
    early_stopping,
    momentum=None,
    eps=None,
    update_freq=None,
    gamma=None,
    alpha=None,
    hyperparam=None
):
    from statistics import mean, stdev

    def format_value(val):
        return f"{val:.4f}" if isinstance(val, float) else str(val)

    # Aggregate stats
    mean_val_loss = mean(val_losses)
    mean_acc = 100 * mean(accs)
    std_acc = 100 * stdev(accs) if len(accs) > 1 else 0.0
    mean_time = mean(durations)

    # 🎨 Apply colors (HTML inline styles)
    colored_acc = f"<span style='color:green;font-weight:bold'>{mean_acc:.2f}</span>"
    colored_std = f"<span style='color:gray'>{std_acc:.2f}</span>"
    colored_val_loss = f"<span style='color:red'>{mean_val_loss:.4f}</span>"
    colored_time = f"<span style='color:blue'>{mean_time:.3f}</span>"

    headers = [
        "Test Acc (%)", "± Std", "Val Loss", "Time (s)",
        "Dataset", "Split", "Runs", "Epochs", "Optimizer", "Preconditioner",
        "LR", "Weight Decay", "Early Stop", "Momentum", "Eps", "Update Freq",
        "Gamma", "Alpha", "Hyperparam"
    ]

    values = [
        colored_acc, colored_std, colored_val_loss, colored_time,
        dataset_name, split_type, runs, epochs, str_optimizer, str_preconditioner,
        format_value(lr), format_value(weight_decay), early_stopping, format_value(momentum),
        format_value(eps), format_value(update_freq), format_value(gamma),
        format_value(alpha), str(hyperparam)
    ]

    # Create or append Markdown table
    if not os.path.exists(output_path):
        table = "| " + " | ".join(headers) + " |\n"
        table += "| " + " | ".join(["---"] * len(headers)) + " |\n"
        mode = "w"
    else:
        table = ""
        mode = "a"

    table += "| " + " | ".join(map(str, values)) + " |\n"

    with open(output_path, mode) as f:
        f.write(table)

    print(f"✅ Results written to: {output_path}")



import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import torch
import os


def visualize_class_distribution(
        model,
        data,
        dataset_name,
        split_type,
        save_dir='vis',
        class_names=None,
        show_predictions=True
):
    """
    Visualize class distribution for the best model including:
    1. True class distribution
    2. Predicted class distribution
    3. Confusion matrix
    4. Class-wise performance metrics

    Args:
        model: Trained PyTorch model
        data: Dataset object with node features and labels
        dataset_name: Name of the dataset
        split_type: Type of split ('public', 'complete', etc.)
        save_dir: Directory to save visualizations
        class_names: List of class names (optional)
        show_predictions: Whether to show prediction distributions
    """

    # Create save directory
    os.makedirs(save_dir, exist_ok=True)

    # Set model to evaluation mode
    model.eval()

    # Get predictions
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=-1)

    # Extract labels and predictions based on split
    if hasattr(data, 'train_mask') and hasattr(data, 'val_mask') and hasattr(data, 'test_mask'):
        # Graph-based data with masks
        train_labels = data.y[data.train_mask].cpu().numpy()
        val_labels = data.y[data.val_mask].cpu().numpy()
        test_labels = data.y[data.test_mask].cpu().numpy()

        train_preds = pred[data.train_mask].cpu().numpy()
        val_preds = pred[data.val_mask].cpu().numpy()
        test_preds = pred[data.test_mask].cpu().numpy()

        splits_data = {
            'Train': (train_labels, train_preds),
            'Validation': (val_labels, val_preds),
            'Test': (test_labels, test_preds)
        }
    else:
        # Fallback: use all data
        all_labels = data.y.cpu().numpy()
        all_preds = pred.cpu().numpy()
        splits_data = {
            'All Data': (all_labels, all_preds)
        }

    # Determine number of classes
    all_labels_combined = np.concatenate([labels for labels, _ in splits_data.values()])
    n_classes = len(np.unique(all_labels_combined))

    # Generate class names if not provided
    if class_names is None:
        class_names = [f'Class {i}' for i in range(n_classes)]

    # Color palette for classes
    colors = plt.cm.Set3(np.linspace(0, 1, n_classes))

    # Set style
    plt.style.use('seaborn-v0_8-whitegrid')
    plt.rcParams.update({
        'font.size': 10,
        'axes.labelsize': 12,
        'axes.titlesize': 14,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': 10
    })

    # 1. Overall Class Distribution Comparison
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle(f'Class Distribution Analysis - {dataset_name} ({split_type})',
                 fontsize=16, fontweight='bold')

    # Plot for each split
    for idx, (split_name, (labels, preds)) in enumerate(splits_data.items()):
        if idx < 4:  # Maximum 4 subplots
            ax = axes[idx // 2, idx % 2]

            # Count true labels
            label_counts = Counter(labels)
            pred_counts = Counter(preds) if show_predictions else None

            classes = list(range(n_classes))
            true_counts = [label_counts.get(i, 0) for i in classes]
            pred_counts_list = [pred_counts.get(i, 0) for i in classes] if pred_counts else None

            x = np.arange(len(classes))
            width = 0.35

            # Plot true distribution
            bars1 = ax.bar(x - width / 2, true_counts, width,
                           label='True Labels', alpha=0.8,
                           color=colors)

            # Plot predicted distribution if requested
            if show_predictions and pred_counts_list:
                bars2 = ax.bar(x + width / 2, pred_counts_list, width,
                               label='Predictions', alpha=0.8,
                               color=colors, hatch='///')

            # Add value labels on bars
            for bar in bars1:
                height = bar.get_height()
                ax.annotate(f'{int(height)}',
                            xy=(bar.get_x() + bar.get_width() / 2, height),
                            xytext=(0, 3),
                            textcoords="offset points",
                            ha='center', va='bottom', fontsize=9)

            if show_predictions and pred_counts_list:
                for bar in bars2:
                    height = bar.get_height()
                    ax.annotate(f'{int(height)}',
                                xy=(bar.get_x() + bar.get_width() / 2, height),
                                xytext=(0, 3),
                                textcoords="offset points",
                                ha='center', va='bottom', fontsize=9)

            ax.set_title(f'{split_name} Set Distribution', fontweight='bold')
            ax.set_xlabel('Class')
            ax.set_ylabel('Count')
            ax.set_xticks(x)
            ax.set_xticklabels(class_names, rotation=45, ha='right')
            ax.legend()
            ax.grid(True, alpha=0.3)

    # Hide unused subplots
    for idx in range(len(splits_data), 4):
        axes[idx // 2, idx % 2].set_visible(False)

    plt.tight_layout()
    plt.savefig(f'{save_dir}/class_distribution_comparison.png', dpi=300, bbox_inches='tight')
    plt.savefig(f'{save_dir}/class_distribution_comparison.pdf', dpi=300, bbox_inches='tight')
    plt.close()

    # 2. Pie Charts for Class Distribution
    n_splits = len(splits_data)
    fig, axes = plt.subplots(1, n_splits, figsize=(5 * n_splits, 5))
    if n_splits == 1:
        axes = [axes]

    fig.suptitle(f'Class Distribution Pie Charts - {dataset_name} ({split_type})',
                 fontsize=16, fontweight='bold')

    for idx, (split_name, (labels, preds)) in enumerate(splits_data.items()):
        ax = axes[idx]

        label_counts = Counter(labels)
        counts = [label_counts.get(i, 0) for i in range(n_classes)]

        # Create pie chart
        wedges, texts, autotexts = ax.pie(counts, labels=class_names, autopct='%1.1f%%',
                                          colors=colors, startangle=90)

        # Enhance text
        for autotext in autotexts:
            autotext.set_color('white')
            autotext.set_fontweight('bold')

        ax.set_title(f'{split_name} Set', fontweight='bold')

    plt.tight_layout()
    plt.savefig(f'{save_dir}/class_distribution_pie_charts.png', dpi=300, bbox_inches='tight')
    plt.savefig(f'{save_dir}/class_distribution_pie_charts.pdf', dpi=300, bbox_inches='tight')
    plt.close()

    # 3. Confusion Matrix (for test set if available)
    if 'Test' in splits_data or len(splits_data) == 1:
        test_data = splits_data.get('Test', list(splits_data.values())[0])
        test_labels, test_preds = test_data

        from sklearn.metrics import confusion_matrix, classification_report

        cm = confusion_matrix(test_labels, test_preds)

        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=class_names, yticklabels=class_names)
        plt.title(f'Confusion Matrix - {dataset_name} ({split_type})',
                  fontweight='bold', pad=20)
        plt.xlabel('Predicted Class')
        plt.ylabel('True Class')
        plt.tight_layout()
        plt.savefig(f'{save_dir}/confusion_matrix.png', dpi=300, bbox_inches='tight')
        plt.savefig(f'{save_dir}/confusion_matrix.pdf', dpi=300, bbox_inches='tight')
        plt.close()

        # Generate and save classification report
        report = classification_report(test_labels, test_preds,
                                       target_names=class_names,
                                       output_dict=True)

        # Convert to DataFrame for better visualization
        report_df = pd.DataFrame(report).transpose()

        # Save classification report
        report_df.to_csv(f'{save_dir}/classification_report.csv')

        # Visualize classification metrics
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))

        # Precision, Recall, F1-Score by class
        metrics = ['precision', 'recall', 'f1-score']
        class_metrics = report_df.iloc[:-3]  # Exclude avg rows

        for idx, metric in enumerate(metrics):
            ax = axes[idx]
            bars = ax.bar(range(len(class_names)), class_metrics[metric],
                          color=colors, alpha=0.8)

            # Add value labels
            for bar in bars:
                height = bar.get_height()
                ax.annotate(f'{height:.3f}',
                            xy=(bar.get_x() + bar.get_width() / 2, height),
                            xytext=(0, 3),
                            textcoords="offset points",
                            ha='center', va='bottom', fontsize=9)

            ax.set_title(f'{metric.capitalize()} by Class', fontweight='bold')
            ax.set_xlabel('Class')
            ax.set_ylabel(metric.capitalize())
            ax.set_xticks(range(len(class_names)))
            ax.set_xticklabels(class_names, rotation=45, ha='right')
            ax.set_ylim(0, 1.1)
            ax.grid(True, alpha=0.3)

        plt.suptitle(f'Classification Metrics by Class - {dataset_name} ({split_type})',
                     fontsize=16, fontweight='bold')
        plt.tight_layout()
        plt.savefig(f'{save_dir}/classification_metrics.png', dpi=300, bbox_inches='tight')
        plt.savefig(f'{save_dir}/classification_metrics.pdf', dpi=300, bbox_inches='tight')
        plt.close()

    # 4. Class Imbalance Analysis
    plt.figure(figsize=(12, 6))

    # Calculate class imbalance ratio
    all_labels = np.concatenate([labels for labels, _ in splits_data.values()])
    label_counts = Counter(all_labels)
    total_samples = len(all_labels)

    class_ratios = [label_counts.get(i, 0) / total_samples for i in range(n_classes)]

    bars = plt.bar(range(n_classes), class_ratios, color=colors, alpha=0.8)

    # Add percentage labels
    for bar in bars:
        height = bar.get_height()
        plt.annotate(f'{height * 100:.1f}%',
                     xy=(bar.get_x() + bar.get_width() / 2, height),
                     xytext=(0, 3),
                     textcoords="offset points",
                     ha='center', va='bottom', fontsize=10)

    plt.title(f'Class Imbalance Analysis - {dataset_name} ({split_type})',
              fontweight='bold', pad=20)
    plt.xlabel('Class')
    plt.ylabel('Proportion of Total Samples')
    plt.xticks(range(n_classes), class_names, rotation=45, ha='right')

    # Add horizontal line for balanced distribution
    balanced_ratio = 1.0 / n_classes
    plt.axhline(y=balanced_ratio, color='red', linestyle='--', alpha=0.7,
                label=f'Balanced ({balanced_ratio * 100:.1f}%)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f'{save_dir}/class_imbalance_analysis.png', dpi=300, bbox_inches='tight')
    plt.savefig(f'{save_dir}/class_imbalance_analysis.pdf', dpi=300, bbox_inches='tight')
    plt.close()

    print(f"Class distribution visualizations saved in '{save_dir}' directory:")
    print("  - class_distribution_comparison.png/pdf")
    print("  - class_distribution_pie_charts.png/pdf")
    print("  - confusion_matrix.png/pdf")
    print("  - classification_metrics.png/pdf")
    print("  - class_imbalance_analysis.png/pdf")
    print("  - classification_report.csv")

    return report_df if 'Test' in splits_data or len(splits_data) == 1 else None


# Example usage function to integrate with your training loop
def get_best_model_and_visualize(
        models_results,  # List of (model, metrics) tuples from your runs
        data,
        dataset_name,
        split_type,
        metric='test_f1',  # Metric to determine best model
        class_names=None
):
    """
    Find the best model based on specified metric and create visualizations

    Args:
        models_results: List of tuples [(model, {'test_f1': score, 'test_acc': acc, ...})]
        data: Dataset object
        dataset_name: Name of dataset
        split_type: Split type
        metric: Metric to determine best model ('test_f1', 'test_acc', etc.)
        class_names: Optional list of class names
    """

    # Find best model
    best_model, best_metrics = max(models_results, key=lambda x: x[1][metric])

    print(f"Best model selected based on {metric}: {best_metrics[metric]:.4f}")

    # Create visualizations for best model
    report = visualize_class_distribution(
        model=best_model,
        data=data,
        dataset_name=dataset_name,
        split_type=split_type,
        class_names=class_names
    )

    return best_model, best_metrics, report


# Integration helper function for your existing training loop
def integrate_with_training_loop():
    """
    Example of how to integrate this with your existing training code
    """
    print("""
    # Integration example - add this to your training loop:

    # After your training loop, collect models and results
    models_and_results = []

    # In your training loop, after each run:
    # models_and_results.append((model.state_dict(), {
    #     'test_f1': t_f1.mean().item(),
    #     'test_acc': acc.mean().item(),
    #     'val_f1': v_f1.mean().item()
    # }))

    # After all runs, find best model and visualize:
    # best_model_dict, best_metrics, report = get_best_model_and_visualize(
    #     models_and_results,
    #     data,
    #     dataset.name,
    #     split,
    #     metric='test_f1'
    # )
    """)


# Run integration example
#integrate_with_training_loop()


def visualize_class_distribution(
        model,
        data,
        dataset_name,
        split_type,
        save_dir='vis',
        class_names=None,
        show_predictions=True,
        use_tsne=True,
        use_pca=True,
        perplexity=30,
        random_state=42
):
    """
    Visualize class distribution for the best model including:
    1. True class distribution
    2. Predicted class distribution
    3. Confusion matrix
    4. Class-wise performance metrics
    5. Clustering-style scatter plots (t-SNE/PCA) with class coloring

    Args:
        model: Trained PyTorch model
        data: Dataset object with node features and labels
        dataset_name: Name of the dataset
        split_type: Type of split ('public', 'complete', etc.)
        save_dir: Directory to save visualizations
        class_names: List of class names (optional)
        show_predictions: Whether to show prediction distributions
        use_tsne: Whether to create t-SNE visualization
        use_pca: Whether to create PCA visualization
        perplexity: t-SNE perplexity parameter
        random_state: Random state for reproducibility
    """
    # Import required libraries
    from sklearn.manifold import TSNE
    from sklearn.decomposition import PCA
    from sklearn.preprocessing import StandardScaler

    # Initialize configuration values at the beginning
    BAR_WIDTH = 0.35
    FIGURE_DPI = 300
    FONT_SIZES = {
        'base': 14,
        'axes_label': 18,
        'title': 16,
        'suptitle': 18,
        'tick': 16,
        'legend': 14,
        'annotation': 15
    }

    # t-SNE configuration values
    TSNE_MAX_SAMPLES = 1000  # Maximum samples for t-SNE to ensure performance
    TSNE_ITERATIONS = 1000

    # Plot styling values
    SCATTER_SIZE = 60
    SCATTER_ALPHA = 0.7
    EDGE_LINE_WIDTH = 0.5
    INCORRECT_EDGE_WIDTH = 2
    GRID_ALPHA = 0.3
    BAR_ALPHA = 0.8

    # Create save directory
    os.makedirs(save_dir, exist_ok=True)

    # Set model to evaluation mode
    model.eval()

    # Get predictions and feature embeddings
    with torch.no_grad():
        model_output = model(data)
        predictions = model_output.argmax(dim=-1)

        # Get intermediate features for clustering visualization
        # Try to get features from the model's last hidden layer
        if hasattr(model, 'get_embeddings'):
            feature_embeddings = model.get_embeddings(data).cpu().numpy()
        elif hasattr(model, 'features'):
            feature_embeddings = model.features(data).cpu().numpy()
        else:
            # Fallback: use the raw features or model output
            if hasattr(data, 'x') and data.x is not None:
                feature_embeddings = data.x.cpu().numpy()
            else:
                feature_embeddings = model_output.cpu().numpy()

    # Extract labels and predictions based on split
    if hasattr(data, 'train_mask') and hasattr(data, 'val_mask') and hasattr(data, 'test_mask'):
        # Graph-based data with masks
        train_labels = data.y[data.train_mask].cpu().numpy()
        val_labels = data.y[data.val_mask].cpu().numpy()
        test_labels = data.y[data.test_mask].cpu().numpy()

        train_predictions = predictions[data.train_mask].cpu().numpy()
        val_predictions = predictions[data.val_mask].cpu().numpy()
        test_predictions = predictions[data.test_mask].cpu().numpy()

        train_features = feature_embeddings[data.train_mask.cpu().numpy()]
        val_features = feature_embeddings[data.val_mask.cpu().numpy()]
        test_features = feature_embeddings[data.test_mask.cpu().numpy()]

        splits_data = {
            'Train': (train_labels, train_predictions, train_features),
            'Validation': (val_labels, val_predictions, val_features),
            'Test': (test_labels, test_predictions, test_features)
        }
    else:
        # Fallback: use all data
        all_labels = data.y.cpu().numpy()
        all_predictions = predictions.cpu().numpy()
        splits_data = {
            'All Data': (all_labels, all_predictions, feature_embeddings)
        }

    # Determine number of classes
    all_labels_combined = np.concatenate([labels for labels, _, _ in splits_data.values()])
    num_classes = len(np.unique(all_labels_combined))

    # Generate class names if not provided
    if class_names is None:
        class_names = [f'Class {i}' for i in range(num_classes)]

    # Color palette for classes
    class_colors = plt.cm.Set3(np.linspace(0, 1, num_classes))

    # Set matplotlib style and configuration
    plt.style.use('seaborn-v0_8-whitegrid')
    plt.rcParams.update({
        'font.size': FONT_SIZES['base'],
        'axes.labelsize': FONT_SIZES['axes_label'],
        'axes.titlesize': FONT_SIZES['title'],
        'xtick.labelsize': FONT_SIZES['tick'],
        'ytick.labelsize': FONT_SIZES['tick'],
        'legend.fontsize': FONT_SIZES['legend']
    })

    # 1. Overall Class Distribution Comparison
    fig, axes = plt.subplots(2, 3, figsize=(15, 12))
    fig.suptitle(f'Class Distribution Analysis - {dataset_name} ({split_type})',
                 fontsize=FONT_SIZES['suptitle'], fontweight='bold')

    # Plot for each split
    for idx, (split_name, (labels, preds, _)) in enumerate(splits_data.items()):
        if idx < 4:  # Maximum 4 subplots
            ax = axes[idx // 3, idx % 3]

            # Count true labels
            label_counts = Counter(labels)
            pred_counts = Counter(preds) if show_predictions else None

            classes = list(range(num_classes))
            true_counts = [label_counts.get(i, 0) for i in classes]
            pred_counts_list = [pred_counts.get(i, 0) for i in classes] if pred_counts else None

            x_positions = np.arange(len(classes))

            # Plot true distribution
            bars1 = ax.bar(x_positions - BAR_WIDTH / 2, true_counts, BAR_WIDTH,
                           label='True Labels', alpha=BAR_ALPHA,
                           color=class_colors)

            # Plot predicted distribution if requested
            if show_predictions and pred_counts_list:
                bars2 = ax.bar(x_positions + BAR_WIDTH / 2, pred_counts_list, BAR_WIDTH,
                               label='Predictions', alpha=BAR_ALPHA,
                               color=class_colors, hatch='///')

            # Add value labels on bars
            for bar in bars1:
                height = bar.get_height()
                ax.annotate(f'{int(height)}',
                            xy=(bar.get_x() + bar.get_width() / 2, height),
                            xytext=(0, 3),
                            textcoords="offset points",
                            ha='center', va='bottom', fontsize=FONT_SIZES['annotation'])

            if show_predictions and pred_counts_list:
                for bar in bars2:
                    height = bar.get_height()
                    ax.annotate(f'{int(height)}',
                                xy=(bar.get_x() + bar.get_width() / 2, height),
                                xytext=(0, 3),
                                textcoords="offset points",
                                ha='center', va='bottom', fontsize=FONT_SIZES['annotation'])

            ax.set_title(f'{split_name} Set Distribution', fontweight='bold')
            ax.set_xlabel('Class')
            ax.set_ylabel('Count')
            ax.set_xticks(x_positions)
            ax.set_xticklabels(class_names, rotation=45, ha='right')
            ax.legend()
            ax.grid(True, alpha=GRID_ALPHA)

    # Hide unused subplots
    for idx in range(len(splits_data), 4):
        axes[idx // 3, idx % 3].set_visible(False)

    plt.tight_layout()
    plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-class_distribution_comparison.png',
                dpi=FIGURE_DPI, bbox_inches='tight')
    plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-class_distribution_comparison.pdf',
                dpi=FIGURE_DPI, bbox_inches='tight')
    plt.close()

    # Helper function for clustering visualization
    def create_clustering_visualization(features, labels, preds, method_name, split_name):
        """Create clustering-style scatter plots for the given features and labels."""

        # Standardize features
        scaler = StandardScaler()
        features_scaled = scaler.fit_transform(features)

        # Apply dimensionality reduction
        if method_name == 't-SNE':
            if len(features) > TSNE_MAX_SAMPLES:  # Subsample for large datasets for t-SNE performance
                indices = np.random.choice(len(features), TSNE_MAX_SAMPLES, replace=False)
                features_subset = features_scaled[indices]
                labels_subset = labels[indices]
                preds_subset = preds[indices]
            else:
                features_subset = features_scaled
                labels_subset = labels
                preds_subset = preds

            reducer = TSNE(n_components=2, perplexity=min(perplexity, len(features_subset) // 4),
                           random_state=random_state, max_iter=TSNE_ITERATIONS)
            features_2d = reducer.fit_transform(features_subset)
        else:  # PCA
            reducer = PCA(n_components=2, random_state=random_state)
            features_2d = reducer.fit_transform(features_scaled)
            features_subset = features_scaled
            labels_subset = labels
            preds_subset = preds

        # Create subplots for true vs predicted
        fig, axes = plt.subplots(1, 2, figsize=(16, 7))

        # True labels plot
        ax1 = axes[0]
        scatter = ax1.scatter(features_2d[:, 0], features_2d[:, 1],
                              c=[class_colors[label] for label in labels_subset],
                              alpha=SCATTER_ALPHA, s=SCATTER_SIZE,
                              edgecolors='black', linewidth=EDGE_LINE_WIDTH)
        ax1.set_title(f'True Labels - {method_name} ({split_name})', fontweight='bold')
        ax1.set_xlabel(f'{method_name} Component 1')
        ax1.set_ylabel(f'{method_name} Component 2')
        ax1.grid(True, alpha=GRID_ALPHA)

        # Create legend for true labels
        legend_elements = [plt.Line2D([0], [0], marker='o', color='w',
                                      markerfacecolor=class_colors[i], markersize=10,
                                      label=class_names[i]) for i in range(num_classes)]
        ax1.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.05, 1))

        # Predicted labels plot
        ax2 = axes[1]
        scatter2 = ax2.scatter(features_2d[:, 0], features_2d[:, 1],
                               c=[class_colors[pred] for pred in preds_subset],
                               alpha=SCATTER_ALPHA, s=SCATTER_SIZE,
                               edgecolors='black', linewidth=EDGE_LINE_WIDTH)
        ax2.set_title(f'Predicted Labels - {method_name} ({split_name})', fontweight='bold')
        ax2.set_xlabel(f'{method_name} Component 1')
        ax2.set_ylabel(f'{method_name} Component 2')
        ax2.grid(True, alpha=GRID_ALPHA)

        # Create legend for predicted labels
        ax2.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.05, 1))

        # Add accuracy information
        accuracy = np.mean(labels_subset == preds_subset)
        fig.suptitle(f'{method_name} Clustering Visualization - {split_name} Set\n'
                     f'Accuracy: {accuracy:.3f}', fontsize=FONT_SIZES['suptitle'], fontweight='bold')

        plt.tight_layout()

        # Save the plot
        method_str = method_name.lower().replace('-', '')
        plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-{method_str}_clustering_{split_name.lower()}.png',
                    dpi=FIGURE_DPI, bbox_inches='tight')
        plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-{method_str}_clustering_{split_name.lower()}.pdf',
                    dpi=FIGURE_DPI, bbox_inches='tight')
        plt.close()

        return features_2d, accuracy

    # Helper function for combined clustering plot
    def create_combined_clustering_plot(features, labels, preds, method_name):
        """Create a combined plot showing correct and incorrect predictions."""

        # Standardize features
        scaler = StandardScaler()
        features_scaled = scaler.fit_transform(features)

        # Apply dimensionality reduction
        if method_name == 't-SNE':
            if len(features) > TSNE_MAX_SAMPLES:
                indices = np.random.choice(len(features), TSNE_MAX_SAMPLES, replace=False)
                features_subset = features_scaled[indices]
                labels_subset = labels[indices]
                preds_subset = preds[indices]
            else:
                features_subset = features_scaled
                labels_subset = labels
                preds_subset = preds

            reducer = TSNE(n_components=2, perplexity=min(perplexity, len(features_subset) // 4),
                           random_state=random_state, max_iter=TSNE_ITERATIONS)
            features_2d = reducer.fit_transform(features_subset)
        else:  # PCA
            reducer = PCA(n_components=2, random_state=random_state)
            features_2d = reducer.fit_transform(features_scaled)
            features_subset = features_scaled
            labels_subset = labels
            preds_subset = preds

        # Identify correct and incorrect predictions
        correct_mask = labels_subset == preds_subset
        incorrect_mask = ~correct_mask

        fig, ax = plt.subplots(1, 1, figsize=(12, 9))

        # Plot correct predictions
        if np.sum(correct_mask) > 0:
            scatter_correct = ax.scatter(features_2d[correct_mask, 0], features_2d[correct_mask, 1],
                                         c=[class_colors[label] for label in labels_subset[correct_mask]],
                                         alpha=SCATTER_ALPHA, s=SCATTER_SIZE,
                                         edgecolors='black', linewidth=EDGE_LINE_WIDTH,
                                         marker='o', label='Correct Predictions')

        # Plot incorrect predictions
        if np.sum(incorrect_mask) > 0:
            scatter_incorrect = ax.scatter(features_2d[incorrect_mask, 0], features_2d[incorrect_mask, 1],
                                           c=[class_colors[label] for label in labels_subset[incorrect_mask]],
                                           alpha=SCATTER_ALPHA, s=SCATTER_SIZE,
                                           edgecolors='red', linewidth=INCORRECT_EDGE_WIDTH,
                                           marker='X', label='Incorrect Predictions')

        ax.set_title(f'{method_name} Clustering: Prediction Accuracy Visualization', fontweight='bold')
        ax.set_xlabel(f'{method_name} Component 1')
        ax.set_ylabel(f'{method_name} Component 2')
        ax.grid(True, alpha=GRID_ALPHA)

        # Create custom legend combining class colors and prediction accuracy
        legend_elements = []
        for i in range(num_classes):
            legend_elements.append(plt.Line2D([0], [0], marker='o', color='w',
                                              markerfacecolor=class_colors[i], markersize=10,
                                              label=f'{class_names[i]}'))

        legend_elements.append(plt.Line2D([0], [0], marker='o', color='w',
                                          markerfacecolor='gray', markeredgecolor='black',
                                          markersize=10, label='Correct'))
        legend_elements.append(plt.Line2D([0], [0], marker='X', color='w',
                                          markerfacecolor='gray', markeredgecolor='red',
                                          markersize=10, label='Incorrect'))

        ax.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(1.05, 0.5))

        accuracy = np.mean(correct_mask)
        plt.figtext(0.02, 0.02, f'Overall Accuracy: {accuracy:.3f} '
                                f'({np.sum(correct_mask)}/{len(labels_subset)} correct)',
                    fontsize=FONT_SIZES['axes_label'], fontweight='bold')

        plt.tight_layout()

        method_str = method_name.lower().replace('-', '')
        plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-{method_str}_prediction_accuracy.png',
                    dpi=FIGURE_DPI, bbox_inches='tight')
        plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-{method_str}_prediction_accuracy.pdf',
                    dpi=FIGURE_DPI, bbox_inches='tight')
        plt.close()

        return features_2d, accuracy

    # Generate clustering visualizations for each split and method
    clustering_results = {}

    for split_name, (labels, preds, split_features) in splits_data.items():
        if len(split_features) < 10:  # Skip if too few samples
            continue

        clustering_results[split_name] = {}

        if use_tsne:
            print(f"Generating t-SNE visualization for {split_name} set...")
            features_2d, accuracy = create_clustering_visualization(
                split_features, labels, preds, 't-SNE', split_name)
            clustering_results[split_name]['tsne'] = (features_2d, accuracy)

            # Create combined accuracy plot
            create_combined_clustering_plot(split_features, labels, preds, 't-SNE')

        if use_pca:
            print(f"Generating PCA visualization for {split_name} set...")
            features_2d, accuracy = create_clustering_visualization(
                split_features, labels, preds, 'PCA', split_name)
            clustering_results[split_name]['pca'] = (features_2d, accuracy)

            # Create combined accuracy plot
            create_combined_clustering_plot(split_features, labels, preds, 'PCA')

    # 2. Pie Charts for Class Distribution
    n_splits = len(splits_data)
    fig, axes = plt.subplots(1, n_splits, figsize=(5 * n_splits, 5))
    if n_splits == 1:
        axes = [axes]

    fig.suptitle(f'Class Distribution Pie Charts - {dataset_name} ({split_type})',
                 fontsize=FONT_SIZES['suptitle'], fontweight='bold')

    for idx, (split_name, (labels, preds, _)) in enumerate(splits_data.items()):
        ax = axes[idx]

        label_counts = Counter(labels)
        counts = [label_counts.get(i, 0) for i in range(num_classes)]

        # Create pie chart
        wedges, texts, autotexts = ax.pie(counts, labels=class_names, autopct='%1.1f%%',
                                          colors=class_colors, startangle=90)

        # Enhance text
        for autotext in autotexts:
            autotext.set_color('white')
            autotext.set_fontweight('bold')

        ax.set_title(f'{split_name} Set', fontweight='bold')

    plt.tight_layout()
    plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-class_distribution_pie_charts.png',
                dpi=FIGURE_DPI, bbox_inches='tight')
    plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-class_distribution_pie_charts.pdf',
                dpi=FIGURE_DPI, bbox_inches='tight')
    plt.close()

    # 3. Confusion Matrix (for test set if available)
    report_df = None
    if 'Test' in splits_data or len(splits_data) == 1:
        test_data = splits_data.get('Test', list(splits_data.values())[0])
        test_labels, test_preds, _ = test_data

        cm = confusion_matrix(test_labels, test_preds)

        plt.figure(figsize=(10, 8))
        sns.heatmap(cm,
                    annot=True,  # show values
                    fmt='d',  # integer format
                    cmap='Blues',
                    xticklabels=class_names,
                    yticklabels=class_names,
                    annot_kws={
                        "size": FONT_SIZES['axes_label'],
                        "weight": "bold"
                    })
        plt.title(f'Confusion Matrix - {dataset_name} ({split_type})',
                  fontweight='bold', pad=20)
        plt.xlabel('Predicted Class')
        plt.ylabel('True Class')
        plt.tight_layout()
        plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-confusion_matrix.png',
                    dpi=FIGURE_DPI, bbox_inches='tight')
        plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-confusion_matrix.pdf',
                    dpi=FIGURE_DPI, bbox_inches='tight')
        plt.close()

        # Generate and save classification report
        report = classification_report(test_labels, test_preds,
                                       target_names=class_names,
                                       output_dict=True)

        # Convert to DataFrame for better visualization
        report_df = pd.DataFrame(report).transpose()

        # Save classification report
        report_df.to_csv(f'{save_dir}/{dataset_name}-{split_type}-classification_report.csv')

        # Visualize classification metrics
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))

        # Precision, Recall, F1-Score by class
        metrics = ['precision', 'recall', 'f1-score']
        class_metrics = report_df.iloc[:-3]  # Exclude avg rows

        for idx, metric in enumerate(metrics):
            ax = axes[idx]
            bars = ax.bar(range(len(class_names)), class_metrics[metric],
                          color=class_colors, alpha=BAR_ALPHA)

            # Add value labels
            for bar in bars:
                height = bar.get_height()
                ax.annotate(f'{height:.3f}',
                            xy=(bar.get_x() + bar.get_width() / 2, height),
                            xytext=(0, 3),
                            textcoords="offset points",
                            ha='center', va='bottom', fontsize=FONT_SIZES['annotation'])

            ax.set_title(f'{metric.capitalize()} by Class', fontweight='bold')
            ax.set_xlabel('Class')
            ax.set_ylabel(metric.capitalize())
            ax.set_xticks(range(len(class_names)))
            ax.set_xticklabels(class_names, rotation=45, ha='right')
            ax.set_ylim(0, 1.1)
            ax.grid(True, alpha=GRID_ALPHA)

        plt.suptitle(f'Classification Metrics by Class - {dataset_name} ({split_type})',
                     fontsize=FONT_SIZES['suptitle'], fontweight='bold')
        plt.tight_layout()
        plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-classification_metrics.png',
                    dpi=FIGURE_DPI, bbox_inches='tight')
        plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-classification_metrics.pdf',
                    dpi=FIGURE_DPI, bbox_inches='tight')
        plt.close()

    # 4. Class Imbalance Analysis
    plt.figure(figsize=(12, 6))

    # Calculate class imbalance ratio
    all_labels = np.concatenate([labels for labels, _, _ in splits_data.values()])
    label_counts = Counter(all_labels)
    total_samples = len(all_labels)

    class_ratios = [label_counts.get(i, 0) / total_samples for i in range(num_classes)]

    bars = plt.bar(range(num_classes), class_ratios, color=class_colors, alpha=BAR_ALPHA)

    # Add percentage labels
    for bar in bars:
        height = bar.get_height()
        plt.annotate(f'{height * 100:.1f}%',
                     xy=(bar.get_x() + bar.get_width() / 2, height),
                     xytext=(0, 3),
                     textcoords="offset points",
                     ha='center', va='bottom', fontsize=FONT_SIZES['base'])

    plt.title(f'Class Imbalance Analysis - {dataset_name} ({split_type})',
              fontweight='bold', pad=20)
    plt.xlabel('Class')
    plt.ylabel('Proportion of Total Samples')
    plt.xticks(range(num_classes), class_names, rotation=45, ha='right')

    # Add horizontal line for balanced distribution
    balanced_ratio = 1.0 / num_classes
    plt.axhline(y=balanced_ratio, color='red', linestyle='--', alpha=0.7,
                label=f'Balanced ({balanced_ratio * 100:.1f}%)')
    plt.legend()
    plt.grid(True, alpha=GRID_ALPHA)
    plt.tight_layout()
    plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-class_imbalance_analysis.png',
                dpi=FIGURE_DPI, bbox_inches='tight')
    plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-class_imbalance_analysis.pdf',
                dpi=FIGURE_DPI, bbox_inches='tight')
    plt.close()

    print(f"\n✅ Class distribution visualizations saved in '{save_dir}' directory:")
    print("📊 Basic Visualizations:")
    print("   • class_distribution_comparison.png/pdf")
    print("   • class_distribution_pie_charts.png/pdf")
    print("   • confusion_matrix.png/pdf")
    print("   • classification_metrics.png/pdf")
    print("   • class_imbalance_analysis.png/pdf")
    print("   • classification_report.csv")

    if use_tsne or use_pca:
        print("\n🎯 Clustering-style Visualizations:")
        for split_name in clustering_results.keys():
            if use_tsne:
                print(f"   • tsne_clustering_{split_name.lower()}.png/pdf")
                print(f"   • tsne_prediction_accuracy.png/pdf")
            if use_pca:
                print(f"   • pca_clustering_{split_name.lower()}.png/pdf")
                print(f"   • pca_prediction_accuracy.png/pdf")

    return report_df, clustering_results