import torch
from torch.utils.data import Dataset, ConcatDataset
import networkx as nx
from torch_geometric.utils import from_networkx, to_dense_adj
from torch_geometric.data import Data
import random
import math
from typing import List, Tuple, Optional, Dict, Any, Union
from tqdm import tqdm
from abc import ABC, abstractmethod


class BaseGraphDataset(ABC):
    """Base class for all graph datasets"""

    def __init__(self, num_nodes: int, add_self_loops: bool = True, **kwargs):
        self.num_nodes = num_nodes
        self.add_self_loops = add_self_loops
        self.mat_power = num_nodes

    @abstractmethod
    def generate_graph(self, **kwargs) -> nx.Graph:
        """Generate a single graph"""
        pass

    def compute_matrices(self, graph: nx.Graph) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute adjacency and connectivity matrices from a NetworkX graph"""
        # Convert NetworkX graph to PyTorch Geometric data format
        data = from_networkx(graph)

        # Compute adjacency matrix
        adj_matrix = to_dense_adj(data.edge_index, max_num_nodes=self.num_nodes)
        if self.add_self_loops:
            adj_matrix += torch.eye(self.num_nodes)
        adj_matrix = adj_matrix.squeeze(0)

        # Compute connectivity matrix using matrix exponentiation
        identity = torch.eye(self.num_nodes)
        reachability_matrix = (identity + adj_matrix).matrix_power(self.mat_power)
        connectivity_matrix = (reachability_matrix > 0).float()

        return adj_matrix, connectivity_matrix


class ErdosRenyiGenerator(BaseGraphDataset):
    """Generator for Erdős-Rényi graphs"""

    def __init__(
        self,
        num_nodes: int,
        p: float = 0.1,
        sample_p: bool = False,
        p_range: Tuple[float, float] = (0.0, 1.0),
        restrict_diam: Optional[int] = None,
        **kwargs,
    ):
        super().__init__(num_nodes, **kwargs)
        self.p = p
        self.sample_p = sample_p
        self.p_range = p_range
        self.restrict_diam = restrict_diam

    def generate_graph(self, **kwargs) -> nx.Graph:
        p_val = (
            random.uniform(self.p_range[0], self.p_range[1])
            if self.sample_p
            else self.p
        )
        G = nx.erdos_renyi_graph(n=self.num_nodes, p=p_val)
        
        if self.restrict_diam is not None:
            def get_diam(graph):
                if nx.is_connected(graph):
                    return nx.diameter(graph)
                return max([nx.diameter(graph.subgraph(c)) for c in nx.connected_components(graph)])

            diam = get_diam(G)
            while diam > self.restrict_diam:
                G = nx.erdos_renyi_graph(n=self.num_nodes, p=p_val)
                try:
                    diam = get_diam(G)
                except Exception:
                    diam = 0
        
        return G


class ErdosRenyiTwoGraphsGenerator(BaseGraphDataset):
    """Generator for two separate Erdős-Rényi subgraphs with optional connection"""

    def __init__(
        self,
        num_nodes: int,
        p: float = 0.1,
        sample_p: bool = False,
        p_range: Tuple[float, float] = (0.0, 1.0),
        connect_prob: float = 0.5,
        min_component_size_frac: float = 0.25,
        **kwargs,
    ):
        super().__init__(num_nodes, **kwargs)
        self.p = p
        self.sample_p = sample_p
        self.p_range = p_range
        self.connect_prob = connect_prob
        self.min_component_size = int(num_nodes * min_component_size_frac)

    def generate_graph(self, **kwargs) -> nx.Graph:
        p_val = (
            random.uniform(self.p_range[0], self.p_range[1])
            if self.sample_p
            else self.p
        )

        # Generate two subgraphs with minimum component size requirement
        while True:
            G1 = nx.erdos_renyi_graph(n=self.num_nodes // 2, p=p_val)
            largest_comp1 = max(nx.connected_components(G1), key=len)
            if len(largest_comp1) >= self.min_component_size:
                break

        while True:
            G2 = nx.erdos_renyi_graph(n=self.num_nodes // 2, p=p_val)
            largest_comp2 = max(nx.connected_components(G2), key=len)
            if len(largest_comp2) >= self.min_component_size:
                break

        # Combine graphs
        G = nx.Graph()
        G.add_nodes_from(range(self.num_nodes))

        # Add edges from first subgraph
        for u, v in G1.edges():
            G.add_edge(u, v)

        # Add edges from second subgraph (shift indices)
        for u, v in G2.edges():
            G.add_edge(u + self.num_nodes // 2, v + self.num_nodes // 2)

        # Optionally connect the two subgraphs
        if random.random() < self.connect_prob:
            node1 = random.choice(list(largest_comp1))
            node2 = random.choice(list(largest_comp2)) + self.num_nodes // 2
            G.add_edge(node1, node2)

        return G


class ErdosRenyiMediumVariantGenerator(BaseGraphDataset):
    """Generator for Erdős-Rényi graphs with at least two large components"""

    def __init__(
        self,
        num_nodes: int,
        p: float = 0.1,
        sample_p: bool = False,
        p_range: Tuple[float, float] = (0.0, 1.0),
        connect_prob: float = 0.0,
        min_component_size_frac: float = 0.33,
        **kwargs,
    ):
        super().__init__(num_nodes, **kwargs)
        self.p = p
        self.sample_p = sample_p
        self.p_range = p_range
        self.connect_prob = connect_prob
        self.min_component_size = int(num_nodes * min_component_size_frac)

    def generate_graph(self, **kwargs) -> nx.Graph:
        while True:
            p_val = (
                random.uniform(self.p_range[0], self.p_range[1])
                if self.sample_p
                else self.p
            )
            G = nx.erdos_renyi_graph(n=self.num_nodes, p=p_val)
            components = list(nx.connected_components(G))
            large_components = [
                comp for comp in components if len(comp) >= self.min_component_size
            ]

            if len(large_components) >= 2:
                # Optionally connect large components
                if random.random() < self.connect_prob:
                    node1 = random.choice(list(large_components[0]))
                    node2 = random.choice(list(large_components[1]))
                    G.add_edge(node1, node2)
                break

        return G


class ErdosRenyiHardVariantGenerator(BaseGraphDataset):
    """Generator for two separate Erdős-Rényi components with optional single connection"""

    def __init__(
        self,
        num_nodes: int,
        p: float = 0.1,
        sample_p: bool = False,
        p_range: Tuple[float, float] = (0.0, 1.0),
        connect_prob: float = 0.5,
        **kwargs,
    ):
        super().__init__(num_nodes, **kwargs)
        self.p = p * 2  # Increase p to ensure connectivity in smaller components
        self.sample_p = sample_p
        self.p_range = p_range
        self.connect_prob = connect_prob

    def generate_graph(self, **kwargs) -> nx.Graph:
        p_val = (
            random.uniform(self.p_range[0], self.p_range[1])
            if self.sample_p
            else self.p
        )

        # Create two separate Erdős-Rényi components of size n/2 each
        component_size = self.num_nodes // 2
        remaining_nodes = self.num_nodes - 2 * component_size

        # Generate first component
        G1 = nx.erdos_renyi_graph(n=component_size, p=p_val)

        # Generate second component
        G2 = nx.erdos_renyi_graph(n=component_size, p=p_val)

        # Combine graphs
        G = nx.Graph()
        G.add_nodes_from(range(self.num_nodes))

        # Add edges from first component
        for u, v in G1.edges():
            G.add_edge(u, v)

        # Add edges from second component (shift indices)
        for u, v in G2.edges():
            G.add_edge(u + component_size, v + component_size)

        # Add any remaining isolated nodes (if num_nodes is odd)
        # They will remain isolated

        # With probability connect_prob, connect the two components with a single edge
        if random.random() < self.connect_prob:
            # Pick random nodes from each component
            node1 = random.randint(0, component_size - 1)
            node2 = random.randint(component_size, 2 * component_size - 1)
            G.add_edge(node1, node2)

        return G


class TreeForestGenerator(BaseGraphDataset):
    """Generator for random forest graphs (collection of trees)"""

    def __init__(
        self,
        num_nodes: int,
        min_tree_size: int = 3,
        max_tree_size: int = None,
        **kwargs,
    ):
        super().__init__(num_nodes, **kwargs)
        self.min_tree_size = min_tree_size
        self.max_tree_size = (
            max_tree_size if max_tree_size is not None else num_nodes // 2
        )

    def generate_graph(self, **kwargs) -> nx.Graph:
        G = nx.Graph()
        G.add_nodes_from(range(self.num_nodes))

        remaining_nodes = list(range(self.num_nodes))

        while len(remaining_nodes) >= self.min_tree_size:
            # Determine tree size
            max_possible_size = min(self.max_tree_size, len(remaining_nodes))
            if max_possible_size < self.min_tree_size:
                break

            tree_size = random.randint(self.min_tree_size, max_possible_size)

            # Select nodes for this tree
            tree_nodes = random.sample(remaining_nodes, tree_size)
            remaining_nodes = [n for n in remaining_nodes if n not in tree_nodes]

            # Generate a random tree on these nodes
            if tree_size > 1:
                # Use random spanning tree approach
                # Start with one node, then add others one by one
                tree = nx.Graph()
                tree.add_node(tree_nodes[0])

                for i in range(1, tree_size):
                    new_node = tree_nodes[i]
                    # Connect to a random existing node in the tree
                    existing_node = random.choice(list(tree.nodes()))
                    tree.add_edge(new_node, existing_node)

                # Add edges to main graph
                G.add_edges_from(tree.edges())

        return G


class StarForestGenerator(BaseGraphDataset):
    """Generator for forest of star graphs (multiple stars of varying sizes)"""

    def __init__(
        self,
        num_nodes: int,
        min_star_size: int = 3,
        max_star_size: int = None,
        **kwargs,
    ):
        super().__init__(num_nodes, **kwargs)
        self.min_star_size = min_star_size
        self.max_star_size = (
            max_star_size if max_star_size is not None else num_nodes // 2
        )

    def generate_graph(self, **kwargs) -> nx.Graph:
        G = nx.Graph()
        G.add_nodes_from(range(self.num_nodes))

        remaining_nodes = list(range(self.num_nodes))

        while len(remaining_nodes) >= self.min_star_size:
            # Determine star size
            max_possible_size = min(self.max_star_size, len(remaining_nodes))
            if max_possible_size < self.min_star_size:
                break

            star_size = random.randint(self.min_star_size, max_possible_size)

            # Select nodes for this star
            star_nodes = random.sample(remaining_nodes, star_size)
            remaining_nodes = [n for n in remaining_nodes if n not in star_nodes]

            # Generate a star graph on these nodes
            if star_size > 1:
                # First node is the center of the star
                center_node = star_nodes[0]

                # Connect center to all other nodes in this star
                for i in range(1, star_size):
                    leaf_node = star_nodes[i]
                    G.add_edge(center_node, leaf_node)

        return G

    def compute_matrices(self, graph: nx.Graph) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute adjacency and connectivity matrices with random node permutation"""
        # Get matrices from parent class
        adj_matrix, connectivity_matrix = super().compute_matrices(graph)

        # Apply random permutation to node indices
        perm = torch.randperm(self.num_nodes)
        adj_matrix = adj_matrix[perm][:, perm]
        connectivity_matrix = connectivity_matrix[perm][:, perm]

        return adj_matrix, connectivity_matrix


class TwoCliquesGenerator(BaseGraphDataset):
    """Generator for two separate cliques of similar sizes"""

    def __init__(
        self,
        num_nodes: int,
        connect_prob: float = 0.0,
        size_variation: float = 0.1,
        **kwargs,
    ):
        super().__init__(num_nodes, **kwargs)
        self.connect_prob = connect_prob
        self.size_variation = size_variation

    def generate_graph(self, **kwargs) -> nx.Graph:
        if self.num_nodes < 4:
            # Need at least 4 nodes for two cliques of size 2 each
            G = nx.Graph()
            G.add_nodes_from(range(self.num_nodes))
            return G

        # Calculate base size for each clique
        base_size = self.num_nodes // 2

        # Add some variation in sizes
        max_variation = int(base_size * self.size_variation)
        size_diff = random.randint(-max_variation, max_variation)

        size1 = base_size + size_diff
        size2 = self.num_nodes - size1

        # Ensure both cliques have at least 2 nodes
        if size1 < 2:
            size1 = 2
            size2 = self.num_nodes - 2
        elif size2 < 2:
            size2 = 2
            size1 = self.num_nodes - 2

        # Create first clique using NetworkX's complete_graph (much faster)
        G1 = nx.complete_graph(size1)

        # Create second clique and relabel nodes to avoid overlap
        G2 = nx.complete_graph(size2)
        G2 = nx.relabel_nodes(G2, {i: i + size1 for i in range(size2)})

        # Combine the two cliques
        G = nx.union(G1, G2)

        # Add any remaining isolated nodes if num_nodes > size1 + size2
        G.add_nodes_from(range(self.num_nodes))

        # Optionally connect the two cliques
        if random.random() < self.connect_prob:
            # Add one edge between the cliques
            node1 = random.randint(0, size1 - 1)
            node2 = random.randint(size1, self.num_nodes - 1)
            G.add_edge(node1, node2)

        return G

    def compute_matrices(self, graph: nx.Graph) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute adjacency and connectivity matrices with random node permutation"""
        # Get matrices from parent class
        adj_matrix, connectivity_matrix = super().compute_matrices(graph)

        # Apply random permutation to node indices
        perm = torch.randperm(self.num_nodes)
        adj_matrix = adj_matrix[perm][:, perm]
        connectivity_matrix = connectivity_matrix[perm][:, perm]

        return adj_matrix, connectivity_matrix


class SBMGenerator(BaseGraphDataset):
    """Generator for Stochastic Block Model graphs"""

    def __init__(
        self,
        num_nodes: int,
        p_intra: float = 0.1,
        p_inter: float = 0.01,
        num_communities: int = 2,
        **kwargs,
    ):
        super().__init__(num_nodes, **kwargs)
        self.p_intra = p_intra
        self.p_inter = p_inter
        self.num_communities = num_communities

    def generate_graph(self, **kwargs) -> nx.Graph:
        sizes = [self.num_nodes // self.num_communities] * self.num_communities
        probs = [
            [
                self.p_intra if i == j else self.p_inter
                for j in range(self.num_communities)
            ]
            for i in range(self.num_communities)
        ]
        return nx.stochastic_block_model(sizes, probs)


class CavemanGraphGenerator(BaseGraphDataset):
    """Generator for caveman graphs (collection of cliques with optional connections)"""

    def __init__(
        self,
        num_nodes: int,
        k: int = 5,
        connect_prob: float = 0.5,
        **kwargs,
    ):
        super().__init__(num_nodes, **kwargs)
        self.k = max(k, 5)  # Ensure clique size is at least 5
        self.connect_prob = connect_prob

        # Calculate number of cliques l = num_nodes // k, but at least 3
        self.l = max(num_nodes // self.k, 3)

        # If we need at least 3 cliques but don't have enough nodes, reduce k
        if self.l * self.k > num_nodes:
            self.k = max(num_nodes // 3, 3)  # Ensure we can fit at least 3 cliques
            self.l = max(num_nodes // self.k, 3)

        # Calculate remaining isolated nodes
        self.isolated_nodes = num_nodes - (self.l * self.k)

    def generate_graph(self, **kwargs) -> nx.Graph:
        # With probability connect_prob, use connected_caveman_graph
        # Otherwise, use regular caveman_graph (disconnected cliques)
        if random.random() < self.connect_prob:
            # Connected caveman graph - each clique is connected to the next
            G = nx.connected_caveman_graph(self.l, self.k)
        else:
            # Regular caveman graph - disconnected cliques
            G = nx.caveman_graph(self.l, self.k)

        # The NetworkX caveman graph functions might create more nodes than needed
        # or use different node numbering, so we need to relabel and adjust

        # Create a new graph with exactly our desired number of nodes
        result_graph = nx.Graph()
        result_graph.add_nodes_from(range(self.num_nodes))

        # Map the caveman graph nodes to our numbering scheme
        caveman_nodes = list(G.nodes())
        used_nodes = min(len(caveman_nodes), self.l * self.k)

        # Add edges from the caveman graph
        node_mapping = {caveman_nodes[i]: i for i in range(used_nodes)}
        for u, v in G.edges():
            if u in node_mapping and v in node_mapping:
                result_graph.add_edge(node_mapping[u], node_mapping[v])

        # The remaining nodes (self.isolated_nodes) remain isolated

        return result_graph

    def compute_matrices(self, graph: nx.Graph) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute adjacency and connectivity matrices with random node permutation"""
        # Get matrices from parent class
        adj_matrix, connectivity_matrix = super().compute_matrices(graph)

        # Apply random permutation to node indices
        perm = torch.randperm(self.num_nodes)
        adj_matrix = adj_matrix[perm][:, perm]
        connectivity_matrix = connectivity_matrix[perm][:, perm]

        return adj_matrix, connectivity_matrix


class OneCircleGenerator(BaseGraphDataset):
    """
    Generator for graphs with all nodes arranged in a single circle.

    Args:
        num_nodes: Total number of nodes in the graph
        max_num_nodes: Maximum number of nodes for padding (default: same as num_nodes)

    Example:
        For 8 nodes: 0-1-2-3-4-5-6-7-0 (forms a complete circle)
    """

    def __init__(
        self,
        num_nodes: int,
        max_num_nodes: Optional[int] = None,
        **kwargs,
    ):
        self.max_num_nodes = max_num_nodes if max_num_nodes is not None else num_nodes
        if num_nodes > self.max_num_nodes:
            raise ValueError("num_nodes cannot be greater than max_num_nodes")

        super().__init__(self.max_num_nodes, **kwargs)
        self.active_nodes = num_nodes

        # Precompute base matrices
        self._precompute_matrices()

    def _precompute_matrices(self):
        """Precompute the base adjacency and connectivity matrices"""
        adj = torch.zeros((self.active_nodes, self.active_nodes))

        # Create circle: each node connects to next node, last connects to first
        for i in range(self.active_nodes):
            next_node = (i + 1) % self.active_nodes
            adj[i, next_node] = adj[next_node, i] = 1

        if self.add_self_loops:
            adj += torch.eye(self.active_nodes)

        # All nodes in a circle are connected to each other
        conn = torch.ones((self.active_nodes, self.active_nodes))

        if self.add_self_loops:
            conn += torch.eye(self.active_nodes)
            conn = (conn > 0).float()

        self.base_adj = adj
        self.base_conn = conn

    def generate_graph(self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
        """For OneCircleGenerator, we return matrices directly since it's deterministic"""
        # Create padded matrices
        adj_padded = torch.zeros((self.max_num_nodes, self.max_num_nodes))
        conn_padded = torch.zeros((self.max_num_nodes, self.max_num_nodes))

        # Copy base matrices to top-left corner
        adj_padded[: self.active_nodes, : self.active_nodes] = self.base_adj
        conn_padded[: self.active_nodes, : self.active_nodes] = self.base_conn

        # Handle padded nodes
        if self.add_self_loops:
            for i in range(self.active_nodes, self.max_num_nodes):
                adj_padded[i, i] = 1.0

        for i in range(self.active_nodes, self.max_num_nodes):
            conn_padded[i, i] = 1.0

        # Random permutation
        perm = torch.randperm(self.max_num_nodes)
        adj = adj_padded[perm][:, perm]
        conn = conn_padded[perm][:, perm]

        return adj, conn


class TwoDegree3ChainsGenerator(BaseGraphDataset):
    """
    Generator for graphs based on TwoChains but with additional connections.
    Each node connects to its immediate neighbor AND its neighbor's neighbor,
    resulting in most nodes having degree 3.

    Args:
        num_nodes: Total number of nodes in the graph
        k: Length of each chain. If None, calculated automatically based on add_isolated_nodes:
           - If add_isolated_nodes=False: k = num_nodes // 2 (uses all nodes for chains)
           - If add_isolated_nodes=True: k = (num_nodes - 2) // 2 (reserves 2 nodes as isolated)
        max_num_nodes: Maximum number of nodes for padding (default: same as num_nodes)
        add_isolated_nodes: If True, reserves some nodes as explicitly isolated. If False,
                           uses all available nodes for chains (default: False)

    Example:
        For 8 nodes with k=4 each chain:
        Chain 1: 0-1-2-3 becomes 0-(1,2), 1-(0,2,3), 2-(0,1,3), 3-(1,2)
        Chain 2: 4-5-6-7 becomes 4-(5,6), 5-(4,6,7), 6-(4,5,7), 7-(5,6)
    """

    def __init__(
        self,
        num_nodes: int,
        k: int = None,
        max_num_nodes: Optional[int] = None,
        add_isolated_nodes: bool = False,
        **kwargs,
    ):
        self.max_num_nodes = max_num_nodes if max_num_nodes is not None else num_nodes
        if num_nodes > self.max_num_nodes:
            raise ValueError("num_nodes cannot be greater than max_num_nodes")

        self.add_isolated_nodes = add_isolated_nodes

        # Determine chain length based on add_isolated_nodes setting
        if k is None:
            if add_isolated_nodes:
                # With isolated nodes: 2 chains of length (num_nodes-2)//2, plus 2 isolated nodes
                self.k = (num_nodes - 2) // 2
                if 2 * self.k + 2 > num_nodes:
                    self.k = (num_nodes - 2) // 2
            else:
                # Without isolated nodes: 2 chains that use all nodes
                self.k = num_nodes // 2
        else:
            self.k = k

        # Validate chain length
        required_nodes = 2 * self.k + (2 if add_isolated_nodes else 0)
        if required_nodes > num_nodes:
            raise ValueError(
                f"Cannot fit 2 chains of length {self.k} with {'2 isolated nodes' if add_isolated_nodes else 'no isolated nodes'} "
                f"in {num_nodes} nodes. Required: {required_nodes}, Available: {num_nodes}"
            )

        super().__init__(self.max_num_nodes, **kwargs)
        self.active_nodes = num_nodes

        # Precompute base matrices
        self._precompute_matrices()

    def _precompute_matrices(self):
        """Precompute the base adjacency and connectivity matrices"""
        adj = torch.zeros((self.active_nodes, self.active_nodes))

        # First chain: 0 to k-1 with extended connections
        for i in range(self.k):
            # Connect to immediate neighbor (if exists)
            if i + 1 < self.k:
                adj[i, i + 1] = adj[i + 1, i] = 1
            # Connect to neighbor's neighbor (if exists)
            if i + 2 < self.k:
                adj[i, i + 2] = adj[i + 2, i] = 1

        # Second chain: k to 2*k-1 with extended connections
        for i in range(self.k, 2 * self.k):
            # Connect to immediate neighbor (if exists)
            if i + 1 < 2 * self.k:
                adj[i, i + 1] = adj[i + 1, i] = 1
            # Connect to neighbor's neighbor (if exists)
            if i + 2 < 2 * self.k:
                adj[i, i + 2] = adj[i + 2, i] = 1

        if self.add_self_loops:
            adj += torch.eye(self.active_nodes)

        # Connectivity matrix: block-diagonal for each chain
        conn = torch.zeros((self.active_nodes, self.active_nodes))
        conn[0 : self.k, 0 : self.k] = 1  # First chain
        conn[self.k : 2 * self.k, self.k : 2 * self.k] = 1  # Second chain

        # Handle remaining nodes
        if self.add_isolated_nodes:
            # Explicitly isolated nodes: only connect to themselves
            for i in range(2 * self.k, self.active_nodes):
                conn[i, i] = 1
        else:
            # Remaining nodes are also isolated (legacy behavior)
            for i in range(2 * self.k, self.active_nodes):
                conn[i, i] = 1

        if self.add_self_loops:
            conn += torch.eye(self.active_nodes)
            conn = (conn > 0).float()

        self.base_adj = adj
        self.base_conn = conn

    def generate_graph(self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
        """For TwoDegree3ChainsGenerator, we return matrices directly since it's deterministic"""
        # Create padded matrices
        adj_padded = torch.zeros((self.max_num_nodes, self.max_num_nodes))
        conn_padded = torch.zeros((self.max_num_nodes, self.max_num_nodes))

        # Copy base matrices to top-left corner
        adj_padded[: self.active_nodes, : self.active_nodes] = self.base_adj
        conn_padded[: self.active_nodes, : self.active_nodes] = self.base_conn

        # Handle padded nodes
        if self.add_self_loops:
            for i in range(self.active_nodes, self.max_num_nodes):
                adj_padded[i, i] = 1.0

        for i in range(self.active_nodes, self.max_num_nodes):
            conn_padded[i, i] = 1.0

        # Random permutation
        perm = torch.randperm(self.max_num_nodes)
        adj = adj_padded[perm][:, perm]
        conn = conn_padded[perm][:, perm]

        return adj, conn


class TwoChainsGenerator(BaseGraphDataset):
    """
    Generator for graphs with two disjoint k-chains plus optional isolated nodes.

    Args:
        num_nodes: Total number of nodes in the graph
        k: Length of each chain. If None, calculated automatically based on add_isolated_nodes:
           - If add_isolated_nodes=False: k = num_nodes // 2 (uses all nodes for chains)
           - If add_isolated_nodes=True: k = (num_nodes - 2) // 2 (reserves 2 nodes as isolated)
        max_num_nodes: Maximum number of nodes for padding (default: same as num_nodes)
        add_isolated_nodes: If True, reserves some nodes as explicitly isolated. If False,
                           uses all available nodes for chains (default: False)

    Examples:
        # 56 nodes, add_isolated_nodes=False: 2 chains of length 28 each
        # 56 nodes, add_isolated_nodes=True: 2 chains of length 27 each + 2 isolated nodes
    """

    def __init__(
        self,
        num_nodes: int,
        k: int = None,
        max_num_nodes: Optional[int] = None,
        add_isolated_nodes: bool = False,
        **kwargs,
    ):
        self.max_num_nodes = max_num_nodes if max_num_nodes is not None else num_nodes
        if num_nodes > self.max_num_nodes:
            raise ValueError("num_nodes cannot be greater than max_num_nodes")

        self.add_isolated_nodes = add_isolated_nodes

        # Determine chain length based on add_isolated_nodes setting
        if k is None:
            if add_isolated_nodes:
                # With isolated nodes: 2 chains of length (num_nodes-2)//2, plus 2 isolated nodes
                self.k = (num_nodes - 2) // 2
                if 2 * self.k + 2 > num_nodes:
                    self.k = (num_nodes - 2) // 2
            else:
                # Without isolated nodes: 2 chains that use all nodes
                self.k = num_nodes // 2
        else:
            self.k = k

        # Validate chain length
        required_nodes = 2 * self.k + (2 if add_isolated_nodes else 0)
        if required_nodes > num_nodes:
            raise ValueError(
                f"Cannot fit 2 chains of length {self.k} with {'2 isolated nodes' if add_isolated_nodes else 'no isolated nodes'} "
                f"in {num_nodes} nodes. Required: {required_nodes}, Available: {num_nodes}"
            )

        super().__init__(self.max_num_nodes, **kwargs)
        self.active_nodes = num_nodes

        # Precompute base matrices
        self._precompute_matrices()

    def _precompute_matrices(self):
        """Precompute the base adjacency and connectivity matrices"""
        adj = torch.zeros((self.active_nodes, self.active_nodes))

        # First chain: 0 to k-1
        for i in range(self.k - 1):
            adj[i, i + 1] = adj[i + 1, i] = 1

        # Second chain: k to 2*k-1
        for i in range(self.k, 2 * self.k - 1):
            adj[i, i + 1] = adj[i + 1, i] = 1

        if self.add_self_loops:
            adj += torch.eye(self.active_nodes)

        # Connectivity matrix: block-diagonal for each chain
        conn = torch.zeros((self.active_nodes, self.active_nodes))
        conn[0 : self.k, 0 : self.k] = 1  # First chain
        conn[self.k : 2 * self.k, self.k : 2 * self.k] = 1  # Second chain

        # Handle remaining nodes
        if self.add_isolated_nodes:
            # Explicitly isolated nodes: only connect to themselves
            for i in range(2 * self.k, self.active_nodes):
                conn[i, i] = 1
        else:
            # Remaining nodes are also isolated (legacy behavior)
            for i in range(2 * self.k, self.active_nodes):
                conn[i, i] = 1

        if self.add_self_loops:
            conn += torch.eye(self.active_nodes)
            conn = (conn > 0).float()

        self.base_adj = adj
        self.base_conn = conn

    def generate_graph(self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
        """For TwoChainsGenerator, we return matrices directly since it's deterministic"""
        # Create padded matrices
        adj_padded = torch.zeros((self.max_num_nodes, self.max_num_nodes))
        conn_padded = torch.zeros((self.max_num_nodes, self.max_num_nodes))

        # Copy base matrices to top-left corner
        adj_padded[: self.active_nodes, : self.active_nodes] = self.base_adj
        conn_padded[: self.active_nodes, : self.active_nodes] = self.base_conn

        # Handle padded nodes
        if self.add_self_loops:
            for i in range(self.active_nodes, self.max_num_nodes):
                adj_padded[i, i] = 1.0

        for i in range(self.active_nodes, self.max_num_nodes):
            conn_padded[i, i] = 1.0

        # Random permutation
        perm = torch.randperm(self.max_num_nodes)
        adj = adj_padded[perm][:, perm]
        conn = conn_padded[perm][:, perm]

        return adj, conn


class TwoVariableChainsGenerator(BaseGraphDataset):
    """
    Generator for graphs with two disjoint chains of variable lengths.
    The first chain has k nodes and the second chain has n-k nodes,
    where k is uniformly sampled from [1, n-1].

    Args:
        num_nodes: Total number of nodes in the graph
        max_num_nodes: Maximum number of nodes for padding (default: same as num_nodes)

    Examples:
        # 8 nodes: might generate chains of length 3 and 5, or 2 and 6, etc.
        # The sum k + (n-k) = n always uses all nodes for the two chains
    """

    def __init__(
        self,
        num_nodes: int,
        max_num_nodes: Optional[int] = None,
        **kwargs,
    ):
        self.max_num_nodes = max_num_nodes if max_num_nodes is not None else num_nodes
        if num_nodes > self.max_num_nodes:
            raise ValueError("num_nodes cannot be greater than max_num_nodes")

        if num_nodes < 2:
            raise ValueError("num_nodes must be at least 2 for two chains")

        super().__init__(self.max_num_nodes, **kwargs)
        self.active_nodes = num_nodes

        # Note: We don't precompute matrices since chain lengths are random

    def generate_graph(self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generate two chains with variable lengths that sum to active_nodes"""
        # Sample k uniformly from [1, active_nodes - 1]
        k1 = random.randint(1, self.active_nodes - 1)
        k2 = self.active_nodes - k1

        # Generate adjacency matrix with two chains
        adj = torch.zeros((self.active_nodes, self.active_nodes))

        # First chain: 0 to k1-1
        for i in range(k1 - 1):
            adj[i, i + 1] = adj[i + 1, i] = 1

        # Second chain: k1 to k1+k2-1
        for i in range(k1, k1 + k2 - 1):
            adj[i, i + 1] = adj[i + 1, i] = 1

        if self.add_self_loops:
            adj += torch.eye(self.active_nodes)

        # Connectivity matrix: block-diagonal for each chain
        conn = torch.zeros((self.active_nodes, self.active_nodes))
        conn[0:k1, 0:k1] = 1  # First chain
        conn[k1:k1+k2, k1:k1+k2] = 1  # Second chain

        if self.add_self_loops:
            conn += torch.eye(self.active_nodes)
            conn = (conn > 0).float()

        # Create padded matrices
        adj_padded = torch.zeros((self.max_num_nodes, self.max_num_nodes))
        conn_padded = torch.zeros((self.max_num_nodes, self.max_num_nodes))

        # Copy matrices to top-left corner
        adj_padded[:self.active_nodes, :self.active_nodes] = adj
        conn_padded[:self.active_nodes, :self.active_nodes] = conn

        # Handle padded nodes
        if self.add_self_loops:
            for i in range(self.active_nodes, self.max_num_nodes):
                adj_padded[i, i] = 1.0

        for i in range(self.active_nodes, self.max_num_nodes):
            conn_padded[i, i] = 1.0

        # Random permutation
        perm = torch.randperm(self.max_num_nodes)
        adj_result = adj_padded[perm][:, perm]
        conn_result = conn_padded[perm][:, perm]

        return adj_result, conn_result


class TwoTreesGenerator(BaseGraphDataset):
    """
    Generator for graphs with two disjoint k-node trees plus optional isolated nodes.

    Args:
        num_nodes: Total number of nodes in the graph
        k: Size of each tree. If None, calculated automatically based on add_isolated_nodes:
           - If add_isolated_nodes=False: k = num_nodes // 2 (uses all nodes for trees)
           - If add_isolated_nodes=True: k = (num_nodes - 2) // 2 (reserves 2 nodes as isolated)
        max_num_nodes: Maximum number of nodes for padding (default: same as num_nodes)
        add_isolated_nodes: If True, reserves some nodes as explicitly isolated. If False,
                           uses all available nodes for trees (default: False)

    Examples:
        # 56 nodes, add_isolated_nodes=False: 2 trees of size 28 each
        # 56 nodes, add_isolated_nodes=True: 2 trees of size 27 each + 2 isolated nodes
    """

    def __init__(
        self,
        num_nodes: int,
        k: int = None,
        max_num_nodes: Optional[int] = None,
        add_isolated_nodes: bool = False,
        **kwargs,
    ):
        self.max_num_nodes = max_num_nodes if max_num_nodes is not None else num_nodes
        if num_nodes > self.max_num_nodes:
            raise ValueError("num_nodes cannot be greater than max_num_nodes")

        self.add_isolated_nodes = add_isolated_nodes

        # Determine tree size based on add_isolated_nodes setting
        if k is None:
            if add_isolated_nodes:
                # With isolated nodes: 2 trees of size (num_nodes-2)//2, plus 2 isolated nodes
                self.k = (num_nodes - 2) // 2
                if 2 * self.k + 2 > num_nodes:
                    self.k = (num_nodes - 2) // 2
            else:
                # Without isolated nodes: 2 trees that use all nodes
                self.k = num_nodes // 2
        else:
            self.k = k

        # Validate tree size
        required_nodes = 2 * self.k + (2 if add_isolated_nodes else 0)
        if required_nodes > num_nodes:
            raise ValueError(
                f"Cannot fit 2 trees of size {self.k} with {'2 isolated nodes' if add_isolated_nodes else 'no isolated nodes'} "
                f"in {num_nodes} nodes. Required: {required_nodes}, Available: {num_nodes}"
            )

        super().__init__(self.max_num_nodes, **kwargs)
        self.active_nodes = num_nodes

        # Note: We don't precompute matrices for trees since they should be random
        # Matrices will be generated in generate_graph method

    def _generate_random_tree(self, nodes):
        """Generate a random tree on the given nodes using Kruskal-like approach"""
        import random

        if len(nodes) <= 1:
            return []

        # Start with the first node
        tree_edges = []
        in_tree = {nodes[0]}
        remaining = set(nodes[1:])

        # Add nodes one by one, connecting each to a random node already in the tree
        while remaining:
            new_node = random.choice(list(remaining))
            parent_node = random.choice(list(in_tree))
            tree_edges.append((new_node, parent_node))
            in_tree.add(new_node)
            remaining.remove(new_node)

        return tree_edges

    def generate_graph(self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
        """For TwoTreesGenerator, we generate new random trees each time"""
        # Generate adjacency matrix with random trees
        adj = torch.zeros((self.active_nodes, self.active_nodes))

        # Generate first tree: nodes 0 to k-1 as a random tree
        if self.k > 1:
            tree_edges = self._generate_random_tree(list(range(self.k)))
            for u, v in tree_edges:
                adj[u, v] = adj[v, u] = 1

        # Generate second tree: nodes k to 2*k-1 as a random tree
        if self.k > 1:
            second_tree_nodes = list(range(self.k, 2 * self.k))
            tree_edges = self._generate_random_tree(second_tree_nodes)
            for u, v in tree_edges:
                adj[u, v] = adj[v, u] = 1

        if self.add_self_loops:
            adj += torch.eye(self.active_nodes)

        # Connectivity matrix: block-diagonal for each tree
        conn = torch.zeros((self.active_nodes, self.active_nodes))
        conn[0 : self.k, 0 : self.k] = 1  # First tree
        conn[self.k : 2 * self.k, self.k : 2 * self.k] = 1  # Second tree

        # Handle remaining nodes
        if self.add_isolated_nodes:
            # Explicitly isolated nodes: only connect to themselves
            for i in range(2 * self.k, self.active_nodes):
                conn[i, i] = 1
        else:
            # Remaining nodes are also isolated (legacy behavior)
            for i in range(2 * self.k, self.active_nodes):
                conn[i, i] = 1

        if self.add_self_loops:
            conn += torch.eye(self.active_nodes)
            conn = (conn > 0).float()

        # Create padded matrices
        adj_padded = torch.zeros((self.max_num_nodes, self.max_num_nodes))
        conn_padded = torch.zeros((self.max_num_nodes, self.max_num_nodes))

        # Copy matrices to top-left corner
        adj_padded[: self.active_nodes, : self.active_nodes] = adj
        conn_padded[: self.active_nodes, : self.active_nodes] = conn

        # Handle padded nodes
        if self.add_self_loops:
            for i in range(self.active_nodes, self.max_num_nodes):
                adj_padded[i, i] = 1.0

        for i in range(self.active_nodes, self.max_num_nodes):
            conn_padded[i, i] = 1.0

        # Random permutation
        perm = torch.randperm(self.max_num_nodes)
        adj_result = adj_padded[perm][:, perm]
        conn_result = conn_padded[perm][:, perm]

        return adj_result, conn_result


class TwoStarsGenerator(BaseGraphDataset):
    """
    Generator for graphs with two disjoint k-node stars plus optional isolated nodes.

    Args:
        num_nodes: Total number of nodes in the graph
        k: Size of each star. If None, calculated automatically based on add_isolated_nodes:
           - If add_isolated_nodes=False: k = num_nodes // 2 (uses all nodes for stars)
           - If add_isolated_nodes=True: k = (num_nodes - 2) // 2 (reserves 2 nodes as isolated)
        max_num_nodes: Maximum number of nodes for padding (default: same as num_nodes)
        add_isolated_nodes: If True, reserves some nodes as explicitly isolated. If False,
                           uses all available nodes for stars (default: False)

    Examples:
        # 56 nodes, add_isolated_nodes=False: 2 stars of size 28 each
        # 56 nodes, add_isolated_nodes=True: 2 stars of size 27 each + 2 isolated nodes
    """

    def __init__(
        self,
        num_nodes: int,
        k: int = None,
        max_num_nodes: Optional[int] = None,
        add_isolated_nodes: bool = False,
        **kwargs,
    ):
        self.max_num_nodes = max_num_nodes if max_num_nodes is not None else num_nodes
        if num_nodes > self.max_num_nodes:
            raise ValueError("num_nodes cannot be greater than max_num_nodes")

        self.add_isolated_nodes = add_isolated_nodes

        # Determine star size based on add_isolated_nodes setting
        if k is None:
            if add_isolated_nodes:
                # With isolated nodes: 2 stars of size (num_nodes-2)//2, plus 2 isolated nodes
                self.k = (num_nodes - 2) // 2
                if 2 * self.k + 2 > num_nodes:
                    self.k = (num_nodes - 2) // 2
            else:
                # Without isolated nodes: 2 stars that use all nodes
                self.k = num_nodes // 2
        else:
            self.k = k

        # Validate star size
        required_nodes = 2 * self.k + (2 if add_isolated_nodes else 0)
        if required_nodes > num_nodes:
            raise ValueError(
                f"Cannot fit 2 stars of size {self.k} with {'2 isolated nodes' if add_isolated_nodes else 'no isolated nodes'} "
                f"in {num_nodes} nodes. Required: {required_nodes}, Available: {num_nodes}"
            )

        super().__init__(self.max_num_nodes, **kwargs)
        self.active_nodes = num_nodes

        # Precompute base matrices
        self._precompute_matrices()

    def _precompute_matrices(self):
        """Precompute the base adjacency and connectivity matrices"""
        adj = torch.zeros((self.active_nodes, self.active_nodes))

        # First star: node 0 is center, connected to nodes 1 to k-1
        for i in range(1, self.k):
            adj[0, i] = adj[i, 0] = 1

        # Second star: node k is center, connected to nodes k+1 to 2*k-1
        center2 = self.k
        for i in range(self.k + 1, 2 * self.k):
            adj[center2, i] = adj[i, center2] = 1

        if self.add_self_loops:
            adj += torch.eye(self.active_nodes)

        # Connectivity matrix: block-diagonal for each star
        conn = torch.zeros((self.active_nodes, self.active_nodes))
        conn[0 : self.k, 0 : self.k] = 1  # First star
        conn[self.k : 2 * self.k, self.k : 2 * self.k] = 1  # Second star

        # Handle remaining nodes
        if self.add_isolated_nodes:
            # Explicitly isolated nodes: only connect to themselves
            for i in range(2 * self.k, self.active_nodes):
                conn[i, i] = 1
        else:
            # Remaining nodes are also isolated (legacy behavior)
            for i in range(2 * self.k, self.active_nodes):
                conn[i, i] = 1

        if self.add_self_loops:
            conn += torch.eye(self.active_nodes)
            conn = (conn > 0).float()

        self.base_adj = adj
        self.base_conn = conn

    def generate_graph(self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
        """For TwoStarsGenerator, we return matrices directly since it's deterministic"""
        # Create padded matrices
        adj_padded = torch.zeros((self.max_num_nodes, self.max_num_nodes))
        conn_padded = torch.zeros((self.max_num_nodes, self.max_num_nodes))

        # Copy base matrices to top-left corner
        adj_padded[: self.active_nodes, : self.active_nodes] = self.base_adj
        conn_padded[: self.active_nodes, : self.active_nodes] = self.base_conn

        # Handle padded nodes
        if self.add_self_loops:
            for i in range(self.active_nodes, self.max_num_nodes):
                adj_padded[i, i] = 1.0

        for i in range(self.active_nodes, self.max_num_nodes):
            conn_padded[i, i] = 1.0

        # Random permutation
        perm = torch.randperm(self.max_num_nodes)
        adj = adj_padded[perm][:, perm]
        conn = conn_padded[perm][:, perm]

        return adj, conn


class UnifiedGraphDataset(Dataset):
    """
    Unified dataset wrapper that can generate different types of graphs based on configuration.

    Args:
        dataset_type (str): Type of dataset to generate. Options:
            - 'erdos_renyi': Standard Erdős-Rényi graphs
            - 'erdos_renyi_two_graphs': Two separate ER subgraphs
            - 'erdos_renyi_medium': ER graphs with at least two large components
            - 'erdos_renyi_hard': Two ER components of size n/2 with optional single connection
            - 'tree_forest': Collection of random trees
            - 'star_forest': Forest of star graphs of varying sizes
            - 'two_cliques': Two separate cliques of similar sizes
            - 'caveman': Caveman graphs (collection of cliques with optional connections)
            - 'sbm': Stochastic Block Model graphs
            - 'two_chains': Two disjoint k-chains with isolated nodes
            - 'two_variable_chains': Two disjoint chains with variable lengths (k and n-k)
            - 'two_trees': Two disjoint k-node trees with isolated nodes
            - 'two_stars': Two disjoint k-node stars with isolated nodes
            - 'one_circle': All nodes arranged in a single circle
            - 'two_degree_3_chains': Extension of two_chains with degree-3 connections
        num_samples (int): Number of graphs to generate
        num_nodes (int): Number of nodes per graph
        on_the_fly (bool): If True, generate graphs on-the-fly during __getitem__.
                          If False, pre-generate all graphs and store in memory.
                          Default: False (pre-generate for backward compatibility)
        **kwargs: Additional parameters specific to each dataset type
    """

    GENERATOR_MAP = {
        "erdos_renyi": ErdosRenyiGenerator,
        "erdos_renyi_two_graphs": ErdosRenyiTwoGraphsGenerator,
        "erdos_renyi_medium": ErdosRenyiMediumVariantGenerator,
        "erdos_renyi_hard": ErdosRenyiHardVariantGenerator,
        "tree_forest": TreeForestGenerator,
        "star_forest": StarForestGenerator,
        "two_cliques": TwoCliquesGenerator,
        "caveman": CavemanGraphGenerator,
        "sbm": SBMGenerator,
        "two_chains": TwoChainsGenerator,
        "two_variable_chains": TwoVariableChainsGenerator,
        "two_trees": TwoTreesGenerator,
        "two_stars": TwoStarsGenerator,
        "one_circle": OneCircleGenerator,
        "two_degree_3_chains": TwoDegree3ChainsGenerator,
    }

    def __init__(
        self,
        dataset_type: str,
        num_samples: int,
        num_nodes: int,
        on_the_fly: bool = False,
        verbose: bool = True,
        **kwargs,
    ):

        if dataset_type not in self.GENERATOR_MAP:
            raise ValueError(
                f"Unknown dataset_type: {dataset_type}. "
                f"Available types: {list(self.GENERATOR_MAP.keys())}"
            )

        self.dataset_type = dataset_type
        self.num_samples = num_samples
        self.num_nodes = num_nodes
        self.on_the_fly = on_the_fly
        self.verbose = verbose
        self.generator_kwargs = kwargs

        # Initialize the appropriate generator
        self.generator = self.GENERATOR_MAP[dataset_type](num_nodes=num_nodes, **kwargs)

        # Generate all graphs if not on_the_fly
        if not self.on_the_fly:
            self.graphs = []
            self._generate_all_graphs(**kwargs)
        else:
            self.graphs = None
            if self.verbose:
                print(
                    f"\nInitialized {self.dataset_type} dataset for on-the-fly generation:"
                )
                print(f">>> Number of samples: {self.num_samples}")
                print(f">>> Number of nodes: {self.num_nodes}")
                print(f">>> Mode: On-the-fly generation")
                for key, value in kwargs.items():
                    print(f">>> {key}: {value}")

    def _generate_all_graphs(self, **kwargs):
        """Generate all graphs and store them"""
        if self.verbose:
            print(f"\nGenerating {self.dataset_type} graphs. Parameters:")
            print(f">>> Number of samples: {self.num_samples}")
            print(f">>> Number of nodes: {self.num_nodes}")
            for key, value in kwargs.items():
                print(f">>> {key}: {value}")

        connectivity_sum = 0
        total_possible_connections = 0

        desc = f"Generating {self.dataset_type} graphs"
        for _ in tqdm(range(self.num_samples), desc=desc, disable=not self.verbose):

            if self.dataset_type in [
                "two_chains",
                "two_variable_chains",
                "two_trees",
                "two_stars",
                "one_circle",
                "two_degree_3_chains",
            ]:
                # These generators return matrices directly
                adj_matrix, connectivity_matrix = self.generator.generate_graph(
                    **kwargs
                )
            else:
                # Other generators return NetworkX graphs
                graph = self.generator.generate_graph(**kwargs)
                adj_matrix, connectivity_matrix = self.generator.compute_matrices(graph)

            self.graphs.append((adj_matrix, connectivity_matrix))

            # Track connectivity statistics
            connectivity_sum += connectivity_matrix.sum() - self.num_nodes
            total_possible_connections += (
                self.num_nodes * self.num_nodes - self.num_nodes
            )

        if self.verbose and total_possible_connections > 0:
            connectivity_ratio = connectivity_sum / total_possible_connections
            print(f">>> Connectivity ratio: {connectivity_ratio:.3f}")

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        """Get a sample from the dataset"""
        if idx >= self.num_samples:
            raise IndexError(
                f"Index {idx} out of range for dataset of size {self.num_samples}"
            )

        if self.on_the_fly:
            # Generate graph on-the-fly
            if self.dataset_type in [
                "two_chains",
                "two_variable_chains",
                "two_trees",
                "two_stars",
                "one_circle",
                "two_degree_3_chains",
            ]:
                # These generators return matrices directly
                adj_matrix, connectivity_matrix = self.generator.generate_graph(
                    **self.generator_kwargs
                )
            else:
                # Other generators return NetworkX graphs
                graph = self.generator.generate_graph(**self.generator_kwargs)
                adj_matrix, connectivity_matrix = self.generator.compute_matrices(graph)

            return adj_matrix, connectivity_matrix
        else:
            # Return pre-generated graph
            return self.graphs[idx]

    def get_info(self) -> Dict[str, Any]:
        """Get information about the dataset"""
        info = {
            "dataset_type": self.dataset_type,
            "num_samples": self.num_samples,
            "num_nodes": self.num_nodes,
            "on_the_fly": self.on_the_fly,
            "generator_class": self.generator.__class__.__name__,
        }

        # Add generator-specific parameters
        if hasattr(self.generator, "p"):
            info["p"] = self.generator.p
        if hasattr(self.generator, "sample_p"):
            info["sample_p"] = self.generator.sample_p
        if hasattr(self.generator, "p_range"):
            info["p_range"] = self.generator.p_range
        if hasattr(self.generator, "k"):
            info["k"] = self.generator.k
        if hasattr(self.generator, "add_isolated_nodes"):
            info["add_isolated_nodes"] = self.generator.add_isolated_nodes
        if hasattr(self.generator, "num_communities"):
            info["num_communities"] = self.generator.num_communities

        return info


def create_dataset(dataset_type: str, **kwargs) -> UnifiedGraphDataset:
    """
    Convenience function to create a dataset with type checking and validation.

    Example usage:
        # Erdős-Rényi dataset (pre-generated)
        dataset = create_dataset('erdos_renyi', num_samples=1000, num_nodes=50, p=0.1)

        # Erdős-Rényi dataset (on-the-fly generation for memory efficiency)
        dataset = create_dataset('erdos_renyi', num_samples=1000, num_nodes=50, p=0.1, on_the_fly=True)

        # Erdős-Rényi hard variant (two components with optional connection)
        dataset = create_dataset('erdos_renyi_hard', num_samples=500, num_nodes=32, p=0.15, connect_prob=0.5)

        # Tree forest dataset
        dataset = create_dataset('tree_forest', num_samples=500, num_nodes=32, min_tree_size=3, max_tree_size=10)

        # Star forest dataset
        dataset = create_dataset('star_forest', num_samples=500, num_nodes=32, min_star_size=3, max_star_size=8)

        # Two cliques dataset
        dataset = create_dataset('two_cliques', num_samples=500, num_nodes=32, connect_prob=0.2, size_variation=0.1)

        # Caveman dataset
        dataset = create_dataset('caveman', num_samples=500, num_nodes=32, k=5, connect_prob=0.5)

        # Two chains dataset (default: uses all nodes for chains)
        dataset = create_dataset('two_chains', num_samples=500, num_nodes=32, k=8)

        # Two chains dataset with isolated nodes
        dataset = create_dataset('two_chains', num_samples=500, num_nodes=56, add_isolated_nodes=True)
        # This creates 2 chains of length 27 each + 2 isolated nodes

        # Two variable chains dataset (chain lengths are randomly sampled)
        dataset = create_dataset('two_variable_chains', num_samples=500, num_nodes=32)
        # This creates 2 chains where first chain has k nodes (sampled from [1, 31]) and second has (32-k) nodes

        # Two trees dataset (default: uses all nodes for trees)
        dataset = create_dataset('two_trees', num_samples=500, num_nodes=32, k=8)

        # Two trees dataset with isolated nodes
        dataset = create_dataset('two_trees', num_samples=500, num_nodes=56, add_isolated_nodes=True)
        # This creates 2 trees of size 27 each + 2 isolated nodes

        # Two stars dataset (default: uses all nodes for stars)
        dataset = create_dataset('two_stars', num_samples=500, num_nodes=32, k=8)

        # Two stars dataset with isolated nodes
        dataset = create_dataset('two_stars', num_samples=500, num_nodes=56, add_isolated_nodes=True)
        # This creates 2 stars of size 27 each + 2 isolated nodes

        # SBM dataset
        dataset = create_dataset('sbm', num_samples=800, num_nodes=64,
                                p_intra=0.3, p_inter=0.05, num_communities=4)

        # One circle dataset (all nodes in a circle)
        dataset = create_dataset('one_circle', num_samples=500, num_nodes=32)

        # Two degree-3 chains dataset (extension of two_chains with higher connectivity)
        dataset = create_dataset('two_degree_3_chains', num_samples=500, num_nodes=32, k=8)

        # Two degree-3 chains dataset with isolated nodes
        dataset = create_dataset('two_degree_3_chains', num_samples=500, num_nodes=56, add_isolated_nodes=True)
    """
    required_params = ["num_samples", "num_nodes"]
    for param in required_params:
        if param not in kwargs:
            raise ValueError(f"Required parameter '{param}' not provided")

    return UnifiedGraphDataset(dataset_type=dataset_type, **kwargs)


# Legacy class aliases for backward compatibility
ErdosRenyiGraphDataset = lambda **kwargs: create_dataset("erdos_renyi", **kwargs)
ErdosRenyiTwoGraphsDataset = lambda **kwargs: create_dataset(
    "erdos_renyi_two_graphs", **kwargs
)
ErdosRenyiVariant2 = lambda **kwargs: create_dataset(
    "erdos_renyi_medium", **kwargs
)  # Backward compatibility: old "hard" is now "medium"
SBMGraphDataset = lambda **kwargs: create_dataset("sbm", **kwargs)
TwoChainsDataset = lambda **kwargs: create_dataset("two_chains", **kwargs)

# New dataset aliases
TreeForestDataset = lambda **kwargs: create_dataset("tree_forest", **kwargs)
StarForestDataset = lambda **kwargs: create_dataset("star_forest", **kwargs)
TwoCliquesDataset = lambda **kwargs: create_dataset("two_cliques", **kwargs)
CavemanDataset = lambda **kwargs: create_dataset("caveman", **kwargs)
TwoTreesDataset = lambda **kwargs: create_dataset("two_trees", **kwargs)
TwoStarsDataset = lambda **kwargs: create_dataset("two_stars", **kwargs)
TwoVariableChainsDataset = lambda **kwargs: create_dataset("two_variable_chains", **kwargs)
ErdosRenyiHardDataset = lambda **kwargs: create_dataset("erdos_renyi_hard", **kwargs)
ErdosRenyiMediumDataset = lambda **kwargs: create_dataset(
    "erdos_renyi_medium", **kwargs
)

# Legacy aliases for backward compatibility with old names
RandomForestDataset = lambda **kwargs: create_dataset(
    "tree_forest", **kwargs
)  # Backward compatibility
StarGraphDataset = lambda **kwargs: create_dataset(
    "star_forest", **kwargs
)  # Backward compatibility


def create_mixed_dataset(
    dataset_configs: List[Tuple[str, Dict[str, Any]]], shuffle: bool = True
) -> ConcatDataset:
    """
    Create a mixed dataset by concatenating multiple dataset types.

    Args:
        dataset_configs: List of (dataset_type, config) tuples
        shuffle: Whether to shuffle the combined dataset

    Returns:
        ConcatDataset containing all specified datasets

    Example:
        mixed_dataset = create_mixed_dataset([
            ("two_chains", {"num_samples": 1000, "num_nodes": 32, "k": 10}),
            ("erdos_renyi", {"num_samples": 1000, "num_nodes": 32, "p": 0.15}),
            ("sbm", {"num_samples": 500, "num_nodes": 32, "p_intra": 0.4, "p_inter": 0.05})
        ])
    """
    datasets = []

    for dataset_type, config in dataset_configs:
        dataset = create_dataset(dataset_type, **config)
        datasets.append(dataset)
        print(f"Created {dataset_type} dataset: {len(dataset)} samples")

    mixed_dataset = ConcatDataset(datasets)
    print(f"Total mixed dataset size: {len(mixed_dataset)} samples")

    if shuffle:
        # Create a wrapper that shuffles indices
        class ShuffledDataset(Dataset):
            def __init__(self, dataset):
                self.dataset = dataset
                self.indices = torch.randperm(len(dataset)).tolist()

            def __len__(self):
                return len(self.dataset)

            def __getitem__(self, idx):
                return self.dataset[self.indices[idx]]

        mixed_dataset = ShuffledDataset(mixed_dataset)

    return mixed_dataset


if __name__ == "__main__":
    # Example usage demonstrations
    print("=== Unified Graph Dataset Examples ===\n")

    # Example 1: Erdős-Rényi dataset (equivalent to original ErdosRenyiGraphDataset)
    print("1. Creating Erdős-Rényi dataset (pre-generated)...")
    er_dataset = create_dataset(
        "erdos_renyi", num_samples=100, num_nodes=32, p=0.1, sample_p=False
    )
    adj, conn = er_dataset[0]
    print(f"   Shape: {adj.shape}, Connectivity: {conn.sum().item()}")

    # Example 1b: Same dataset but with on-the-fly generation
    print("\n1b. Creating Erdős-Rényi dataset (on-the-fly)...")
    er_dataset_otf = create_dataset(
        "erdos_renyi",
        num_samples=100,
        num_nodes=32,
        p=0.1,
        sample_p=False,
        on_the_fly=True,
    )
    adj, conn = er_dataset_otf[0]
    print(f"    Shape: {adj.shape}, Connectivity: {conn.sum().item()}")
    print(f"    Memory usage: Much lower for large datasets!")

    # Example 2: Two chains dataset
    print("\n2. Creating Two Chains dataset...")
    chains_dataset = create_dataset("two_chains", num_samples=50, num_nodes=16, k=5)
    adj, conn = chains_dataset[0]
    print(f"   Shape: {adj.shape}, Connectivity: {conn.sum().item()}")

    # Example 2b: Two variable chains dataset
    print("\n2b. Creating Two Variable Chains dataset...")
    var_chains_dataset = create_dataset("two_variable_chains", num_samples=50, num_nodes=16)
    adj, conn = var_chains_dataset[0]
    print(f"    Shape: {adj.shape}, Connectivity: {conn.sum().item()}")
    print(f"    Note: Chain lengths are randomly sampled from [1, n-1] and [1, n-1] such that they sum to n")

    # Example 3: SBM dataset
    print("\n3. Creating SBM dataset...")
    sbm_dataset = create_dataset(
        "sbm",
        num_samples=75,
        num_nodes=24,
        p_intra=0.3,
        p_inter=0.05,
        num_communities=3,
    )
    adj, conn = sbm_dataset[0]
    print(f"   Shape: {adj.shape}, Connectivity: {conn.sum().item()}")

    # Example 4: Hard variant dataset (new implementation: two ER components)
    print("\n4. Creating Hard Variant dataset...")
    hard_dataset = create_dataset(
        "erdos_renyi_hard",
        num_samples=60,
        num_nodes=40,
        p=0.1,
        sample_p=True,
        p_range=(0.08, 0.15),
    )
    adj, conn = hard_dataset[0]
    print(f"   Shape: {adj.shape}, Connectivity: {conn.sum().item()}")

    # Example 5: Two graphs dataset (equivalent to original ErdosRenyiTwoGraphsDataset)
    print("\n5. Creating Two Graphs dataset...")
    two_graphs_dataset = create_dataset(
        "erdos_renyi_two_graphs", num_samples=40, num_nodes=32, p=0.15, connect_prob=0.5
    )
    adj, conn = two_graphs_dataset[0]
    print(f"   Shape: {adj.shape}, Connectivity: {conn.sum().item()}")

    # Example 6: Tree Forest dataset
    print("\n6. Creating Tree Forest dataset...")
    forest_dataset = create_dataset(
        "tree_forest", num_samples=50, num_nodes=32, min_tree_size=3, max_tree_size=8
    )
    adj, conn = forest_dataset[0]
    print(f"   Shape: {adj.shape}, Connectivity: {conn.sum().item()}")

    # Example 7: Star Forest dataset
    print("\n7. Creating Star Forest dataset...")
    star_dataset = create_dataset(
        "star_forest", num_samples=50, num_nodes=32, min_star_size=3, max_star_size=8
    )
    adj, conn = star_dataset[0]
    print(f"   Shape: {adj.shape}, Connectivity: {conn.sum().item()}")

    # Example 8: Two Cliques dataset
    print("\n8. Creating Two Cliques dataset...")
    cliques_dataset = create_dataset(
        "two_cliques",
        num_samples=50,
        num_nodes=32,
        connect_prob=0.3,
        size_variation=0.1,
    )
    adj, conn = cliques_dataset[0]
    print(f"   Shape: {adj.shape}, Connectivity: {conn.sum().item()}")

    # Example 9: Caveman dataset
    print("\n9. Creating Caveman dataset...")
    caveman_dataset = create_dataset(
        "caveman",
        num_samples=50,
        num_nodes=32,
        k=5,
        connect_prob=0.5,
    )
    adj, conn = caveman_dataset[0]
    print(f"   Shape: {adj.shape}, Connectivity: {conn.sum().item()}")

    # Example 10: Two Trees dataset
    print("\n10. Creating Two Trees dataset...")
    two_trees_dataset = create_dataset("two_trees", num_samples=50, num_nodes=32, k=8)
    adj, conn = two_trees_dataset[0]
    print(f"   Shape: {adj.shape}, Connectivity: {conn.sum().item()}")

    # Example 11: Two Stars dataset
    print("\n11. Creating Two Stars dataset...")
    two_stars_dataset = create_dataset("two_stars", num_samples=50, num_nodes=32, k=8)
    adj, conn = two_stars_dataset[0]
    print(f"   Shape: {adj.shape}, Connectivity: {conn.sum().item()}")

    print(f"\n=== Dataset Info ===")
    for i, dataset in enumerate(
        [
            er_dataset,
            er_dataset_otf,
            chains_dataset,
            var_chains_dataset,
            sbm_dataset,
            hard_dataset,
            two_graphs_dataset,
            forest_dataset,
            star_dataset,
            cliques_dataset,
            caveman_dataset,
            two_trees_dataset,
            two_stars_dataset,
        ],
        1,
    ):
        info = dataset.get_info()
        mode = "On-the-fly" if info.get("on_the_fly", False) else "Pre-generated"
        print(f"{i}. {info['dataset_type']} ({mode}): {info}")

    # Example of using legacy class names (for backward compatibility)
    print(f"\n=== Legacy Compatibility ===")
    legacy_dataset = ErdosRenyiGraphDataset(num_samples=10, num_nodes=16, p=0.1)
    print(f"Legacy ErdosRenyiGraphDataset works: {len(legacy_dataset)} samples")

    # Example of merged datasets with mixed modes
    print(f"\n=== Merged Datasets ===")
    mixed_dataset = create_mixed_dataset(
        [
            (
                "two_chains",
                {"num_samples": 500, "num_nodes": 16, "k": 6, "on_the_fly": True},
            ),
            (
                "erdos_renyi",
                {"num_samples": 500, "num_nodes": 16, "p": 0.15, "on_the_fly": False},
            ),
        ]
    )
    print(f"Mixed dataset created: {len(mixed_dataset)} total samples")
    print(
        f"Note: Mixed datasets can combine both pre-generated and on-the-fly datasets!"
    )
