import os
import random
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import shutil
import pandas as pd
from torch_geometric.data import Data, InMemoryDataset
from collections import defaultdict, Counter
from torch_geometric.utils import degree
from torch_geometric.transforms import OneHotDegree
import seaborn as sns
from sklearn.manifold import TSNE
from torch_geometric.utils import to_undirected

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

class BridgeCountDataset(InMemoryDataset):
    """
    PyTorch Geometric dataset for generating and storing graphs
    with various topological structures, using the number of bridges
    in each graph as the regression target.
    """
    
    def __init__(self, root, mode='train', transform=None, pre_transform=None, pre_filter=None, 
                 use_degree_features=True, min_nodes=None, max_nodes=None, samples_per_class=None):
        self.mode = mode
        self.use_degree_features = use_degree_features
        
        # Set node range and samples per class based on the mode or use custom values
        if min_nodes is None or max_nodes is None or samples_per_class is None:
            if mode == 'train':
                self.min_nodes, self.max_nodes = 20, 50
                self.samples_per_class = 500
            elif mode == 'val':
                self.min_nodes, self.max_nodes = 20, 50  # keep consistent with training set
                self.samples_per_class = 50
            else:  # test
                self.min_nodes, self.max_nodes = 20, 50  # ID test uses same distribution
                self.samples_per_class = 50
        else:
            self.min_nodes = min_nodes
            self.max_nodes = max_nodes
            self.samples_per_class = samples_per_class

        # Define topology classes
        self._num_classes = 5
        self.class_names = [
            "Geometric",
            "Community",
            "Hierarchical",
            "Bottleneck",
            "Multi-core",
        ]
        
        # Target statistical features to keep inter-class stats aligned
        self.target_avg_degree = 4.0
        self.target_density_range = (0.15, 0.25)
        
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return [f"dummy_{self.mode}.txt"]

    @property
    def processed_file_names(self):
        return [f"bridge_count_{self.mode}.pt"]

    def download(self):
        # Create a dummy raw file to meet PyG requirements
        os.makedirs(self.raw_dir, exist_ok=True)
        with open(os.path.join(self.raw_dir, f"dummy_{self.mode}.txt"), "w") as f:
            f.write("dummy")

    def process(self):
        data_list = []

        for class_idx in range(self._num_classes):
            for _ in tqdm(
                range(self.samples_per_class), desc=f"Generating {self.mode} class {class_idx}"
            ):
                n_nodes = random.randint(self.min_nodes, self.max_nodes)

                # Generate graph based on the specific topology type
                if class_idx == 0:
                    G = self.generate_random_geometric(n_nodes)
                elif class_idx == 1:
                    G = self.generate_community(n_nodes)
                elif class_idx == 2:
                    G = self.generate_hierarchical_hub(n_nodes)
                elif class_idx == 3:
                    G = self.generate_bottleneck(n_nodes)
                else:
                    G = self.generate_multi_core(n_nodes)

                # Ensure the graph is connected
                if not nx.is_connected(G):
                    largest_cc = max(nx.connected_components(G), key=len)
                    G = G.subgraph(largest_cc).copy()

                # Normalize node indices
                G = nx.convert_node_labels_to_integers(G)

                # Convert to PyG Data object
                edge_index = torch.tensor(list(G.edges)).t().contiguous()
                edge_index = to_undirected(edge_index)
                x = torch.ones(G.number_of_nodes(), 1)

                # Calculate the number of bridges
                num_bridges = len(list(nx.bridges(G)))
                y = torch.tensor([num_bridges], dtype=torch.float)

                data = Data(
                    x=x,
                    edge_index=edge_index,
                    y=y,
                    class_idx=torch.tensor([class_idx], dtype=torch.long),
                )
                data_list.append(data)

        # Pre-filtering if specified
        if self.pre_filter is not None:
            data_list = [d for d in data_list if self.pre_filter(d)]

        # Apply one-hot degree features
        if self.use_degree_features:
            max_degree = 100
            degree_transform = OneHotDegree(max_degree)
            data_list = [degree_transform(d) for d in data_list]

        # Pre-transformations if specified
        if self.pre_transform is not None:
            data_list = [self.pre_transform(d) for d in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

    def generate_argg_topology(self, n_nodes, r_inner, r_outer, r_connect):
        # Generate annular random geometric graph topology
        G = nx.Graph()
        G.add_nodes_from(range(n_nodes))
 
        # Temporary positions used for topology generation (not for visualization)
        temp_positions = {}
 
        # Distribute nodes within the annular region
        for i in range(n_nodes):
            theta = 2 * np.pi * random.random()
            u = random.random()
            r = np.sqrt(r_inner**2 + u * (r_outer**2 - r_inner**2))

            x = r * np.cos(theta)
            y = r * np.sin(theta)
            temp_positions[i] = (x, y)
 
        # Connect nodes based on distance threshold
        for i in range(n_nodes):
            for j in range(i+1, n_nodes):
                x1, y1 = temp_positions[i]
                x2, y2 = temp_positions[j]

                dist = np.sqrt((x1-x2)**2 + (y1-y2)**2)

                if dist <= r_connect:
                    G.add_edge(i, j)
                    
        return G

    def generate_random_geometric(self, n_nodes):
        # Generate a random geometric graph
        radius = random.uniform(0.15, 0.25)
        G = nx.random_geometric_graph(n_nodes, radius)
        
        # Ensure the graph is connected by linking disconnected components
        if not nx.is_connected(G):
            comps = list(nx.connected_components(G))
            largest_cc = max(comps, key=len)
            for comp in comps:
                if comp != largest_cc:
                    G.add_edge(random.choice(list(comp)), random.choice(list(largest_cc)))
        return G

    def generate_community(self, n_nodes):
        # Generate a graph with community structure
        G = nx.Graph()
        num_communities = random.randint(3, min(5, n_nodes // 5))
        nodes_per_comm = n_nodes // num_communities
        communities = []
        node_cnt = 0
        
        # Create communities and intra-community edges
        for i in range(num_communities):
            comm_size = n_nodes - node_cnt if i == num_communities - 1 else nodes_per_comm
            community = list(range(node_cnt, node_cnt + comm_size))
            communities.append(community)
            node_cnt += comm_size
            p_within = random.uniform(0.6, 0.8)
            for u in community:
                G.add_node(u)
                for v in community:
                    if u < v and random.random() < p_within:
                        G.add_edge(u, v)
        
        # Add sparse inter-community edges
        p_between = random.uniform(0.01, 0.05)
        for i in range(num_communities):
            for j in range(i + 1, num_communities):
                for u in communities[i]:
                    for v in communities[j]:
                        if random.random() < p_between:
                            G.add_edge(u, v)
        
        # Ensure all communities are connected together
        if not nx.is_connected(G):
            for i in range(num_communities - 1):
                G.add_edge(random.choice(communities[i]), random.choice(communities[i + 1]))
        
        return G

    def generate_hierarchical_hub(self, n_nodes):
        # Generate a hierarchical hub-based graph
        G = nx.Graph()
        G.add_nodes_from(range(n_nodes))
        
        num_levels = min(4, max(2, n_nodes // 10))
        remaining_ratio = 1.0
        decay = 0.4
        level_ratios = []
        
        # Compute proportions of nodes per hierarchy level
        for level in range(num_levels - 1):
            ratio = remaining_ratio * decay if level < num_levels - 2 else remaining_ratio
            level_ratios.append(ratio)
            remaining_ratio -= ratio
        if remaining_ratio < 0.5:
            level_ratios = [r * 0.5 for r in level_ratios]
            remaining_ratio = 0.5
        level_ratios.append(remaining_ratio)
        level_sizes = [max(1, int(r * n_nodes)) for r in level_ratios]
        diff = n_nodes - sum(level_sizes)
        level_sizes[-1] += diff
        
        # Assign nodes to each level
        level_nodes = []
        idx = 0
        for sz in level_sizes:
            level_nodes.append(list(range(idx, idx + sz)))
            idx += sz
        
        # Add intra-level and inter-level connections
        for lvl in range(num_levels - 1):
            upper, lower = level_nodes[lvl], level_nodes[lvl + 1]
            intra_density = 0.7 * (num_levels - lvl) / num_levels
            for i in range(len(upper)):
                for j in range(i + 1, len(upper)):
                    if random.random() < intra_density:
                        G.add_edge(upper[i], upper[j])
            for ln in lower:
                for un in random.sample(upper, min(random.randint(1, 3), len(upper))):
                    G.add_edge(ln, un)
        return G

    def generate_bottleneck(self, n_nodes):
        # Generate a graph with bottlenecks between communities
        G = nx.Graph()
        num_communities = random.randint(2, min(4, n_nodes // 7))
        nodes_per_comm = n_nodes // num_communities
        communities, start_idx = [], 0
        
        # Create communities with dense internal connections
        for i in range(num_communities):
            comm_size = n_nodes - start_idx if i == num_communities - 1 else nodes_per_comm
            community = list(range(start_idx, start_idx + comm_size))
            communities.append(community)
            start_idx += comm_size
            p_within = random.uniform(0.4, 0.6)
            for u in community:
                G.add_node(u)
                for v in community:
                    if u < v and random.random() < p_within:
                        G.add_edge(u, v)
        
        # Add narrow connections (bottlenecks) between adjacent communities
        for i in range(len(communities) - 1):
            k = 1
            bs1 = random.sample(communities[i], k)
            bs2 = random.sample(communities[i + 1], k)
            for b1, b2 in zip(bs1, bs2):
                G.add_edge(b1, b2)
        return G

    def generate_multi_core(self, n_nodes):
        # Generate a graph with multiple densely connected cores and a sparse periphery
        G = nx.Graph()
        n_cores = random.randint(2, 3)
        core_ratio = random.uniform(0.5, 0.6)
        core_total = int(n_nodes * core_ratio)
        core_size = max(4, core_total // n_cores)
        cores, node_counter = [], 0
        
        # Create core regions
        for i in range(n_cores):
            core = (
                list(range(node_counter, core_total))
                if i == n_cores - 1
                else list(range(node_counter, node_counter + core_size))
            )
            cores.append(core)
            node_counter += len(core)
            p_within = random.uniform(0.6, 0.8)
            for u in core:
                G.add_node(u)
                for v in core:
                    if u < v and random.random() < p_within:
                        G.add_edge(u, v)
                        
        # Add inter-core edges
        for i in range(len(cores)):
            for j in range(i + 1, len(cores)):
                k = random.randint(1, 2)
                for b1, b2 in zip(random.sample(cores[i], k), random.sample(cores[j], k)):
                    G.add_edge(b1, b2)
        
        # Add peripheral nodes and connect them sparsely to core nodes
        periphery = list(range(node_counter, n_nodes))
        for p in periphery:
            G.add_node(p)
            core = random.choice(cores)
            for c in random.sample(core, random.randint(1, 2)):
                G.add_edge(p, c)
        return G


def generate_bridge_count_dataset(root_dir: str, seed: int = 42):
    set_seed(seed)
    
    # Create root directory if it doesn't exist
    os.makedirs(root_dir, exist_ok=True)
    
    # Configuration for test sets with different levels of distribution shift
    difficulty_configs = {
        'ID': {
            'min_nodes': 20, 
            'max_nodes': 50,  # Same distribution as training
            'samples_per_class': 50
        },
        'Near-OOD': {
            'min_nodes': 40, 
            'max_nodes': 100,  # Moderate shift
            'samples_per_class': 50
        },
        'Far-OOD': {
            'min_nodes': 60, 
            'max_nodes': 150,  # Significant shift
            'samples_per_class': 50
        }
    }
    
    # Generate training and validation datasets
    train_dataset = BridgeCountDataset(os.path.join(root_dir, 'train'), mode='train')
    val_dataset = BridgeCountDataset(os.path.join(root_dir, 'val'), mode='val')
    
    # Generate test datasets for all difficulty levels
    test_datasets = {}
    
    for difficulty, config in difficulty_configs.items():
        test_dir = os.path.join(root_dir, f'test_{difficulty}')
        
        # Check if the test dataset already exists
        processed_file = os.path.join(test_dir, 'processed', 'bridge_count_test.pt')
        if os.path.exists(processed_file):
            print(f"Found existing {difficulty} test set. Loading...")
            test_dataset = BridgeCountDataset(test_dir, mode='test', use_degree_features=True)
        else:
            print(f"Generating {difficulty} test set...")
            test_dataset = BridgeCountDataset(
                test_dir, 
                mode='test',
                min_nodes=config['min_nodes'],
                max_nodes=config['max_nodes'],
                samples_per_class=config['samples_per_class'],
                use_degree_features=True
            )
        
        test_datasets[difficulty] = test_dataset
        print(f"{difficulty} test set: {len(test_dataset)} samples, node range: {config['min_nodes']}-{config['max_nodes']}")
    
    # Print summary of train and validation sets
    print(f"Training set: {len(train_dataset)} samples, node range: {train_dataset.min_nodes}-{train_dataset.max_nodes}")
    print(f"Validation set: {len(val_dataset)} samples, node range: {val_dataset.min_nodes}-{val_dataset.max_nodes}")
    
    return train_dataset, val_dataset, test_datasets


def collect_cyclic_graphs_for_inspection(dataset_csv, image_dir, target_dir=None):
    """
    Collect images of cyclic graphs for manual inspection.
    """
    
    df = pd.read_csv(dataset_csv)
    cyclic_graphs = df[(df['label'] == 0) | (df['class_name'] == 'Cyclic')]
    
    if target_dir is None:
        target_dir = os.path.join(os.path.dirname(image_dir), "cyclic_inspection")
    
    os.makedirs(target_dir, exist_ok=True)
    
    print(f"Found {len(cyclic_graphs)} cyclic graphs. Copying images...")
    
    # Copy relevant images to target directory
    for i, row in cyclic_graphs.iterrows():
        src_path = os.path.join(image_dir, row['image_path'])
        dst_path = os.path.join(target_dir, f"{row['split']}_cyclic_{i:04d}.png")
        
        if os.path.exists(src_path):
            shutil.copy(src_path, dst_path)
    
    print(f"Copied {len(cyclic_graphs)} cyclic graph images to {target_dir}")
    return target_dir


def generate_bridge_count_image_dataset(
    root_dir: str,
    layout: str = "spring",
    image_size: int = 224,
    seed: int = 42,
):
    """
    Generate visual graph representations (images) for the bridge count dataset.
    """
    set_seed(seed)
    
    # Ensure image directory exists
    image_dir = os.path.join(root_dir, "images")
    os.makedirs(image_dir, exist_ok=True)
    
    dataset_csv = os.path.join(root_dir, "dataset.csv")
    
    # Load existing PyG datasets
    train_dataset = BridgeCountDataset(os.path.join(root_dir, 'train'), mode='train')
    val_dataset = BridgeCountDataset(os.path.join(root_dir, 'val'), mode='val')
    # Load all test datasets
    difficulties = ['ID', 'Near-OOD', 'Far-OOD']
    test_datasets = {}
    
    for difficulty in difficulties:
        test_dir = os.path.join(root_dir, f'test_{difficulty}')
        if os.path.exists(test_dir):
            test_datasets[f'test_{difficulty}'] = BridgeCountDataset(test_dir, mode='test')
        else:
            print(f"Warning: Test set for {difficulty} not found. Please ensure it's generated.")

    
    # Combine all datasets
    all_datasets = {
        'train': train_dataset,
        'val': val_dataset,
        **test_datasets  # Include all test sets
    }
    
    # Check if data csv already exists
    if os.path.exists(dataset_csv):
        print("Found existing dataset CSV.")
        df = pd.read_csv(dataset_csv)
    else:
        print("No dataset CSV found. Creating a new one.")
        df = pd.DataFrame(columns=["image_path", "label", "class_name", "split"])
    
    # Prepare to generate graph images
    new_rows = []
    layout_func = getattr(nx, f"{layout}_layout")
    dpi = image_size//8
    
    print("Generating graph images...")
    global_idx = len(df) if not df.empty else 0
    
    for split, dataset in all_datasets.items():
        # Skip generation if all images already exist
        split_images = df[df['split'] == split]
        if not split_images.empty and len(split_images) == len(dataset):
            all_exist = all(os.path.exists(os.path.join(image_dir, path)) for path in split_images['image_path'])
            if all_exist:
                print(f"All images for {split} split found. Skipping generation.")
                continue
        
        print(f"Generating images for {split} split...")
        
        for idx in tqdm(range(len(dataset)), desc=f"Generating {split} images"):
            data = dataset[idx]
            G = nx.Graph()
            G.add_nodes_from(range(data.num_nodes))
            G.add_edges_from(data.edge_index.t().numpy())
            
            # Compute node layout for visualization
            if layout == "spring":
                pos = layout_func(G, seed=seed + global_idx, k=0.3)
            else:
                pos = layout_func(G)
            
            # Draw graph on black background
            plt.figure(figsize=(8, 8), facecolor='black')
            nx.draw_networkx_nodes(G, pos, node_size=50, node_color='skyblue', 
                                  edgecolors='white', linewidths=0.8, alpha=0.9)
            nx.draw_networkx_edges(G, pos, width=1.5, alpha=0.8, edge_color='white')
            
            plt.axis('off')
            plt.tight_layout()
            
            # Save image with a unique filename
            img_path = f"{split}_graph_{global_idx:04d}.png"
            full_path = os.path.join(image_dir, img_path)
            plt.savefig(full_path, dpi=dpi, facecolor='black', bbox_inches='tight')
            plt.close()

            new_rows.append({
                "image_path": img_path, 
                "label": data.y.item(),
                "class_name": train_dataset.class_names[data.class_idx.item()],
                "split": split,
                "graph_idx": idx # local index in that split
            })
            
            global_idx += 1
    
    # Update dataset CSV metadata
    if new_rows:
        new_df = pd.DataFrame(new_rows)
        if df.empty:
            df = new_df
        else:
            df = pd.concat([df, new_df], ignore_index=True)
            
        df.to_csv(dataset_csv, index=False)
    
    print(f"Graph image generation complete! Images saved to {image_dir}")
    print(f"Dataset metadata saved to {dataset_csv}")
    
    # Print sample count per split
    split_counts = df['split'].value_counts()
    print("\nSample distribution by split:")
    for split, count in split_counts.items():
        print(f"{split}: {count} samples")

    return image_dir, dataset_csv


def analyze_degree_distribution(train_dataset, val_dataset, test_dataset, split_name='train'):
    """
    Analyze degree distribution for each graph class in the specified dataset split.
    Focuses on key degree values, such as the proportion of nodes with degree 2.
    """
    
    # Select dataset split to analyze
    if split_name == 'train':
        dataset = train_dataset
    elif split_name == 'val':
        dataset = val_dataset
    else:
        dataset = test_dataset
    
    class_names = dataset.class_names
    
    # Initialize structures to collect degree statistics
    class_degree_counts = {name: defaultdict(int) for name in class_names}
    class_total_nodes = {name: 0 for name in class_names}
    
    for data in dataset:
        class_idx = data.y.item()
        class_name = class_names[class_idx]
        
        # Compute degrees of nodes
        deg = degree(data.edge_index[0], data.num_nodes, dtype=torch.long).numpy()
        
        # Record node count per degree
        for d in deg:
            class_degree_counts[class_name][int(d)] += 1
            class_total_nodes[class_name] += 1
    
    # Plot degree percentage distribution per class
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()
    
    degree_summary = []
    
    for i, class_name in enumerate(sorted(class_names)):
        if class_total_nodes[class_name] == 0:
            continue
            
        degrees = sorted(class_degree_counts[class_name].keys())
        counts = [class_degree_counts[class_name][d] for d in degrees]
        percentages = [count/class_total_nodes[class_name]*100 for count in counts]
        
        ax = axes[i]
        ax.bar(degrees, percentages)
        ax.set_title(f"{class_name} - Degree distribution")
        ax.set_xlabel("Node degree")
        ax.set_ylabel("Percentage (%)")
        
        # Annotate the proportion of nodes with degree = 2
        for d, p in zip(degrees, percentages):
            if d == 2:
                ax.text(d, p+1, f"{p:.1f}%", ha='center', fontweight='bold', color='red')
            else:
                ax.text(d, p+0.5, f"{p:.1f}%", ha='center')
        
        # Save summary stats
        degree_props = {
            "class": class_name,
            "total_nodes": class_total_nodes[class_name],
            "avg_degree": sum(d*c for d,c in class_degree_counts[class_name].items())/class_total_nodes[class_name],
            "degree_1%": class_degree_counts[class_name].get(1, 0)/class_total_nodes[class_name]*100,
            "degree_2%": class_degree_counts[class_name].get(2, 0)/class_total_nodes[class_name]*100,
            "degree_3%": class_degree_counts[class_name].get(3, 0)/class_total_nodes[class_name]*100,
            "degree_4+%": sum(class_degree_counts[class_name].get(d, 0) for d in range(4,100))/class_total_nodes[class_name]*100
        }
        degree_summary.append(degree_props)
        
    plt.tight_layout()
    plt.savefig(f"{split_name}_degree_distribution_detail.png", dpi=300)
    
    # Print degree distribution summary
    summary_df = pd.DataFrame(degree_summary)
    print(f"\n{split_name} Set - Degree Distribution Summary by Class")
    print("-" * 80)
    print(summary_df.to_string(index=False, float_format=lambda x: f"{x:.2f}"))
    
    # Create comparison plot for proportion of nodes with degree == 2
    plt.figure(figsize=(10, 6))
    plt.bar(summary_df['class'], summary_df['degree_2%'])
    plt.title(f"{split_name} Set - Proportion of Nodes with Degree = 2")
    plt.ylabel("Percentage (%)")
    plt.ylim(0, 100)
    for i, v in enumerate(summary_df['degree_2%']):
        plt.text(i, v+2, f"{v:.1f}%", ha='center')
    plt.tight_layout()
    plt.savefig(f"{split_name}_degree2_comparison.png", dpi=300)

    return summary_df


def visualize_features(train_ds, val_ds, test_ds, method: str = "tsne"):
    """
    Visualize hand-crafted graph features (e.g., degree, density, clustering)
    using dimensionality reduction (t-SNE or PCA) across train/val/test splits.
    Each graph is represented by a 7-dimensional feature vector.
    """
    ds_map = {"train": train_ds, "val": val_ds, "test": test_ds}
    plt.figure(figsize=(18, 6))
    for idx, (split, ds) in enumerate(ds_map.items()):
        feats, labels = [], []
        
        for data in ds:
            # Reconstruct graph for feature extraction
            G = nx.Graph()
            G.add_nodes_from(range(data.num_nodes))
            G.add_edges_from(data.edge_index.t().numpy())
            
            # Extract graph-level features
            deg = degree(data.edge_index[0], data.num_nodes, dtype=torch.float)
            feat_vec = [
                data.num_nodes,
                data.edge_index.shape[1] // 2,
                nx.density(G),
                float(deg.mean()),
                float(deg.std()),
                np.mean(list(nx.closeness_centrality(G).values())) if G.number_of_edges() > 0 else 0,
                nx.average_clustering(G) if G.number_of_edges() > 0 else 0,
            ]
            feats.append(feat_vec)
            labels.append(data.class_idx.item())
            
        # Normalize features
        feats = np.array(feats)
        feats = (feats - feats.mean(0)) / (feats.std(0) + 1e-8)
        
        # Apply dimensionality reduction
        if method == "tsne":
            emb = TSNE(n_components=2, random_state=42).fit_transform(feats)
        else:
            from sklearn.decomposition import PCA
            emb = PCA(n_components=2, random_state=42).fit_transform(feats)
        
        # Plot embedding
        plt.subplot(1, 3, idx + 1)
        for i, cls in enumerate(sorted(ds.class_names)):
            mask = np.array(labels) == i
            plt.scatter(emb[mask, 0], emb[mask, 1], label=cls, alpha=0.7)
        
        plt.title(f"{split} Feature Distribution")
        if idx == 0:
            plt.legend()
    plt.tight_layout()
    plt.savefig("feature_distribution.png", dpi=300)
    plt.show()


def test_degree_distribution(train_ds, val_ds, test_ds):
    """
    Plot the degree distribution (using KDE) of nodes across different graph classes
    for train, validation, and test datasets.
    """
    ds_map = {"train": train_ds, "val": val_ds, "test": test_ds}
    plt.figure(figsize=(15, 10))
    
    for idx, (split, ds) in enumerate(ds_map.items()):
        # Collect degrees per class
        class_deg = {c: [] for c in ds.class_names}
        for data in ds:
            cls = ds.class_names[data.class_idx.item()]
            class_deg[cls].extend(degree(data.edge_index[0], data.num_nodes, dtype=torch.long).numpy())
        
        # Plot KDE for each class
        plt.subplot(1, 3, idx + 1)
        for cls, d in class_deg.items():
            sns.kdeplot(d, label=cls)
        plt.title(f"{split} degree distribution")
        plt.xlabel("Degree")
        plt.ylabel("Density")
        if idx == 0:
            plt.legend()
            
    plt.tight_layout()
    plt.savefig("degree_distributions.png", dpi=300)
    plt.show()


def visualize_all_bridge_types(dataset, title="Bridge Distribution by Graph Type", 
                             save_path="bridge_distribution.png",
                             break_point=0.4): 
    """
    Visualize the distribution of bridge counts for each graph class.
    Uses a broken y-axis to highlight both high-density and low-density regions.
    """
    plt.rcParams.update({
        'font.family': 'sans-serif',
        'font.weight': 'bold',
        'font.size': 14
    })
    
    # Create a figure with two stacked subplots: one for high-density zoom
    fig = plt.figure(figsize=(12, 10))
    gs = fig.add_gridspec(2, 1, height_ratios=[1, 4], hspace=0.05)
    ax_top = fig.add_subplot(gs[0])
    ax_bottom = fig.add_subplot(gs[1])
    
    # Define custom colors for each graph class
    colors = {
        'Geometric':    '#D85B72',
        'Community':    '#795F9C',
        'Hierarchical': '#6C8FA9',
        'Bottleneck':   '#886441',
        'Multi-core':   '#518463'
    }
    
    # Collect bridge count data per class
    class_bridge_counts = {}
    all_counts = []
    max_density = 0
    
    for data in dataset:
        class_idx = data.class_idx.item()
        class_name = dataset.class_names[class_idx]
        
        if class_name not in class_bridge_counts:
            class_bridge_counts[class_name] = []
            
        count = data.y.item()
        class_bridge_counts[class_name].append(count)
        all_counts.append(count)
    
    min_count = int(min(all_counts))
    max_count = int(max(all_counts))
    
    # Determine tick interval based on range
    range_count = max_count - min_count
    if range_count > 20:
        step = range_count // 10
    else:
        step = max(1, range_count // 5)
        
    bins = np.arange(min_count, max_count + 2) - 0.5
    
    # Plot histograms and KDEs for each class on both subplots
    for class_name, bridge_counts in class_bridge_counts.items():
        color = colors.get(class_name, '#333333')
        
        if len(set(bridge_counts)) == 1:
            # If all values are the same, use histogram
            hist, _ = np.histogram(bridge_counts, bins=bins, density=True)
            for ax in [ax_top, ax_bottom]:
                ax.hist(bridge_counts, bins=bins, alpha=0.4, color=color, 
                       label=class_name if ax == ax_bottom else "", density=True)
            max_density = max(max_density, max(hist))
        else: 
            # Use KDE plot
            for ax in [ax_top, ax_bottom]:
                kde = sns.kdeplot(
                    data=bridge_counts,
                    label=class_name if ax == ax_bottom else "",
                    color=color,
                    linewidth=3,
                    alpha=1,
                    ax=ax
                )
                
                line = kde.lines[-1]
                max_density = max(max_density, max(line.get_ydata()))
                
                sns.kdeplot(
                    data=bridge_counts,
                    color=color,
                    alpha=0.2,
                    fill=True,
                    ax=ax
                )
    
    # Set y-limits for each subplot
    low_range = (0, break_point)
    high_range = (break_point, max_density * 1.1)
    ax_bottom.set_ylim(*low_range)
    ax_top.set_ylim(*high_range)
    
    # Set x-axis limits and ticks
    for ax in [ax_top, ax_bottom]:
        ax.set_xlim(min_count-0.5, max_count+0.5)
        ax.set_xticks(np.arange(min_count, max_count + 1, step))
        ax.tick_params(labelsize=16, colors='#2F2F2F')
        ax.grid(True, linestyle='--', alpha=0.2, color='#2F2F2F')
    
    # Hide x-axis labels on top plot
    ax_top.set_xticklabels([])
    
    # Add diagonal cut marks to indicate axis break
    d = .015 # size of the cut
    kwargs = dict(transform=ax_top.transAxes, color='k', clip_on=False)
    ax_top.plot((-d, +d), (-d, +d), **kwargs)
    
    kwargs.update(transform=ax_bottom.transAxes)
    ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
    
    # Set overall title and axis labels
    fig.suptitle(title, fontsize=26, fontweight='bold', y=0.95, color='#2F2F2F')
    ax_bottom.set_xlabel("Number of Bridges", fontsize=20, fontweight='bold', color='#2F2F2F')
    # ax_bottom.set_ylabel("Density", fontsize=20, fontweight='bold', color='#2F2F2F')
    
    for ax in [ax_top, ax_bottom]:
        ax.set_ylabel("Density", fontsize=20, fontweight='bold', color='#2F2F2F')
   
    # Add legend to bottom subplot
    legend = plt.legend(loc='best', fontsize=20, frameon=True,
                       facecolor='white', edgecolor='black', 
                       framealpha=0.9, title='Graph Types')
    legend.get_title().set_fontweight('bold')
    legend.get_title().set_fontsize(20)
    
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print descriptive statistics
    print("\nDistribution Statistics:")
    print("-" * 50)
    for class_name, counts in class_bridge_counts.items():
        counts_array = np.array(counts)
        print(f"\n{class_name}:")
        print(f"  Count: {len(counts)}")
        print(f"  Range: [{min(counts)}, {max(counts)}]")
        print(f"  Mean ± Std: {np.mean(counts):.2f} ± {np.std(counts):.2f}")
        print(f"  Median: {np.median(counts):.2f}")


if __name__ == "__main__":
    # Set dataset root directory
    root_dir = "./bridge_count_regression"
    
    # Generate graph image dataset
    train_dataset, val_dataset, test_dataset = generate_bridge_count_dataset(root_dir)
    
    # Generate graph image dataset
    image_dir, dataset_csv = generate_bridge_count_image_dataset(
        root_dir=root_dir,
        layout="spring",
        image_size=224,
        seed=42
    )

    # Visualize bridge distribution for training set
    visualize_all_bridge_types(
        train_dataset, 
        title="Bridge Distribution across Graph Types (Train)", 
        save_path="train_bridge_distribution.pdf"
    )

    # Visualize bridge distribution for each test difficulty setting
    for difficulty in ['ID', 'Near-OOD', 'Far-OOD']:
        visualize_all_bridge_types(
            test_dataset[difficulty], 
            title=f"Bridge Distribution ({difficulty})", 
            save_path=f"test_{difficulty}_bridge_distribution.pdf"
        )
