import os
import random
import json
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import shutil
import pandas as pd
from torch_geometric.data import Data, InMemoryDataset
from collections import defaultdict, Counter
from torch_geometric.utils import degree
from torch_geometric.transforms import OneHotDegree
import seaborn as sns
from sklearn.manifold import TSNE
from torch_geometric.utils import to_undirected

# Set random seed for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(42)

class SpectralGapDataset(InMemoryDataset):
    """
    A PyTorch Geometric dataset class for learning to predict 
    the spectral gap of graphs generated by different models.
    
    Spectral gap = 2nd smallest eigenvalue of the normalized Laplacian matrix,
    which measures graph connectivity.
    """
    
    def __init__(self, root, mode='train', transform=None, pre_transform=None, pre_filter=None, 
                 use_degree_features=True, min_nodes=None, max_nodes=None, samples_per_class=None):
        self.mode = mode
        self.use_degree_features = use_degree_features
        
        # Set node size ranges based on mode
        if min_nodes is None or max_nodes is None or samples_per_class is None:
            if mode == 'train':
                self.min_nodes, self.max_nodes = 20, 50
                self.samples_per_class = 3000
            elif mode == 'val':
                self.min_nodes, self.max_nodes = 20, 50
                self.samples_per_class = 300
            else:  # test
                self.min_nodes, self.max_nodes = 20, 50
                self.samples_per_class = 50
        else:
            self.min_nodes = min_nodes
            self.max_nodes = max_nodes
            self.samples_per_class = samples_per_class
        
        # Types of underlying graph generators
        self.generator_types = [
            "sbm_dumbbell",     # Dumbbell SBM (2 communities)
            "sbm_multi",        # Multi-community SBM (3-5 communities)
            "geometric_random", # Random geometric graph
            "modified_config"   # Graphs modified via Configuration Model
        ]
        
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])
    
    @property
    def raw_file_names(self):
        return [f'dummy_{self.mode}.txt']
    
    @property
    def processed_file_names(self):
        return [f'spectral_gap_{self.mode}.pt']
    
    def download(self):
        # Create a dummy raw file to satisfy PyG requirements
        if not os.path.exists(self.raw_dir):
            os.makedirs(self.raw_dir)
        with open(os.path.join(self.raw_dir, f'dummy_{self.mode}.txt'), 'w') as f:
            f.write('dummy')

    def compute_normalized_spectral_gap(self, G):
        # Compute the spectral gap from the normalized Laplacian matrix
        from scipy.sparse.linalg import eigsh
        
        L_norm = nx.normalized_laplacian_matrix(G)
        
        try:
            # Use sparse solver for efficiency
            eigenvalues, _ = eigsh(L_norm, k=2, which='SM')
            eigenvalues = np.sort(eigenvalues)
            return eigenvalues[1]  # 2nd smallest eigenvalue for connected graphs
        except:
            # Use full eigendecomposition as a fallback (for small graphs)
            eigenvalues = np.sort(np.linalg.eigvalsh(L_norm.toarray()))
            return eigenvalues[1]
    
    def process(self):
        # Core logic to generate the dataset
        data_list = []
        metadata = {
            'dataset_info': {
                'mode': self.mode,
                'min_nodes': self.min_nodes,
                'max_nodes': self.max_nodes,
                'samples_per_class': self.samples_per_class,
            },
            'graphs': []  # Store graph generation info
        }
        
        # 1. Generate SBM-based graphs (60%)
        sbm_total_samples = int(self.samples_per_class * 0.6)
        sbm_structures = ['dumbbell', 'multi']  
        samples_per_structure = sbm_total_samples // len(sbm_structures)
        
        for structure in sbm_structures:
            for i in tqdm(range(samples_per_structure), desc=f"Generating SBM {structure}"):
                n_nodes = random.randint(self.min_nodes, self.max_nodes)
                
                # Vary the mix_factor to generate different spectral gap levels
                if i < samples_per_structure * 0.4: 
                    mix_factor = random.uniform(0, 0.2)  
                elif i < samples_per_structure * 0.7:
                    mix_factor = random.uniform(0.2, 0.5)
                else:
                    mix_factor = random.uniform(0.5, 0.8)  
                
                G = self.generate_sbm_evolution(n_nodes, mix_factor, structure)
                gen_info = {
                    'generator': f'sbm_{structure}',
                    'nodes': n_nodes,
                    'mix_factor': mix_factor
                }                
                self._add_graph_to_data_list(G, data_list, f'sbm_{structure}', gen_info,metadata)
        
        # 2. Generate geometric random graphs (20%)
        geo_samples = int(self.samples_per_class * 0.2)
        for i in tqdm(range(geo_samples), desc="Generating Geometric"):
            n_nodes = random.randint(self.min_nodes, self.max_nodes)
            mix_factor = random.uniform(0, 1)
            G = self.generate_geometric_evolution(n_nodes, mix_factor)
            gen_info = {
                'generator': 'geometric_random',
                'nodes': n_nodes,
                'mix_factor': mix_factor
            }
            self._add_graph_to_data_list(G, data_list, 'geometric_random', gen_info,metadata)
        
        # 3. Generate config-model-based graphs (20%)
        config_samples = self.samples_per_class - sbm_total_samples - geo_samples
        for i in tqdm(range(config_samples), desc="Generating Config Model"):
            n_nodes = random.randint(self.min_nodes, self.max_nodes)
            base_type = random.choice(['sbm_dumbbell', 'sbm_multi', 'geometric_random'])
            mix_factor = random.uniform(0, 1)
            
            if base_type.startswith('sbm'):
                structure = base_type.split('_')[1]
                G = self.generate_sbm_evolution(n_nodes, mix_factor, structure)
            else:
                G = self.generate_geometric_evolution(n_nodes, mix_factor)
            
            randomization_level = random.uniform(0.3, 0.8)
            G = self.apply_configuration_model(G, randomization_level)
            
            gen_info = {
                'generator': 'modified_config',
                'base_generator': base_type,
                'nodes': n_nodes,
                'base_mix_factor': mix_factor,
                'randomization_level': randomization_level
            }
            
            self._add_graph_to_data_list(G, data_list, 'modified_config', gen_info,metadata)
        
        # Save metadata for future visualization
        metadata['dataset_stats'] = {'total_graphs': len(data_list)}
        metadata_path = os.path.join(self.processed_dir, f'metadata_{self.mode}.json')
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)

        # Optional filtering and transformation
        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]
            
        if self.use_degree_features:
            max_degree = 150
            degree_transform = OneHotDegree(max_degree)
            data_list = [degree_transform(data) for data in data_list]
        
        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]
            
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
    
    def _add_graph_to_data_list(self, G, data_list, generator_type, generation_info,metadata):
        # Process and add a generated graph to the data list
        if not nx.is_connected(G):
            largest_cc = max(nx.connected_components(G), key=len)
            G = G.subgraph(largest_cc).copy()
        
        G = nx.convert_node_labels_to_integers(G)
        G.remove_edges_from(nx.selfloop_edges(G))        

        spectral_gap = self.compute_normalized_spectral_gap(G)
        
        edge_index = torch.tensor(list(G.edges)).t().contiguous()
        edge_index = to_undirected(edge_index)
        x = torch.ones(G.number_of_nodes(), 1)
        y = torch.tensor([spectral_gap], dtype=torch.float)
        
        generator_idx = self.generator_types.index(generator_type)
        
        data = Data(
            x=x, 
            edge_index=edge_index, 
            y=y,
            # generator_idx=torch.tensor([generator_idx], dtype=torch.long),
            num_nodes=torch.tensor([G.number_of_nodes()], dtype=torch.long)
        )
        
        graph_info = {
            'index': len(data_list),  
            'generator_type': generator_type,
            'spectral_gap': spectral_gap,
            **generation_info  
        }
        metadata['graphs'].append(graph_info)  
        
        data_list.append(data)
    
    def generate_sbm_evolution(self, n_nodes, mix_factor=0.0, structure_type='multi'):
        # Generate a Stochastic Block Model (SBM) graph, community connectivity controlled by `mix_factor`

        if structure_type == 'dumbbell':
            num_blocks = 2
            sizes = [n_nodes // 2, n_nodes - n_nodes // 2]
        else:  # 'multi'
            num_blocks = max(3, min(5, n_nodes // 15))
            sizes = [n_nodes // num_blocks] * (num_blocks - 1)
            sizes.append(n_nodes - sum(sizes))
        
        p_within = random.uniform(0.6, 0.8)  
        
        if mix_factor < 0.1:
            p_between = random.uniform(0.001, 0.02)  # minimal inter-block connections
        elif mix_factor < 0.3:
            p_between = random.uniform(0.02, 0.1)    # low
        elif mix_factor < 0.7:
            p_between = random.uniform(0.1, 0.3)     # medium
        else:
            p_between = random.uniform(0.3, p_within) # high
        
        p = [[p_between for _ in range(num_blocks)] for _ in range(num_blocks)]
        for i in range(num_blocks):
            p[i][i] = p_within
        
        G = nx.stochastic_block_model(sizes, p, seed=random.randint(1, 10000))
        if not nx.is_connected(G):
            components = list(nx.connected_components(G))
            for i in range(len(components)-1):
                u = random.choice(list(components[i]))
                v = random.choice(list(components[i+1]))
                G.add_edge(u, v)
        return G
    
    def generate_geometric_evolution(self, n_nodes, mix_factor=0.0):
        # Generate a geometric graph with increasing randomness based on mix_factor
        
        base_radius = max(0.1, min(0.3, (2*np.log(n_nodes)/n_nodes)**0.5))
        pos = {i: (random.random(), random.random()) for i in range(n_nodes)}
        
        G = nx.Graph()
        G.add_nodes_from(range(n_nodes))
        
        # Phase 1: Connect nodes within base radius (geometric connection)
        for i in range(n_nodes):
            for j in range(i+1, n_nodes):
                x1, y1 = pos[i]
                x2, y2 = pos[j]
                dist = ((x1-x2)**2 + (y1-y2)**2)**0.5
                
                if dist <= base_radius:
                    G.add_edge(i, j)
        
        # Phase 2: Add extra random connections based on mix_factor
        if mix_factor > 0:
            extra_connections = int(mix_factor * 0.1 * n_nodes * np.log(n_nodes))            
            non_edges = list(nx.non_edges(G))
            if non_edges and extra_connections > 0:
                additional_edges = random.sample(non_edges, min(extra_connections, len(non_edges)))
                G.add_edges_from(additional_edges)
        
        # Ensure connectivity: connect disconnected components
        if not nx.is_connected(G):
            components = list(nx.connected_components(G))
            largest = max(components, key=len)
            for comp in components:
                if comp != largest:
                    G.add_edge(random.choice(list(comp)), random.choice(list(largest)))
        
        return G
    
    def apply_configuration_model(self, G, randomization_level=0.0):
        # Rewire the graph using Configuration Model while preserving the degree distribution.
        # randomization_level controls how many edges are rewired (0=no change, 1=fully randomized)

        degrees = [d for _, d in G.degree()]
        n_edges = G.number_of_edges()
        
        # Determine number of edges to rewire
        edges_to_rewire = int(randomization_level * n_edges)
        
        if edges_to_rewire > 0:
            # Generate configuration model graph with same degree sequence
            G_config = nx.configuration_model(degrees, seed=random.randint(1, 10000))
            G_config = nx.Graph(G_config)  # Convert to simple graph (no multi-edges)
            
            config_edges = list(G_config.edges())
            if len(config_edges) > edges_to_rewire:
                rewired_edges = random.sample(config_edges, edges_to_rewire)

                G_edges = list(G.edges())
                if len(G_edges) >= edges_to_rewire:
                    edges_to_remove = random.sample(G_edges, edges_to_rewire)
                    G.remove_edges_from(edges_to_remove)
                    G.add_edges_from(rewired_edges)
        
        # Ensure graph is connected after rewiring
        if not nx.is_connected(G):
            components = list(nx.connected_components(G))
            largest = max(components, key=len)
            for comp in components:
                if comp != largest:
                    G.add_edge(random.choice(list(comp)), 
                             random.choice(list(largest)))
        
        return nx.Graph(G)

def generate_spectral_gap_dataset(root_dir, seed=42):
    # Generate spectral gap dataset with training/validation and multiple difficulty test sets.
    set_seed(seed)
    os.makedirs(root_dir, exist_ok=True)
    
    # Define test set difficulty settings
    difficulty_configs = {
        'ID': {
            'min_nodes': 20, 
            'max_nodes': 50,  # In-distribution
            'samples_per_class': 300  # total samples
        },
        'Near-OOD': {
            'min_nodes': 40, 
            'max_nodes': 100,  # Mild distribution shift
            'samples_per_class': 300
        },
        'Far-OOD': {
            'min_nodes': 60, 
            'max_nodes': 150,  # Significant distribution shift
            'samples_per_class': 300
        }
    }
    
    print("Generating training set...")
    train_dataset = SpectralGapDataset(os.path.join(root_dir, 'train'), mode='train')
    
    print("Generating validation set...")
    val_dataset = SpectralGapDataset(os.path.join(root_dir, 'val'), mode='val')
    
    test_datasets = {}
    
    for difficulty, config in difficulty_configs.items():
        test_dir = os.path.join(root_dir, f'test_{difficulty}')
        processed_file = os.path.join(test_dir, 'processed', f'spectral_gap_test.pt')
        
        if os.path.exists(processed_file):
            print(f"Found existing {difficulty} test set, loading...")
            test_dataset = SpectralGapDataset(test_dir, mode='test', use_degree_features=True)
        else:
            print(f"Generating {difficulty} test set...")
            test_dataset = SpectralGapDataset(
                test_dir, 
                mode='test',
                min_nodes=config['min_nodes'],
                max_nodes=config['max_nodes'],
                samples_per_class=config['samples_per_class'],
                use_degree_features=True
            )
        
        test_datasets[difficulty] = test_dataset
        
        # Print stats
        gaps = [data.y.item() for data in test_dataset]
        print(f"{difficulty} Test Set: {len(test_dataset)} samples, Nodes: {config['min_nodes']}-{config['max_nodes']}")
        print(f"  Spectral Gap: {min(gaps):.3f} - {max(gaps):.3f}, Median: {np.median(gaps):.3f}")
    
    train_gaps = [data.y.item() for data in train_dataset]
    val_gaps = [data.y.item() for data in val_dataset]
    
    print(f"Training Set: {len(train_dataset)} samples, Nodes: {train_dataset.min_nodes}-{train_dataset.max_nodes}")
    print(f"  Spectral Gap: {min(train_gaps):.3f} - {max(train_gaps):.3f}, Median: {np.median(train_gaps):.3f}")

    print(f"Validation Set: {len(val_dataset)} samples, Nodes: {val_dataset.min_nodes}-{val_dataset.max_nodes}")
    print(f"  Spectral Gap: {min(val_gaps):.3f} - {max(val_gaps):.3f}, Median: {np.median(val_gaps):.3f}")

    return train_dataset, val_dataset, test_datasets

def generate_spectral_gap_image_dataset(root_dir, layout="spring", image_size=224, seed=42):
    # Generate image representations for the spectral gap prediction dataset
    set_seed(seed)
    
    # Ensure the directory exists
    image_dir = os.path.join(root_dir, "images")
    os.makedirs(image_dir, exist_ok=True)
    
    dataset_csv = os.path.join(root_dir, "dataset.csv")
    
    # Load the previously generated PyG datasets
    train_dataset = SpectralGapDataset(os.path.join(root_dir, 'train'), mode='train')
    val_dataset = SpectralGapDataset(os.path.join(root_dir, 'val'), mode='val')
    
    # Load test sets of different difficulty levels
    difficulties = ['ID', 'Near-OOD', 'Far-OOD']
    test_datasets = {}
    
    for difficulty in difficulties:
        test_dir = os.path.join(root_dir, f'test_{difficulty}')
        if os.path.exists(test_dir):
            test_datasets[f'test_{difficulty}'] = SpectralGapDataset(test_dir, mode='test')
        else:
            print(f"Warning: {difficulty} test set not found. Please ensure it has been generated.")
    
    # Combine all datasets
    all_datasets = {
        'train': train_dataset,
        'val': val_dataset,
        **test_datasets
    }
    
    # Check if dataset information already exists
    if os.path.exists(dataset_csv):
        print("Found existing dataset metadata.")
        df = pd.read_csv(dataset_csv)
    else:
        print("Dataset metadata not found. Creating new metadata.")
        df = pd.DataFrame(columns=["image_path", "label", "generator_type", "split", "mix_factor"])
    
    generator_types = train_dataset.generator_types
    
    # Generate images for all datasets
    new_rows = []
    
    layout_func = getattr(nx, f"{layout}_layout")
    dpi = 100
    
    print("Generating images...")
    global_idx = len(df) if not df.empty else 0
    
    for split, dataset in all_datasets.items():
        # Skip if all images for this split already exist
        split_images = df[df['split'] == split]
        if not split_images.empty and len(split_images) == len(dataset):
            all_exist = all(os.path.exists(os.path.join(image_dir, path)) for path in split_images['image_path'])
            if all_exist:
                print(f"All images found for {split} split. Skipping generation.")
                continue
        
        print(f"Generating images for {split} split...")
        
        for idx in tqdm(range(len(dataset)), desc=f"Generating {split} images"):
            data = dataset[idx]
            G = nx.Graph()
            G.add_nodes_from(range(data.num_nodes))
            G.add_edges_from(data.edge_index.t().numpy())
            
            # Compute layout
            if layout == "spring":
                pos = layout_func(G, seed=seed + global_idx, k=0.3)
            else:
                pos = layout_func(G)
            
            # Draw graph with black background
            plt.figure(figsize=(8, 8), facecolor='black')
            nx.draw_networkx_nodes(G, pos, node_size=50, node_color='skyblue', 
                                  edgecolors='white', linewidths=0.8, alpha=0.9)
            nx.draw_networkx_edges(G, pos, width=1.5, alpha=0.8, edge_color='white')
            plt.axis('off')
            plt.tight_layout()

            img_path = f"{split}_graph_{global_idx:04d}.png"
            full_path = os.path.join(image_dir, img_path)
            plt.savefig(full_path, dpi=dpi, facecolor='black', bbox_inches='tight')
            plt.close()

            generator_idx = data.generator_idx.item() if hasattr(data, 'generator_idx') else -1
            generator_type = generator_types[generator_idx] if 0 <= generator_idx < len(generator_types) else "unknown"
            
            mix_factor = data.mix_factor.item() if hasattr(data, 'mix_factor') else -1
            
            new_rows.append({
                "image_path": img_path, 
                "label": data.y.item(),
                "generator_type": generator_type,
                "split": split,
                "graph_idx": idx, # local index in that split
                "mix_factor": mix_factor
            })
            
            global_idx += 1
    
    # Update dataset metadata
    if new_rows:
        new_df = pd.DataFrame(new_rows)
        if df.empty:
            df = new_df
        else:
            df = pd.concat([df, new_df], ignore_index=True)
            
        df.to_csv(dataset_csv, index=False)
    
    print(f"Image generation completed! Images saved in {image_dir}")
    print(f"Dataset metadata saved in {dataset_csv}")
    
    # Print sample count per split
    split_counts = df['split'].value_counts()
    print("\nDataset split distribution:")
    for split, count in split_counts.items():
        print(f"{split}: {count} samples")
    
    return image_dir, dataset_csv

def visualize_spectral_gap_by_generator_single(dataset, title, save_path, break_point=4.0):
    # Create a spectral gap distribution plot by generator type for a single dataset
    plt.rcParams.update({
        'font.family': 'sans-serif',
        'font.weight': 'bold',
        'font.size': 14
    })
    
    fig = plt.figure(figsize=(12, 10))
    gs = fig.add_gridspec(2, 1, height_ratios=[1, 4], hspace=0.05)
    ax_top = fig.add_subplot(gs[0])
    ax_bottom = fig.add_subplot(gs[1])
    
    generator_colors = {
        'sbm_dumbbell': '#D85B72',
        'sbm_multi': '#795F9C',
        'geometric_random': '#6C8FA9',
        'modified_config': '#886441'
    }
    
    generator_display_names = {
        'sbm_dumbbell': 'Dumbbell SBM',
        'sbm_multi': 'Multi-community SBM',
        'geometric_random': 'Random Geometric',
        'modified_config': 'Modified Config Model'
    }
    
    # Load metadata if not already loaded
    if not hasattr(dataset, 'metadata'):
        metadata_path = os.path.join(dataset.processed_dir, f'metadata_{dataset.mode}.json')
        if os.path.exists(metadata_path):
            with open(metadata_path, 'r') as f:
                dataset.metadata = json.load(f)
    
    generator_gaps = {}
    all_gaps = []
    max_density = 0
    
    # Get spectral gaps grouped by generator type
    if hasattr(dataset, 'metadata') and dataset.metadata:
        for graph_info in dataset.metadata.get('graphs', []):
            gen_type = graph_info.get('generator_type')
            idx = graph_info.get('index')
            
            if gen_type and idx is not None and idx < len(dataset):
                try:
                    data = dataset[idx]
                    gap = data.y.item()
                    
                    if gen_type not in generator_gaps:
                        generator_gaps[gen_type] = []
                    
                    generator_gaps[gen_type].append(gap)
                    all_gaps.append(gap)
                except:
                    continue
    
    # If no metadata is available, fall back to raw data
    if not generator_gaps:
        print(f"Warning: No metadata for {title}. Grouping skipped.")
        all_gaps = [data.y.item() for data in dataset]
        generator_gaps = {"unknown": all_gaps}
    
    min_gap = min(all_gaps)
    max_gap = max(all_gaps)
    
    for gen_type, gaps in generator_gaps.items():
        color = generator_colors.get(gen_type, '#333333')
        display_name = generator_display_names.get(gen_type, gen_type)
        
        for ax in [ax_top, ax_bottom]:
            kde = sns.kdeplot(
                data=gaps,
                label=display_name if ax == ax_bottom else "",  
                color=color,
                linewidth=3,
                alpha=1,
                ax=ax
            )
            
            line = kde.lines[-1]
            max_density = max(max_density, max(line.get_ydata()))
            
            sns.kdeplot(
                data=gaps,
                color=color,
                alpha=0.2,
                fill=True,
                ax=ax
            )
    
    low_range = (0, break_point)
    high_range = (break_point, max_density * 1.1)
    
    ax_bottom.set_ylim(*low_range)
    ax_top.set_ylim(*high_range)
    
    for ax in [ax_top, ax_bottom]:
        ax.set_xlim(min_gap - 0.05, max_gap + 0.05)
        ax.tick_params(labelsize=16, colors='#2F2F2F')
        ax.grid(True, linestyle='--', alpha=0.2, color='#2F2F2F')
    
    for ax in [ax_top, ax_bottom]:
        ax.set_ylabel("Density", fontsize=20, fontweight='bold', color='#2F2F2F')
    
    ax_top.set_xticklabels([])
    
    d = .015
    kwargs = dict(transform=ax_top.transAxes, color='k', clip_on=False)
    ax_top.plot((-d, +d), (-d, +d), **kwargs)
    kwargs.update(transform=ax_bottom.transAxes)
    ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
    
    fig.suptitle(title, fontsize=26, fontweight='bold', y=0.95, color='#2F2F2F')
    ax_bottom.set_xlabel("Spectral Gap", fontsize=20, fontweight='bold', color='#2F2F2F')
    
    legend = plt.legend(loc='best', fontsize=20, frameon=True,
                      facecolor='white', edgecolor='black', 
                      framealpha=0.9, title="Graph Types") 
    legend.get_title().set_fontweight('bold')
    legend.get_title().set_fontsize(20)
    
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def test_spectral_gap_distribution(train_dataset, val_dataset, test_dataset):
    # Create spectral gap distribution plots for each dataset spli
    datasets = {
        'Train': train_dataset,
        'Test-ID': test_dataset['ID'],
        'Test-Near-OOD': test_dataset['Near-OOD'],
        'Test-Far-OOD': test_dataset['Far-OOD']
    }
    
    for name, dataset in datasets.items():
        title = f"Spectral Gap Distribution by Graph Type ({name})"
        save_path = f"spectral_gap_{name.lower()}.pdf"
        
        visualize_spectral_gap_by_generator_single(dataset, title, save_path)
        
        print(f"Saved spectral gap distribution plot for {name} dataset")


if __name__ == "__main__":
    root_dir = "./spectral_gap_regression"
    
    # Generate the dataset
    train_dataset, val_dataset, test_dataset = generate_spectral_gap_dataset(root_dir)
    
    # Generate image representations
    image_dir, dataset_csv = generate_spectral_gap_image_dataset(
        root_dir=root_dir,
        layout="forceatlas2",
        image_size=224,
        seed=42
    )

    test_spectral_gap_distribution(train_dataset, val_dataset, test_dataset)
