from __future__ import division

import time
import os
import shutil
import os
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch import tensor
from torch.utils.tensorboard import SummaryWriter
import utils as ut
import psgd
from tqdm import tqdm

from sklearn.metrics import f1_score, precision_score, recall_score

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

GREEN = "\033[92m"
RED = "\033[91m"
RESET = "\033[0m"

# Print colored device info
if 'cuda' in device.type:
    print(f"{GREEN}Using GPU: {torch.cuda.get_device_name(0)}{RESET}")
else:
    print(f"{RED}Using CPU{RESET}")

path_runs = "runs"

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import torch
import os
from sklearn.metrics import confusion_matrix, classification_report
from tqdm import tqdm
import time
from torch import tensor
from torch.utils.tensorboard import SummaryWriter


def run(
    dataset, 
    model,
    split,
    str_optimizer, 
    str_preconditioner, 
    runs, 
    epochs, 
    lr,
    loss,
    loss_type,
    lambda_rbm,
    weight_decay, 
    early_stopping,  
    logger, 
    momentum,
    eps,
    update_freq,
    gamma,
    alpha,
    log_dict,
    hyperparam,
    save_dir='vis',
    class_names=None,
    visualize_classes=True
    ):
    if logger is not None:
        if hyperparam:
            logger += f"-{hyperparam}{eval(hyperparam)}"
        path_logger = os.path.join(path_runs, logger)
        print(f"path logger: {path_logger}")

        ut.empty_dir(path_logger)
        logger = SummaryWriter(log_dir=os.path.join(path_runs, logger)) if logger is not None else None

    val_losses, accs, val_f1s, test_f1s, durations = [], [], [], [], []
    models_and_results = []  # Store models for visualization

    torch.manual_seed(42)
    for i_run in range(runs):
        data = dataset[0]
        data = data.to(device)

        model.to(device).reset_parameters()
        if str_preconditioner == 'KFAC':

            preconditioner = psgd.KFAC(
                model,
                eps,
                sua=False,
                pi=False,
                update_freq=update_freq,
                alpha=alpha if alpha is not None else 1.,
                constraint_norm=False
            )
        else:
            preconditioner = None

        if str_optimizer == 'Adam':
            optimizer = torch.optim.Adam(
                model.parameters(),
                lr=lr,
                weight_decay=weight_decay
            )
        elif str_optimizer == 'SGD':
            optimizer = torch.optim.SGD(
                model.parameters(),
                lr=lr,
                momentum=momentum,
            )

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        t_start = time.perf_counter()

        best_val_loss = float('inf')
        test_acc = 0
        test_f1 = 0
        val_f1 = 0
        val_loss_history = []

        if split == 'geom-gcn':
            train_mask, val_mask, test_mask = data.train_mask, data.val_mask, data.test_mask
        else:
            train_mask, val_mask, test_mask = data.train_mask.unsqueeze(1) , data.val_mask.unsqueeze(1) , data.test_mask.unsqueeze(1)

        for epoch in tqdm(range(1, epochs + 1)):
            for col_idx in range(train_mask.size(1)):
                t_mask = train_mask[:, col_idx].to(torch.bool)
                v_mask = val_mask[:, col_idx].to(torch.bool)
                e_mask = test_mask[:, col_idx].to(torch.bool)

                mask = {'train_mask': t_mask, 'val_mask': v_mask, 'test_mask': e_mask}

                # Fixed training call with proper RBM loss handling
                train_info = train(model, optimizer, data, t_mask, preconditioner, lam=0.5, lambda_rbm=lambda_rbm,
                                   rbm_loss_type=loss_type)
                # free_energy - kl_divergence - cd - , reconstruction

                # Evaluation (unchanged)
                eval_info = evaluate(model, data, mask)

                # Update best metrics based on validation performance
                current_val_acc = eval_info.get('val_acc', 0)
                current_val_f1 = eval_info.get('val_f1', 0)

                # Use validation F1 or accuracy to determine best model
                if current_val_f1 > val_f1:  # or use current_val_acc > test_acc
                    best_val_loss = train_info["total_loss"]
                    test_acc = eval_info.get('test_acc', 0)
                    test_f1 = eval_info.get('test_f1', 0)
                    val_f1 = current_val_f1

                # Add logging with RBM loss
                eval_info.update({
                    'epoch': int(epoch),
                    'run': int(i_run + 1),
                    'time': time.perf_counter() - t_start,
                    'eps': eps,
                    'update-freq': update_freq,
                    'train_loss': train_info["total_loss"],
                    'classification_loss': train_info["classification_loss"],
                    'rbm_loss': train_info["rbm_loss"],
                })

                # Optional: print progress every few epochs
                if epoch % 1250 == 0:
                    print(f"Epoch {epoch}: Total Loss: {train_info['total_loss']:.4f}, "
                          f"Class Loss: {train_info['classification_loss']:.4f}, "
                          f"RBM Loss: {train_info['rbm_loss']:.4f}")

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        t_end = time.perf_counter()

        val_losses.append(best_val_loss)
        accs.append(test_acc)
        test_f1s.append(test_f1)
        val_f1s.append(val_f1)
        durations.append(t_end - t_start)

        # Store model state and results for visualization
        models_and_results.append((
            model.state_dict().copy(),
            {
                'test_f1': test_f1,
                'test_acc': test_acc,
                'val_f1': val_f1,
                'val_loss': best_val_loss
            }
        ))

    if logger is not None:
        logger.close()

    loss, acc, v_f1, t_f1, duration = tensor(val_losses), tensor(accs), tensor(val_f1s), tensor(test_f1s), tensor(
        durations)
    print(
        '\033[94mVal Loss:\033[0m {:.4f}, '
        '\033[92mVal F1:\033[0m {:.2f} ± {:.2f}, '
        '\033[93mTest Accuracy:\033[0m {:.2f} ± {:.2f}, '
        '\033[91mTest F1:\033[0m {:.2f} ± {:.2f}, '
        '\033[95mDuration:\033[0m {:.3f} sec'.format(
            loss.mean().item(),
            100 * v_f1.mean().item(),
            100 * v_f1.std().item(),
            100 * acc.mean().item(),
            100 * acc.std().item(),
            100 * t_f1.mean().item(),
            100 * t_f1.std().item(),
            duration.mean().item())
    )


    save_results(
        method=log_dict['model'],
        dataset=log_dict['dataset'],
        split=log_dict['split'],
        val_loss=loss.mean().item(),
        val_f1= v_f1.mean().item(),
        val_f1_std=v_f1.std().item(),
        test_accuracy=acc.mean().item(),
        test_accuracy_std=acc.std().item(),
        test_f1=t_f1.mean().item(),
        test_f1_std=t_f1.std().item(),
        duration=duration.mean().item(),
        optimizer=log_dict['optimizer'],
        epochs=log_dict['epochs'],
        runs=log_dict['runs'],
        lr=log_dict['lr'],
        k=log_dict['k'],
        residual=log_dict['residual'],
        forward_sampling=log_dict['forward_sampling'],
        backward_sampling=log_dict['backward_sampling'],
        num_layers=log_dict['num_layers'],
        lambda_rbm=log_dict['lambda_rbm'],
        loss_type=log_dict['loss_type'],
        md_file=log_dict['md_file'],
        csv_file=log_dict['csv_file']
    )

    print(f"accuracies: {acc.tolist()}")
    ut.log_run_results_to_file(
        output_path='results.md',
        dataset_name=dataset.name if hasattr(dataset, 'name') else 'unknown',
        split_type=split,  # e.g., 'public' or 'complete'
        val_losses=val_losses,
        accs=accs,
        durations=durations,
        str_optimizer=str_optimizer,
        str_preconditioner=str_preconditioner,
        lr=lr,
        weight_decay=weight_decay,
        epochs=epochs,
        runs=runs,
        early_stopping=early_stopping,
        momentum=momentum,
        eps=eps,
        update_freq=update_freq,
        gamma=gamma,
        alpha=alpha,
        hyperparam=hyperparam
    )

    # Visualize class distribution for best model
    if visualize_classes and len(models_and_results) > 0:
        print("\n" + "=" * 50)
        print("GENERATING CLASS DISTRIBUTION VISUALIZATIONS")
        print("=" * 50)

        #try:
            # Find best model based on test F1 score
        best_state_dict, best_metrics = max(models_and_results, key=lambda x: x[1]['test_f1'])

        print(f"Best model selected based on test F1: {best_metrics['test_f1']:.4f}")
        print(f"Best model test accuracy: {best_metrics['test_acc']:.4f}")

        # Load best model
        model.load_state_dict(best_state_dict)
        model.eval()

        # Generate visualizations
        report = ut.visualize_class_distribution(
            model=model,
            data=data,
            dataset_name=dataset.name if hasattr(dataset, 'name') else 'Unknown',
            split_type=split,
            class_names=class_names,
            save_dir=save_dir
        )

        # except Exception as e:
        #     print(f"Error generating class distribution visualizations: {e}")
        #     print("Continuing without visualizations...")


def visualize_class_distribution_backup(
        model,
        data,
        dataset_name,
        split_type,
        save_dir='vis',
        class_names=None,
        show_predictions=True,
        use_tsne=True,
        use_pca=True,
        perplexity=30,
        random_state=42
):
    """
    Visualize class distribution for the best model including:
    1. True class distribution
    2. Predicted class distribution
    3. Confusion matrix
    4. Class-wise performance metrics
    5. Clustering-style scatter plots (t-SNE/PCA) with class coloring

    Args:
        model: Trained PyTorch model
        data: Dataset object with node features and labels
        dataset_name: Name of the dataset
        split_type: Type of split ('public', 'complete', etc.)
        save_dir: Directory to save visualizations
        class_names: List of class names (optional)
        show_predictions: Whether to show prediction distributions
        use_tsne: Whether to create t-SNE visualization
        use_pca: Whether to create PCA visualization
        perplexity: t-SNE perplexity parameter
        random_state: Random state for reproducibility
    """
    from sklearn.manifold import TSNE
    from sklearn.decomposition import PCA
    from sklearn.preprocessing import StandardScaler

    # Create save directory
    os.makedirs(save_dir, exist_ok=True)

    # Set model to evaluation mode
    model.eval()

    # Get predictions and feature embeddings
    with torch.no_grad():
        out = model(data)
        pred = out.argmax(dim=-1)

        # Get intermediate features for clustering visualization
        # Try to get features from the model's last hidden layer
        if hasattr(model, 'get_embeddings'):
            features = model.get_embeddings(data).cpu().numpy()
        elif hasattr(model, 'features'):
            features = model.features(data).cpu().numpy()
        else:
            # Fallback: use the raw features or model output
            if hasattr(data, 'x') and data.x is not None:
                features = data.x.cpu().numpy()
            else:
                features = out.cpu().numpy()

    # Extract labels and predictions based on split
    if hasattr(data, 'train_mask') and hasattr(data, 'val_mask') and hasattr(data, 'test_mask'):
        # Graph-based data with masks
        train_labels = data.y[data.train_mask].cpu().numpy()
        val_labels = data.y[data.val_mask].cpu().numpy()
        test_labels = data.y[data.test_mask].cpu().numpy()

        train_preds = pred[data.train_mask].cpu().numpy()
        val_preds = pred[data.val_mask].cpu().numpy()
        test_preds = pred[data.test_mask].cpu().numpy()

        train_features = features[data.train_mask.cpu().numpy()]
        val_features = features[data.val_mask.cpu().numpy()]
        test_features = features[data.test_mask.cpu().numpy()]

        splits_data = {
            'Train': (train_labels, train_preds, train_features),
            'Validation': (val_labels, val_preds, val_features),
            'Test': (test_labels, test_preds, test_features)
        }
    else:
        # Fallback: use all data
        all_labels = data.y.cpu().numpy()
        all_preds = pred.cpu().numpy()
        splits_data = {
            'All Data': (all_labels, all_preds, features)
        }

    # Determine number of classes
    all_labels_combined = np.concatenate([labels for labels, _, _ in splits_data.values()])
    n_classes = len(np.unique(all_labels_combined))

    # Generate class names if not provided
    if class_names is None:
        class_names = [f'Class {i}' for i in range(n_classes)]

    # Color palette for classes
    colors = plt.cm.Set3(np.linspace(0, 1, n_classes))

    # Set style
    plt.style.use('seaborn-v0_8-whitegrid')
    plt.rcParams.update({
        'font.size': 10,
        'axes.labelsize': 12,
        'axes.titlesize': 14,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': 10
    })

    # [Previous code for distribution plots remains the same...]
    # 1. Overall Class Distribution Comparison
    fig, axes = plt.subplots(2, 3, figsize=(15, 12))
    fig.suptitle(f'Class Distribution Analysis - {dataset_name} ({split_type})',
                 fontsize=16, fontweight='bold')

    # Plot for each split
    for idx, (split_name, (labels, preds, _)) in enumerate(splits_data.items()):
        if idx < 4:  # Maximum 4 subplots
            ax = axes[idx // 3, idx % 3]

            # Count true labels
            label_counts = Counter(labels)
            pred_counts = Counter(preds) if show_predictions else None

            classes = list(range(n_classes))
            true_counts = [label_counts.get(i, 0) for i in classes]
            pred_counts_list = [pred_counts.get(i, 0) for i in classes] if pred_counts else None

            x = np.arange(len(classes))
            width = 0.35

            # Plot true distribution
            bars1 = ax.bar(x - width / 2, true_counts, width,
                           label='True Labels', alpha=0.8,
                           color=colors)

            # Plot predicted distribution if requested
            if show_predictions and pred_counts_list:
                bars2 = ax.bar(x + width / 2, pred_counts_list, width,
                               label='Predictions', alpha=0.8,
                               color=colors, hatch='///')

            # Add value labels on bars
            for bar in bars1:
                height = bar.get_height()
                ax.annotate(f'{int(height)}',
                            xy=(bar.get_x() + bar.get_width() / 2, height),
                            xytext=(0, 3),
                            textcoords="offset points",
                            ha='center', va='bottom', fontsize=9)

            if show_predictions and pred_counts_list:
                for bar in bars2:
                    height = bar.get_height()
                    ax.annotate(f'{int(height)}',
                                xy=(bar.get_x() + bar.get_width() / 2, height),
                                xytext=(0, 3),
                                textcoords="offset points",
                                ha='center', va='bottom', fontsize=9)

            ax.set_title(f'{split_name} Set Distribution', fontweight='bold')
            ax.set_xlabel('Class')
            ax.set_ylabel('Count')
            ax.set_xticks(x)
            ax.set_xticklabels(class_names, rotation=45, ha='right')
            ax.legend()
            ax.grid(True, alpha=0.3)

    # Hide unused subplots
    for idx in range(len(splits_data), 4):
        axes[idx // 3, idx % 3].set_visible(False)

    plt.tight_layout()
    plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-class_distribution_comparison.png', dpi=300,
                bbox_inches='tight')
    plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-class_distribution_comparison.pdf', dpi=300,
                bbox_inches='tight')
    plt.close()

    # NEW: Clustering-style visualization with dimensionality reduction
    def create_clustering_visualization(features, labels, preds, method_name, split_name):
        """Create clustering-style scatter plots for the given features and labels."""

        # Standardize features
        scaler = StandardScaler()
        features_scaled = scaler.fit_transform(features)

        # Apply dimensionality reduction
        if method_name == 't-SNE':
            if len(features) > 1000:  # Subsample for large datasets for t-SNE performance
                indices = np.random.choice(len(features), 1000, replace=False)
                features_subset = features_scaled[indices]
                labels_subset = labels[indices]
                preds_subset = preds[indices]
            else:
                features_subset = features_scaled
                labels_subset = labels
                preds_subset = preds

            reducer = TSNE(n_components=2, perplexity=min(perplexity, len(features_subset) // 4),
                           random_state=random_state, max_iter=1000)
            features_2d = reducer.fit_transform(features_subset)
        else:  # PCA
            reducer = PCA(n_components=2, random_state=random_state)
            features_2d = reducer.fit_transform(features_scaled)
            features_subset = features_scaled
            labels_subset = labels
            preds_subset = preds

        # Create subplots for true vs predicted
        fig, axes = plt.subplots(1, 2, figsize=(16, 7))

        # True labels plot
        ax1 = axes[0]
        scatter = ax1.scatter(features_2d[:, 0], features_2d[:, 1],
                              c=[colors[label] for label in labels_subset],
                              alpha=0.7, s=50, edgecolors='black', linewidth=0.5)
        ax1.set_title(f'True Labels - {method_name} ({split_name})', fontweight='bold')
        ax1.set_xlabel(f'{method_name} Component 1')
        ax1.set_ylabel(f'{method_name} Component 2')
        ax1.grid(True, alpha=0.3)

        # Create legend for true labels
        legend_elements = [plt.Line2D([0], [0], marker='o', color='w',
                                      markerfacecolor=colors[i], markersize=10,
                                      label=class_names[i]) for i in range(n_classes)]
        ax1.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.05, 1))

        # Predicted labels plot
        ax2 = axes[1]
        scatter2 = ax2.scatter(features_2d[:, 0], features_2d[:, 1],
                               c=[colors[pred] for pred in preds_subset],
                               alpha=0.7, s=50, edgecolors='black', linewidth=0.5)
        ax2.set_title(f'Predicted Labels - {method_name} ({split_name})', fontweight='bold')
        ax2.set_xlabel(f'{method_name} Component 1')
        ax2.set_ylabel(f'{method_name} Component 2')
        ax2.grid(True, alpha=0.3)

        # Create legend for predicted labels
        ax2.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.05, 1))

        # Add accuracy information
        accuracy = np.mean(labels_subset == preds_subset)
        fig.suptitle(f'{method_name} Clustering Visualization - {split_name} Set\n'
                     f'Accuracy: {accuracy:.3f}', fontsize=16, fontweight='bold')

        plt.tight_layout()

        # Save the plot
        method_str = method_name.lower().replace('-', '')
        plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-{method_str}_clustering_{split_name.lower()}.png',
                    dpi=300, bbox_inches='tight')
        plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-{method_str}_clustering_{split_name.lower()}.pdf',
                    dpi=300, bbox_inches='tight')
        plt.close()

        return features_2d, accuracy

    # NEW: Create combined comparison plot (true vs predicted side by side)
    def create_combined_clustering_plot(features, labels, preds, method_name):
        """Create a combined plot showing correct and incorrect predictions."""

        # Standardize features
        scaler = StandardScaler()
        features_scaled = scaler.fit_transform(features)

        # Apply dimensionality reduction
        if method_name == 't-SNE':
            if len(features) > 1000:
                indices = np.random.choice(len(features), 1000, replace=False)
                features_subset = features_scaled[indices]
                labels_subset = labels[indices]
                preds_subset = preds[indices]
            else:
                features_subset = features_scaled
                labels_subset = labels
                preds_subset = preds

            reducer = TSNE(n_components=2, perplexity=min(perplexity, len(features_subset) // 4),
                           random_state=random_state, max_iter=1000)
            features_2d = reducer.fit_transform(features_subset)
        else:  # PCA
            reducer = PCA(n_components=2, random_state=random_state)
            features_2d = reducer.fit_transform(features_scaled)
            features_subset = features_scaled
            labels_subset = labels
            preds_subset = preds

        # Identify correct and incorrect predictions
        correct_mask = labels_subset == preds_subset
        incorrect_mask = ~correct_mask

        fig, ax = plt.subplots(1, 1, figsize=(12, 9))

        # Plot correct predictions
        if np.sum(correct_mask) > 0:
            scatter_correct = ax.scatter(features_2d[correct_mask, 0], features_2d[correct_mask, 1],
                                         c=[colors[label] for label in labels_subset[correct_mask]],
                                         alpha=0.8, s=60, edgecolors='black', linewidth=0.5,
                                         marker='o', label='Correct Predictions')

        # Plot incorrect predictions
        if np.sum(incorrect_mask) > 0:
            scatter_incorrect = ax.scatter(features_2d[incorrect_mask, 0], features_2d[incorrect_mask, 1],
                                           c=[colors[label] for label in labels_subset[incorrect_mask]],
                                           alpha=0.8, s=60, edgecolors='red', linewidth=2,
                                           marker='X', label='Incorrect Predictions')

        ax.set_title(f'{method_name} Clustering: Prediction Accuracy Visualization', fontweight='bold')
        ax.set_xlabel(f'{method_name} Component 1')
        ax.set_ylabel(f'{method_name} Component 2')
        ax.grid(True, alpha=0.3)

        # Create custom legend combining class colors and prediction accuracy
        legend_elements = []
        for i in range(n_classes):
            legend_elements.append(plt.Line2D([0], [0], marker='o', color='w',
                                              markerfacecolor=colors[i], markersize=10,
                                              label=f'{class_names[i]}'))

        legend_elements.append(plt.Line2D([0], [0], marker='o', color='w',
                                          markerfacecolor='gray', markeredgecolor='black',
                                          markersize=10, label='Correct'))
        legend_elements.append(plt.Line2D([0], [0], marker='X', color='w',
                                          markerfacecolor='gray', markeredgecolor='red',
                                          markersize=10, label='Incorrect'))

        ax.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(1.05, 0.5))

        accuracy = np.mean(correct_mask)
        plt.figtext(0.02, 0.02, f'Overall Accuracy: {accuracy:.3f} '
                                f'({np.sum(correct_mask)}/{len(labels_subset)} correct)',
                    fontsize=12, fontweight='bold')

        plt.tight_layout()

        method_str = method_name.lower().replace('-', '')
        plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-{method_str}_prediction_accuracy.png',
                    dpi=300, bbox_inches='tight')
        plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-{method_str}_prediction_accuracy.pdf',
                    dpi=300, bbox_inches='tight')
        plt.close()

        return features_2d, accuracy

    # Generate clustering visualizations for each split and method
    clustering_results = {}

    for split_name, (labels, preds, split_features) in splits_data.items():
        if len(split_features) < 10:  # Skip if too few samples
            continue

        clustering_results[split_name] = {}

        if use_tsne:
            print(f"Generating t-SNE visualization for {split_name} set...")
            features_2d, accuracy = create_clustering_visualization(
                split_features, labels, preds, 't-SNE', split_name)
            clustering_results[split_name]['tsne'] = (features_2d, accuracy)

            # Create combined accuracy plot
            create_combined_clustering_plot(split_features, labels, preds, 't-SNE')

        if use_pca:
            print(f"Generating PCA visualization for {split_name} set...")
            features_2d, accuracy = create_clustering_visualization(
                split_features, labels, preds, 'PCA', split_name)
            clustering_results[split_name]['pca'] = (features_2d, accuracy)

            # Create combined accuracy plot
            create_combined_clustering_plot(split_features, labels, preds, 'PCA')

    # [Rest of the original visualization code remains the same...]
    # 2. Pie Charts for Class Distribution
    n_splits = len(splits_data)
    fig, axes = plt.subplots(1, n_splits, figsize=(5 * n_splits, 5))
    if n_splits == 1:
        axes = [axes]

    fig.suptitle(f'Class Distribution Pie Charts - {dataset_name} ({split_type})',
                 fontsize=16, fontweight='bold')

    for idx, (split_name, (labels, preds, _)) in enumerate(splits_data.items()):
        ax = axes[idx]

        label_counts = Counter(labels)
        counts = [label_counts.get(i, 0) for i in range(n_classes)]

        # Create pie chart
        wedges, texts, autotexts = ax.pie(counts, labels=class_names, autopct='%1.1f%%',
                                          colors=colors, startangle=90)

        # Enhance text
        for autotext in autotexts:
            autotext.set_color('white')
            autotext.set_fontweight('bold')

        ax.set_title(f'{split_name} Set', fontweight='bold')

    plt.tight_layout()
    plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-class_distribution_pie_charts.png', dpi=300,
                bbox_inches='tight')
    plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-class_distribution_pie_charts.pdf', dpi=300,
                bbox_inches='tight')
    plt.close()

    # 3. Confusion Matrix (for test set if available)
    report_df = None
    if 'Test' in splits_data or len(splits_data) == 1:
        test_data = splits_data.get('Test', list(splits_data.values())[0])
        test_labels, test_preds, _ = test_data

        cm = confusion_matrix(test_labels, test_preds)

        plt.figure(figsize=(10, 8))
        sns.heatmap(cm,
                    annot=True,  # show values
                    fmt='d',  # integer format
                    cmap='Blues',
                    xticklabels=class_names,
                    yticklabels=class_names,
                    annot_kws={
                        "size": 12,  # font size
                        "weight": "bold",  # font weight
                        "color": "black"  # text color
                    })
        plt.title(f'Confusion Matrix - {dataset_name} ({split_type})',
                  fontweight='bold', pad=20)
        plt.xlabel('Predicted Class')
        plt.ylabel('True Class')
        plt.tight_layout()
        plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-confusion_matrix.png', dpi=300, bbox_inches='tight')
        plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-confusion_matrix.pdf', dpi=300, bbox_inches='tight')
        plt.close()

        # Generate and save classification report
        report = classification_report(test_labels, test_preds,
                                       target_names=class_names,
                                       output_dict=True)

        # Convert to DataFrame for better visualization
        report_df = pd.DataFrame(report).transpose()

        # Save classification report
        report_df.to_csv(f'{save_dir}/{dataset_name}-{split_type}-classification_report.csv')

        # Visualize classification metrics
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))

        # Precision, Recall, F1-Score by class
        metrics = ['precision', 'recall', 'f1-score']
        class_metrics = report_df.iloc[:-3]  # Exclude avg rows

        for idx, metric in enumerate(metrics):
            ax = axes[idx]
            bars = ax.bar(range(len(class_names)), class_metrics[metric],
                          color=colors, alpha=0.8)

            # Add value labels
            for bar in bars:
                height = bar.get_height()
                ax.annotate(f'{height:.3f}',
                            xy=(bar.get_x() + bar.get_width() / 2, height),
                            xytext=(0, 3),
                            textcoords="offset points",
                            ha='center', va='bottom', fontsize=9)

            ax.set_title(f'{metric.capitalize()} by Class', fontweight='bold')
            ax.set_xlabel('Class')
            ax.set_ylabel(metric.capitalize())
            ax.set_xticks(range(len(class_names)))
            ax.set_xticklabels(class_names, rotation=45, ha='right')
            ax.set_ylim(0, 1.1)
            ax.grid(True, alpha=0.3)

        plt.suptitle(f'Classification Metrics by Class - {dataset_name} ({split_type})',
                     fontsize=16, fontweight='bold')
        plt.tight_layout()
        plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-classification_metrics.png', dpi=300, bbox_inches='tight')
        plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-classification_metrics.pdf', dpi=300, bbox_inches='tight')
        plt.close()

    # 4. Class Imbalance Analysis
    plt.figure(figsize=(12, 6))

    # Calculate class imbalance ratio
    all_labels = np.concatenate([labels for labels, _, _ in splits_data.values()])
    label_counts = Counter(all_labels)
    total_samples = len(all_labels)

    class_ratios = [label_counts.get(i, 0) / total_samples for i in range(n_classes)]

    bars = plt.bar(range(n_classes), class_ratios, color=colors, alpha=0.8)

    # Add percentage labels
    for bar in bars:
        height = bar.get_height()
        plt.annotate(f'{height * 100:.1f}%',
                     xy=(bar.get_x() + bar.get_width() / 2, height),
                     xytext=(0, 3),
                     textcoords="offset points",
                     ha='center', va='bottom', fontsize=10)

    plt.title(f'Class Imbalance Analysis - {dataset_name} ({split_type})',
              fontweight='bold', pad=20)
    plt.xlabel('Class')
    plt.ylabel('Proportion of Total Samples')
    plt.xticks(range(n_classes), class_names, rotation=45, ha='right')

    # Add horizontal line for balanced distribution
    balanced_ratio = 1.0 / n_classes
    plt.axhline(y=balanced_ratio, color='red', linestyle='--', alpha=0.7,
                label=f'Balanced ({balanced_ratio * 100:.1f}%)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-class_imbalance_analysis.png', dpi=300, bbox_inches='tight')
    plt.savefig(f'{save_dir}/{dataset_name}-{split_type}-class_imbalance_analysis.pdf', dpi=300, bbox_inches='tight')
    plt.close()

    print(f"\nClass distribution visualizations saved in '{save_dir}' directory:")
    print("  - class_distribution_comparison.png/pdf")
    print("  - class_distribution_pie_charts.png/pdf")
    print("  - confusion_matrix.png/pdf")
    print("  - classification_metrics.png/pdf")
    print("  - class_imbalance_analysis.png/pdf")
    print("  - classification_report.csv")
    if use_tsne or use_pca:
        print("\nClustering-style visualizations:")
        for split_name in clustering_results.keys():
            if use_tsne:
                print(f"  - tsne_clustering_{split_name.lower()}.png/pdf")
                print(f"  - tsne_prediction_accuracy.png/pdf")
            if use_pca:
                print(f"  - pca_clustering_{split_name.lower()}.png/pdf")
                print(f"  - pca_prediction_accuracy.png/pdf")

    return report_df, clustering_results


def train(model, optimizer, data, train_mask, preconditioner=None, lam=0.5, lambda_rbm=0.0,
          rbm_loss_type="cd"):
    model.train()
    optimizer.zero_grad()

    # Get both logits and RBM loss from model
    if lambda_rbm != 0.0:
        out, rbm_loss = model(data, return_loss=True, loss_type=rbm_loss_type)
    else:
        out = model(data, return_loss=False)
        rbm_loss = 0.0

    # Classification loss

    #print(data.y.shape)
    #print(data.y[0])
    if data.y.dim() == 2:
        if data.y.size(1) > 1:
            data.y = data.y.argmax(dim=1)
        else:
            data.y = data.y.view(-1)

        # Ensure 1D
    data.y = data.y.long().view(-1)
    # Handle both 1D and 2D masks automatically
    if train_mask.dim() == 1:
        # Simple 1D mask [N]
        classification_loss = F.nll_loss(out[train_mask], data.y[train_mask])
    else:
        # 2D mask [N, num_splits] - use first split or loop through all
        split_idx = 0  # Choose which split to use
        mask_1d = train_mask[:, split_idx]
        classification_loss = F.nll_loss(out[mask_1d], data.y[mask_1d])
    #classification_loss = F.nll_loss(out[train_mask], data.y[train_mask])

    # Combined loss: classification + weighted RBM reconstruction loss
    total_loss = (1 - lambda_rbm) * classification_loss + lambda_rbm * rbm_loss

    #total_loss = rbm_loss

    total_loss.backward()

    if preconditioner is not None:
        preconditioner.step()

    optimizer.step()

    return {
        'total_loss': total_loss.item(),
        'classification_loss': classification_loss.item(),
        'rbm_loss': rbm_loss.item() if isinstance(rbm_loss, torch.Tensor) else rbm_loss
    }


def evaluate(model, data, mask):
    model.eval()
    with torch.no_grad():
        # Don't compute RBM loss during evaluation for efficiency
        logits = model(data, return_loss=False)

        results = {}
        for key, node_mask in mask.items():
            # Handle both 1D and 2D masks
            if node_mask.dim() > 1:
                # 2D mask [N, num_splits] - use first split
                node_mask = node_mask[:, 0]

            pred = logits[node_mask].max(1)[1]
            acc = pred.eq(data.y[node_mask]).float().mean().item()

            # Calculate F1 score
            y_true = data.y[node_mask].cpu().numpy()
            y_pred = pred.cpu().numpy()

            if len(np.unique(y_true)) == 2:
                f1 = f1_score(y_true, y_pred, average='binary')
            else:
                f1 = f1_score(y_true, y_pred, average='macro')

            results[f'{key.replace("_mask", "")}_acc'] = acc
            results[f'{key.replace("_mask", "")}_f1'] = f1

    return results

import pandas as pd
import os
from datetime import datetime


def save_results(method, dataset, split, val_loss, val_f1, val_f1_std,
                 test_accuracy, test_accuracy_std, test_f1, test_f1_std,
                 duration, optimizer, epochs, runs, lr, k, residual,
                 forward_sampling, backward_sampling, num_layers,
                 lambda_rbm, loss_type, md_file, csv_file):
    """
    Save node classification results to Markdown and CSV files, appending for multiple runs.

    Args:
        method (str): Method name (e.g., 'RBMNet').
        dataset (str): Dataset name (e.g., 'Cora').
        split (str): Split type (e.g., 'public').
        val_loss (float): Validation loss.
        val_f1 (float): Validation F1-score mean.
        val_f1_std (float): Validation F1-score standard deviation.
        test_accuracy (float): Test accuracy mean.
        test_accuracy_std (float): Test accuracy standard deviation.
        test_f1 (float): Test F1-score mean.
        test_f1_std (float): Test F1-score standard deviation.
        duration (float): Duration in seconds.
        optimizer (str): Optimizer name.
        epochs (int): Number of training epochs.
        runs (int): Number of runs.
        lr (float): Learning rate.
        k (int): Number of CD steps.
        residual (float): Residual connection weight.
        forward_sampling (str): Forward sampling method.
        backward_sampling (str): Backward sampling method.
        num_layers (int): Number of layers.
        lambda_rbm (float): RBM loss weight.
        loss_type (str): Type of loss function.
        md_file (str): Path to Markdown file.
        csv_file (str): Path to CSV file.
    """
    # ANSI color codes
    BLUE = '\033[94m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    RED = '\033[91m'
    MAGENTA = '\033[95m'
    CYAN = '\033[96m'
    BOLD = '\033[1m'
    RESET = '\033[0m'

    # Print results with colored metrics
    #print(
    #    f'{BLUE}Val Loss:{RESET} {val_loss:.4f}, '
    #    f'{GREEN}Val F1:{RESET} {100 * val_f1:.2f} ± {100 * val_f1_std:.2f}, '
    #    f'{YELLOW}Test Accuracy:{RESET} {100 * test_accuracy:.2f} ± {100 * test_accuracy_std:.2f}, '
    #    f'{RED}Test F1:{RESET} {100 * test_f1:.2f} ± {100 * test_f1_std:.2f}, '
    #    f'{MAGENTA}Duration:{RESET} {duration:.3f} sec'
    #)

    # Prepare data for saving
    data = {
        'Timestamp': [datetime.now().strftime('%Y-%m-%d %H:%M:%S')],
        'Method': [method],
        'Dataset': [dataset],
        'Split': [split],
        'Optimizer': [optimizer],
        'Epochs': [epochs],
        'Runs': [runs],
        'LR': [lr],
        'K-Steps': [k],
        'Residual': [residual],
        'Forward Sampling': [forward_sampling],
        'Backward Sampling': [backward_sampling],
        'Num Layers': [num_layers],
        'Lambda RBM': [lambda_rbm],
        'Loss Type': [loss_type],
        'Val Loss': [val_loss],
        'Val F1 (%)': [100 * val_f1],
        'Val F1 Std (%)': [100 * val_f1_std],
        'Test Accuracy (%)': [100 * test_accuracy],
        'Test Accuracy Std (%)': [100 * test_accuracy_std],
        'Test F1 (%)': [100 * test_f1],
        'Test F1 Std (%)': [100 * test_f1_std],
        'Duration (sec)': [duration]
    }
    df = pd.DataFrame(data)

    # Append to CSV file
    if os.path.exists(csv_file):
        existing_df = pd.read_csv(csv_file)
        df = pd.concat([existing_df, df], ignore_index=True)
    df.to_csv(csv_file, index=False)

    # Append to Markdown file with colored metrics
    md_row = (
        f"| {method} | {dataset} | {split} | {optimizer} | {epochs} | {runs} | "
        f"{lr} | {k} | {residual} | {forward_sampling} | {backward_sampling} | "
        f"{num_layers} | {lambda_rbm} | {loss_type} | "
        f"<span style='color:blue'>{val_loss:.4f}</span> | "
        f"<span style='color:green'>{100 * val_f1:.2f} ± {100 * val_f1_std:.2f}</span> | "
        f"<span style='color:orange'>{100 * test_accuracy:.2f} ± {100 * test_accuracy_std:.2f}</span> | "
        f"<span style='color:red'>{100 * test_f1:.2f} ± {100 * test_f1_std:.2f}</span> | "
        f"<span style='color:purple'>{duration:.3f}</span> |\n"
    )

    if not os.path.exists(md_file):
        md_content = """# GBN Node Classification Results  

| Method | Dataset | Split | Optimizer | Epochs | Runs | LR | K | Residual | Forward Sampling | Backward Sampling | Layers | λ_RBM | Loss Type | Val Loss | Val F1 (%) | Test Accuracy (%) | Test F1 (%) | Duration (sec) |
|--------|---------|-------|-----------|--------|------|----|----|----------|------------------|-------------------|--------|-------|-----------|----------|------------|-------------------|-------------|----------------|
"""
        with open(md_file, 'w') as f:
            f.write(md_content)

    with open(md_file, 'a') as f:
        f.write(md_row)

    print(f"{CYAN}✅ Results saved to {csv_file} and {md_file}{RESET}")