import os
import json
import math
from pathlib import Path
import pickle
import shutil
import random
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import pandas as pd
from torch_geometric.data import Data, InMemoryDataset
from collections import defaultdict, Counter
from torch_geometric.utils import degree
from torch_geometric.transforms import OneHotDegree
from torch_geometric.utils import to_undirected
import pynauty  
from torch_geometric.datasets import TUDataset, Planetoid  

# Set random seed for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(42)

def select_optimal_factors(n_nodes, min_size=5):
        # Choose optimal factor sizes for graph products
        options = []
        max_factor = min(50, n_nodes)
        
        for n1 in range(min_size, max_factor + 1):
            for n2 in range(min_size, max_factor + 1):
                if n1 * n2 <= n_nodes:
                    diff = n_nodes - (n1 * n2)
                    options.append((n1, n2, diff))
        
        if not options:
            return min_size, min_size
        
        options.sort(key=lambda x: x[2])
        
        top_count = max(1, len(options) // 10)
        top_options = options[:top_count]
    
        n1, n2, _ = random.choice(top_options)
        return n1, n2 

class SymmetryDataset(InMemoryDataset):
    """Symmetry classification dataset generator for graph data."""
    
    def __init__(self, root, mode='train', transform=None, pre_transform=None, pre_filter=None, 
                 use_degree_features=True, min_nodes=None, max_nodes=None, samples_per_class=None):
        self.mode = mode
        self.use_degree_features = use_degree_features
        self.real_graph_dir = Path(root) / "real_base_graphs"
        self.real_graph_dir.mkdir(parents=True, exist_ok=True)
        self.root=root
        self.real_graph_usage = {
            'symmetric': {},    
            'asymmetric': {}    
        }
        
        if min_nodes is None or max_nodes is None or samples_per_class is None:
            if mode == 'train':
                self.min_nodes, self.max_nodes = 30, 60
                self.samples_per_class = 1000
            elif mode == 'val':
                self.min_nodes, self.max_nodes = 30, 60
                self.samples_per_class = 100
            else:  # test
                self.min_nodes, self.max_nodes = 30, 60
                self.samples_per_class = 50
        else:
            self.min_nodes = min_nodes
            self.max_nodes = max_nodes
            self.samples_per_class = samples_per_class
            
        self._num_classes = 2 
        self.class_names = ["Asymmetric", "Symmetric"]
        
        # Symmetry type
        self.sym_graph_types = [
            'cayley_cyclic',               
            'bipartite_cover',            
            'cartesian_product',           
            'cartesian_product_with_real_graph',  
            'real_data_cover'             
        ]

        # Non-symmetry type
        self.asym_graph_types = [
            'perturbed_asymmetric',        
            'cartesian_product_with_real_graph' 
]
        
        self.real_data_graphs = self._load_real_data_graphs()
        
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])
    
    @property
    def raw_file_names(self):
        return [f'dummy_{self.mode}.txt']
    
    @property
    def processed_file_names(self):
        return [f'symmetry_{self.mode}.pt']
    
    def download(self):
        if not os.path.exists(self.raw_dir):
            os.makedirs(self.raw_dir)
        with open(os.path.join(self.raw_dir, f'dummy_{self.mode}.txt'), 'w') as f:
            f.write('dummy')
    
    def _load_real_data_graphs(self):
        cache_dir = Path(self.root).parent / "real_base_graphs"
        cache_dir.mkdir(parents=True, exist_ok=True)
        cache_file = cache_dir / "cached_graphs.pkl"
        
        # Check if cache exists
        if cache_file.exists():
            try:
                with open(cache_file, 'rb') as f:
                    base_graphs = pickle.load(f)
                print(f"Loaded {len(base_graphs)} base graphs from cache")
                return base_graphs
            except Exception as e:
                print(f"Failed to read cache: {e}, will reload data")
        
        base_graphs = []
        
        size_counts = {size: 0 for size in range(5, 51)}
        
        try:
            mutag_dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG')
            print(f"Successfully loaded MUTAG dataset with {len(mutag_dataset)} graphs")
            
            for data in mutag_dataset:
                G = nx.Graph()
                edge_index = data.edge_index.numpy()
                for i in range(edge_index.shape[1]):
                    G.add_edge(edge_index[0, i], edge_index[1, i])
                
                size = G.number_of_nodes()
                if 5 <= size <= 50:
                    self._save_real_graph(G, f"mutag_{len(base_graphs)}")
                    base_graphs.append(G)
                    size_counts[size] = size_counts.get(size, 0) + 1
        except Exception as e:
            print(f"Failed to load MUTAG dataset: {e}")
        
        try:
            cora_dataset = Planetoid(root='/tmp/Cora', name='Cora')
            cora_data = cora_dataset[0]
            print(f"Successfully loaded Cora dataset with {cora_data.num_nodes} nodes")
            
            cora_graph = nx.Graph()
            edge_index = cora_data.edge_index.numpy()
            for i in range(edge_index.shape[1]):
                cora_graph.add_edge(edge_index[0, i], edge_index[1, i])
            
            target_per_size = 10  
            target_ppr_per_size = 10  
            target_sage_per_size = 10  
            
            sizes_needed = [size for size in range(5, 51) if size_counts.get(size, 0) < target_per_size]
            print(f"Graph sizes needed to supplement: {len(sizes_needed)}")
            
            # 1. BFS
            for target_size in sizes_needed:
                attempts = 0
                while size_counts.get(target_size, 0) < target_per_size and attempts < 50:
                    start_node = random.choice(list(cora_graph.nodes()))
                    visited = {start_node}
                    queue = [start_node]
                    
                    while len(visited) < target_size and queue:
                        current = queue.pop(0)
                        for neighbor in cora_graph.neighbors(current):
                            if neighbor not in visited:
                                visited.add(neighbor)
                                queue.append(neighbor)
                                if len(visited) >= target_size:
                                    break
                    
                    if len(visited) == target_size:
                        subgraph = cora_graph.subgraph(visited).copy()
                        if nx.is_connected(subgraph):
                            subgraph = nx.convert_node_labels_to_integers(subgraph)
                            self._save_real_graph(subgraph, f"cora_bfs_size{target_size}_{size_counts.get(target_size, 0)}")
                            base_graphs.append(subgraph)
                            size_counts[target_size] = size_counts.get(target_size, 0) + 1
                    
                    attempts += 1
            
            # 2. PPR (Personalized PageRank)
            print("Using PPR sampling to generate additional subgraphs...")
            for target_size in range(5, 51):
                ppr_count = 0
                attempts = 0
                while ppr_count < target_ppr_per_size and attempts < 50:
                    start_node = random.choice(list(cora_graph.nodes()))
                    
                    try:
                        personalization = {node: 0.0 for node in cora_graph.nodes()}
                        personalization[start_node] = 1.0
                        
                        ppr_scores = nx.pagerank(cora_graph, alpha=0.2, personalization=personalization, max_iter=100)
                        
                        top_nodes = sorted(ppr_scores.items(), key=lambda x: x[1], reverse=True)[:target_size]
                        visited = {node for node, _ in top_nodes}
                        
                        if len(visited) == target_size:
                            subgraph = cora_graph.subgraph(visited).copy()
                            if nx.is_connected(subgraph):
                                subgraph = nx.convert_node_labels_to_integers(subgraph)
                                self._save_real_graph(subgraph, f"cora_ppr_size{target_size}_{ppr_count}")
                                base_graphs.append(subgraph)
                                size_counts[target_size] = size_counts.get(target_size, 0) + 1
                                ppr_count += 1
                    except:
                        pass  
                    
                    attempts += 1
            
            # 3. Using GraphSAGE-like sampling method
            print("Using GraphSAGE sampling method to generate additional subgraphs...")
            for target_size in range(5, 51):
                sage_count = 0
                attempts = 0
                while sage_count < target_sage_per_size and attempts < 50:
                    start_node = random.choice(list(cora_graph.nodes()))
                    visited = {start_node}
                    frontier = [start_node]
                    
                    while len(visited) < target_size and frontier:
                        next_frontier = []
                        
                        for node in frontier:
                            neighbors = list(cora_graph.neighbors(node))
                            unvisited = [n for n in neighbors if n not in visited]
                            
                            if unvisited:
                                sample_size = min(5, len(unvisited), target_size - len(visited))
                                if sample_size > 0:
                                    sampled = random.sample(unvisited, sample_size)
                                    for s in sampled:
                                        visited.add(s)
                                        next_frontier.append(s)
                                        
                                        if len(visited) >= target_size:
                                            break
                            
                            if len(visited) >= target_size:
                                break
                                
                        frontier = next_frontier
                    
                    if len(visited) == target_size:
                        subgraph = cora_graph.subgraph(visited).copy()
                        if nx.is_connected(subgraph):
                            subgraph = nx.convert_node_labels_to_integers(subgraph)
                            self._save_real_graph(subgraph, f"cora_sage_size{target_size}_{sage_count}")
                            base_graphs.append(subgraph)
                            size_counts[target_size] = size_counts.get(target_size, 0) + 1
                            sage_count += 1
                    
                    attempts += 1
        except Exception as e:
            print(f"Failed to load Cora dataset: {e}")
        
        # Print size distribution statistics
        print("\nReal graph library size distribution:")
        for range_start in range(5, 51, 5):
            range_end = range_start + 4
            count = sum(size_counts.get(size, 0) for size in range(range_start, range_end + 1))
            print(f"  {range_start}-{range_end} nodes: {count} graphs")
        
        try:
            with open(cache_file, 'wb') as f:
                pickle.dump(base_graphs, f)
            print(f"Saved {len(base_graphs)} base graphs to cache")
        except Exception as e:
            print(f"Failed to save cache: {e}")
        
        # Return loaded graphs
        print(f"Loaded a total of {len(base_graphs)} base graphs")
        return base_graphs
    
    def _save_real_graph(self, G, base_name):
        """Save graph structure and visualization"""
        try:
            # Save GraphML file
            graphml_path = self.real_graph_dir / f"{base_name}.graphml"
            nx.write_graphml(G, graphml_path)
            
            # Generate visualization
            self._visualize_with_kk(G, base_name)
            
        except Exception as e:
            print(f"Failed to save graph {base_name}: {str(e)}")

    def _visualize_with_kk(self, G, base_name):
        """Visualization using Kamada-Kawai layout"""
        plt.figure(figsize=(8, 8))
        
        try:
            # Try KK layout (requires connected graph)
            pos = nx.kamada_kawai_layout(G)
        except nx.NetworkXException:
            # Use spring layout for non-connected graphs
            pos = nx.spring_layout(G, seed=42)
            
        nx.draw_networkx_nodes(G, pos, node_size=100, alpha=0.8)
        nx.draw_networkx_edges(G, pos, width=1.0, alpha=0.5)
        plt.title(f"{base_name} ({G.number_of_nodes()} nodes)")
        
        # Save image
        img_path = self.real_graph_dir / f"{base_name}.png"
        plt.savefig(img_path, dpi=150, bbox_inches='tight')
        plt.close()
    
    def verify_symmetry(self, G):
        G = nx.convert_node_labels_to_integers(G, first_label=0)
        
        # Convert networkx graph to pynauty format
        n = G.number_of_nodes()
        
        # Create pynauty graph
        pyn_graph = pynauty.Graph(n)
        
        # Add edges
        edges_by_node = [[] for _ in range(n)]
        for u, v in G.edges():
            if u != v:  # Ignore self-loops
                edges_by_node[u].append(v)
                edges_by_node[v].append(u)
        
        # Connect vertices
        for v in range(n):
            if edges_by_node[v]:
                pyn_graph.connect_vertex(v, edges_by_node[v])
        
        # Calculate automorphism group
        aut_group = pynauty.autgrp(pyn_graph)
        
        # Extract group size information
        grpsize1 = aut_group[1]  # Second element is the base size
        grpsize2 = aut_group[2]  # Third element is the exponent
        
        # Determine if the automorphism group is non-trivial (size > 1)
        # If exponent > 0 or base size > 1, then group size > 1
        return grpsize2 > 0 or grpsize1 > 1
    
    def process(self):
        data_list = []
        metadata = {
            'dataset_info': {
                'mode': self.mode,
                'min_nodes': self.min_nodes,
                'max_nodes': self.max_nodes,
                'samples_per_class': self.samples_per_class,
            },
            'generation_types': {
                'symmetric': self.sym_graph_types,
                'asymmetric': self.asym_graph_types
            },
            'graphs': []  # Store generation info for each graph
        }
        
        # Generate symmetric graphs
        print(f"Generating {self.mode} set - Symmetric graphs:")
        samples_per_sym_type = self.samples_per_class // len(self.sym_graph_types)
        sym_remainder = self.samples_per_class % len(self.sym_graph_types)
        
        graph_idx = 0  # To track graph index
        
        for idx, sym_type in enumerate(self.sym_graph_types):
            target_samples = samples_per_sym_type + (sym_remainder if idx == 0 else 0)
            successful_samples = 0
            
            pbar = tqdm(total=target_samples, desc=f"Generating {sym_type} symmetric graphs")
            while successful_samples < target_samples:
                n_nodes = random.randint(self.min_nodes, self.max_nodes)
                
                # Generate graph and get generation info
                G, generation_info = self.generate_symmetric_graph(n_nodes, method=sym_type)
                
                # Verify symmetry
                is_sym = self.verify_symmetry(G)
                # print(is_sym)
                generation_info['verified_symmetric'] = is_sym
                generation_info['graph_idx'] = graph_idx  # Add graph index
                
                if is_sym:
                    # Convert to PyG data format
                    edge_index = torch.tensor(list(G.edges)).t().contiguous()
                    edge_index = to_undirected(edge_index)
                    x = torch.ones(G.number_of_nodes(), 1)
                    y = torch.tensor([1], dtype=torch.long)
                    
                    data = Data(x=x, edge_index=edge_index, y=y)
                    data_list.append(data)
                    metadata['graphs'].append(generation_info)
                    
                    successful_samples += 1
                    graph_idx += 1
                    pbar.update(1)
            
            pbar.close()
        
        print(f"\nGenerating {self.mode} set - Asymmetric graphs:")
        total_asymmetric_samples = self.samples_per_class

        # Allocate 60% of samples to perturbed_asymmetric
        perturbed_samples = int(total_asymmetric_samples * 0.6)

        # Other methods share the remaining 40%
        other_methods = [m for m in self.asym_graph_types if m != 'perturbed_asymmetric']
        samples_per_other = (total_asymmetric_samples - perturbed_samples) // len(other_methods)
        remainder = (total_asymmetric_samples - perturbed_samples) % len(other_methods)

        # Allocate samples for each method
        method_samples = {m: samples_per_other for m in other_methods}
        method_samples['perturbed_asymmetric'] = perturbed_samples

        # Distribute remainder to first other method
        if remainder > 0 and other_methods:
            method_samples[other_methods[0]] += remainder

        # Generate samples for each method
        for asym_type, target_samples in method_samples.items():
            successful_samples = 0
            
            pbar = tqdm(total=target_samples, desc=f"Generating {asym_type} asymmetric graphs")
            while successful_samples < target_samples:
                n_nodes = random.randint(self.min_nodes, self.max_nodes)                
                G, generation_info = self.generate_asymmetric_graph(n_nodes, method=asym_type)
                
                is_sym = self.verify_symmetry(G)
                # print(is_sym)
                generation_info['verified_symmetric'] = is_sym
                generation_info['graph_idx'] = graph_idx  # Add graph index
                if not is_sym:
                    edge_index = torch.tensor(list(G.edges)).t().contiguous()
                    edge_index = to_undirected(edge_index)
                    x = torch.ones(G.number_of_nodes(), 1)
                    y = torch.tensor([0], dtype=torch.long)
                    
                    data = Data(x=x, edge_index=edge_index, y=y)
                    data_list.append(data)
                    metadata['graphs'].append(generation_info)
                    
                    successful_samples += 1
                    graph_idx += 1
                    pbar.update(1)
            
            pbar.close()
        
        # Add dataset statistics
        metadata['dataset_stats'] = {
            'total_graphs': len(data_list),
            'symmetric_graphs': sum(1 for d in data_list if d.y.item() == 1),
            'asymmetric_graphs': sum(1 for d in data_list if d.y.item() == 0),
        }
        
        # Save metadata
        metadata_path = os.path.join(self.processed_dir, f'metadata_{self.mode}.json')
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        
        if self.pre_filter is not None:
            data_list = [d for d in data_list if self.pre_filter(d)]

        if self.use_degree_features:
            max_degree = 100
            degree_transform = OneHotDegree(max_degree)
            data_list = [degree_transform(d) for d in data_list]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(d) for d in data_list]

        
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

    def generate_symmetric_graph(self, n_nodes, method='bipartite_cover'):
        """Generate symmetric graph using specified method"""
        generation_info = {
            'type': method,
            'target_nodes': n_nodes,
        }
        
        if method == 'cayley_cyclic':
            G = self.generate_cyclic_cayley_graph(n_nodes)
            generation_info.update({
                'group_size': max(8, min(n_nodes // 2, 20)),
                'method_details': 'cyclic_group_cayley'
            })
        elif method == 'bipartite_cover':
            G = self.generate_bipartite_double_cover(n_nodes)
            generation_info['method_details'] = 'bipartite_double_cover'

        elif method == 'cartesian_product':
            G = self.generate_cartesian_product(n_nodes)
            generation_info['method_details'] = 'cartesian_product'

        elif method == 'cartesian_product_with_real_graph':
            G = self.generate_cartesian_product(n_nodes,real_graph=True)
            generation_info['method_details'] = 'cartesian_product_with_real_graph'

        elif method == 'real_data_cover':
            G = self.generate_real_data_cover(n_nodes)
            generation_info['method_details'] = 'real_data_cover'
        else:
            G = self.generate_bipartite_double_cover(n_nodes)
            generation_info['method_details'] = 'default_bipartite_cover'
        
        return G, generation_info
    

    def generate_cyclic_cayley_graph(self,n_nodes):
        """
        Generate Cayley graph based on cyclic group, strictly following mathematical definition
        
        Parameters:
        n_nodes -- requested number of nodes
        
        Returns:
        G -- generated Cayley graph
        """
        # Ensure at least 2 nodes
        n = max(2, n_nodes)
        
        # Check if element is a generator (coprime with n)
        def is_generator(g, n):
            return math.gcd(g, n) == 1
        
        # Find all possible generators
        all_generators = [g for g in range(1, n) if is_generator(g, n)]
        
        # Ensure at least one generator
        if not all_generators:
            all_generators = [1]  # Degenerate case
        
        # Decide how many generators to use (1-3), but not more than available
        max_generators = min(3, len(all_generators))
        k = random.randint(1, max_generators)
        
        # Randomly select k generators from available ones
        generators = random.sample(all_generators, k)
        
        # Create graph and add nodes
        G = nx.Graph()
        G.add_nodes_from(range(n))
        
        # Add edges based on generators, ensuring undirected graph
        for i in range(n):
            for gen in generators:
                # Add forward edge
                j_forward = (i + gen) % n
                G.add_edge(i, j_forward)
                
                # Add backward edge (inverse)
                j_backward = (i - gen) % n
                G.add_edge(i, j_backward)
        
        return G
    
    def generate_bipartite_double_cover(self, n_nodes):
        """Generate double cover of a base graph"""
        # Create base graph with roughly half the target nodes
        base_size = max(n_nodes // 2, 5)
        
        # Randomly generate base graph type
        base_type = random.choice(['random', 'community', 'bottleneck', 'real_data'])
        
        if base_type == 'random':
            # Use Erdős-Rényi random graph with controlled density
            p = random.uniform(0.15, 0.3)
            base_graph = nx.gnp_random_graph(base_size, p)
        
        elif base_type == 'community':
            # Generate graph with community structure
            num_communities = random.randint(2, 4)
            community_sizes = []
            
            # Ensure community sizes sum to base_size
            remaining = base_size
            for i in range(num_communities - 1):
                size = max(2, remaining // (num_communities - i))
                community_sizes.append(size)
                remaining -= size
            
            community_sizes.append(remaining)
            
            # Create communities
            base_graph = nx.Graph()
            node_id = 0
            community_nodes = []
            
            for size in community_sizes:
                community = list(range(node_id, node_id + size))
                community_nodes.append(community)
                
                # Within-community connections - higher probability
                p_within = random.uniform(0.3, 0.7)
                for i in range(len(community)):
                    for j in range(i+1, len(community)):
                        if random.random() < p_within:
                            base_graph.add_edge(community[i], community[j])
                
                node_id += size
            
            # Between-community connections - lower probability
            p_between = random.uniform(0.05, 0.15)
            for i in range(num_communities):
                for j in range(i+1, num_communities):
                    for u in community_nodes[i]:
                        for v in community_nodes[j]:
                            if random.random() < p_between:
                                base_graph.add_edge(u, v)
        
        elif base_type == 'bottleneck':
            # Generate graph with bottleneck structure
            left_size = base_size // 2
            right_size = base_size - left_size
            
            # Create left and right parts
            base_graph = nx.Graph()
            base_graph.add_nodes_from(range(base_size))
            
            # Add edges within each part
            p_within = random.uniform(0.3, 0.6)
            
            for i in range(left_size):
                for j in range(i+1, left_size):
                    if random.random() < p_within:
                        base_graph.add_edge(i, j)
            
            for i in range(left_size, base_size):
                for j in range(i+1, base_size):
                    if random.random() < p_within:
                        base_graph.add_edge(i, j)
            
            # Add bottleneck connections
            bottleneck_width = random.randint(1, 3)
            left_connectors = random.sample(range(left_size), bottleneck_width)
            right_connectors = random.sample(range(left_size, base_size), bottleneck_width)
            
            for i in range(bottleneck_width):
                base_graph.add_edge(left_connectors[i], right_connectors[i])
        
        else:  # real_data
            # Randomly select a suitably sized graph from real dataset
            if self.real_data_graphs:
                candidates = [g for g in self.real_data_graphs 
                             if abs(g.number_of_nodes() - base_size) <= base_size * 0.3]
                
                if candidates:
                    base_graph = random.choice(candidates).copy()
                else:
                    # If no suitable sized graph, use random graph
                    p = random.uniform(0.15, 0.3)
                    base_graph = nx.gnp_random_graph(base_size, p)
            else:
                # If no real data, use random graph
                p = random.uniform(0.15, 0.3)
                base_graph = nx.gnp_random_graph(base_size, p)
        
        # Ensure base graph is connected
        if not nx.is_connected(base_graph):
            largest_cc = max(nx.connected_components(base_graph), key=len)
            base_graph = base_graph.subgraph(largest_cc).copy()
        
        # Calculate double cover
        G = nx.Graph()
        
        # Add nodes for each copy
        for v in base_graph.nodes():
            G.add_node((v, 0))
            G.add_node((v, 1))
        
        # Add edges between copies
        for u, v in base_graph.edges():
            G.add_edge((u, 0), (v, 1))
            G.add_edge((u, 1), (v, 0))
        
        # Convert node labels to integers
        G = nx.convert_node_labels_to_integers(G)
        return G

   
    
    def generate_cartesian_product(self, n_nodes,real_graph=False):
        """Generate Cartesian product of diverse base graphs"""
        # Consider various base graph combinations for diversity
        if real_graph:
            base_options = [
            ('real_data', 'real_data')
        ]
        else:
            base_options = [
                ('cycle', 'path'),
                ('cycle', 'cycle'),
                ('path', 'star'),
            ]
            
        # Define generator functions for each base graph type
        base_generators = {
            'cycle': lambda n: nx.cycle_graph(n),
            'path': lambda n: nx.path_graph(n),
            'star': lambda n: nx.star_graph(n-1),
            'real_data': lambda n: self._get_real_data_graph(n)
        }
        
        # Select base graph combination
        type1, type2 = random.choice(base_options)
        
        # Find appropriate scales so product doesn't exceed target size
        n1, n2 = select_optimal_factors(n_nodes)    
        # Generate two base graphs
        G1 = base_generators[type1](n1)
        G2 = base_generators[type2](n2)
        
        # Ensure base graphs are connected
        if not nx.is_connected(G1):
            largest_cc = max(nx.connected_components(G1), key=len)
            G1 = G1.subgraph(largest_cc).copy()
        
        if not nx.is_connected(G2):
            largest_cc = max(nx.connected_components(G2), key=len)
            G2 = G2.subgraph(largest_cc).copy()
        
        # Calculate Cartesian product
        G = nx.cartesian_product(G1, G2)
        # Convert node labels to integers
        G = nx.convert_node_labels_to_integers(G)
        
        return G
    
    def generate_real_data_cover(self, n_nodes):
        """
        Generate a k-fold cyclic cover graph using a real data graph as the base
        
        This implementation follows the standard k-fold cyclic cover construction
        as defined in algebraic graph theory.
        """
        # Select base graph from real data
        base_graph = self._get_real_data_graph(max(5, n_nodes // 3))
        
        # Ensure the base graph is connected
        if not nx.is_connected(base_graph):
            largest_cc = max(nx.connected_components(base_graph), key=len)
            base_graph = base_graph.subgraph(largest_cc).copy()
        
        # Determine the number of layers (k)
        k = max(2, min(5, n_nodes // base_graph.number_of_nodes()))
        
        # Generate k-fold cyclic cover
        G = nx.Graph()
        
        # Add k layers of nodes
        for layer in range(k):
            for node in base_graph.nodes():
                G.add_node((node, layer))
        
        # Add edges between layers - proper cyclic connection pattern
        for u, v in base_graph.edges():
            for layer in range(k):
                next_layer = (layer + 1) % k
                # Add edges in both directions to ensure symmetry
                G.add_edge((u, layer), (v, next_layer))
                G.add_edge((v, layer), (u, next_layer))
        
        # Add graph metadata
        # G.graph['name'] = f"{k}-fold Cyclic Cover"
        # G.graph['base_graph_nodes'] = base_graph.number_of_nodes()
        # G.graph['k'] = k
        
        # Convert node labels to integers
        G = nx.convert_node_labels_to_integers(G)
        
        return G
    

    def generate_asymmetric_graph(self, n_nodes, method='unique_signatures'):
        """Generate asymmetric graph"""
        generation_info = {
            'type': method,
            'target_nodes': n_nodes,
        }
        
        if method == 'perturbed_asymmetric':
            G = self.generate_perturbed_asymmetric(n_nodes)
            generation_info['method_details'] = 'perturbed_symmetric_base'

        elif method == 'unique_signatures':
            G = self.generate_unique_signature_graph(n_nodes)
            generation_info['method_details'] = 'unique_node_signatures'

        elif method == 'cartesian_product_with_real_graph':
            G = self.generate_cartesian_product(n_nodes,real_graph=True)
            generation_info['method_details'] = 'cartesian_product_with_real_graph'

        else:
            G = self.generate_unique_signature_graph(n_nodes)
            generation_info['method_details'] = 'default_unique_signatures'
        
        
        return G, generation_info

    
    def generate_perturbed_asymmetric(self, n_nodes):
        # 1. Generate symmetric graph
        sym_method = random.choice(self.sym_graph_types)
        G, _ = self.generate_symmetric_graph(n_nodes, method=sym_method)
        
        # 2. Try swapping until symmetry is broken or max attempts reached
        max_swap_attempts = 20  # Max 20 swap attempts
        
        for _ in range(max_swap_attempts):
            # Perform one edge swap
            print(_)
            if self._perform_edge_swap(G):
                # Swap successful, check if symmetry is broken
                if not self.verify_symmetry(G):
                    return G  # Symmetry broken, return result
            else:
                # Cannot find valid swap
                break
        
        return G  # Return graph that may still be symmetric

    def _perform_edge_swap(self, G):
        """Perform one edge swap, return True if successful"""
        edges = list(G.edges())
        if len(edges) < 2:
            return False
        
        # Try to find feasible swap
        max_trials = 50  # Reasonable number of attempts
        
        for _ in range(max_trials):
            # Randomly select two non-overlapping edges
            edge1, edge2 = random.sample(edges, 2)
            a, b = edge1
            c, d = edge2
            
            # Verify swap conditions: four different nodes and new edges don't exist
            if (len({a, b, c, d}) == 4 and 
                not G.has_edge(a, d) and not G.has_edge(c, b)):
                
                # Perform swap and verify connectivity
                G.remove_edge(a, b)
                G.remove_edge(c, d)
                G.add_edge(a, d)
                G.add_edge(c, b)
                
                # Check connectivity
                if nx.is_connected(G):
                    return True  # Swap successful
                else:
                    # Swap caused graph to be disconnected, revert
                    G.remove_edge(a, d)
                    G.remove_edge(c, b)
                    G.add_edge(a, b)
                    G.add_edge(c, d)
        
        return False  

    
    
    def _get_real_data_graph(self, n):
        """Get a suitably sized graph from loaded real data"""
        if not self.real_data_graphs:
            raise ValueError("No real data graphs available!")
        
        # Try to find suitable sized graph
        candidates = [g for g in self.real_data_graphs if abs(g.number_of_nodes() - n) <= n * 0.3]
        
        if candidates:
            return random.choice(candidates).copy()
        
        # If no suitably sized graph, simply use closest size
        closest_graph = min(self.real_data_graphs, key=lambda g: abs(g.number_of_nodes() - n))
        return closest_graph.copy()
    
    def _get_real_data_graph_with_index(self, n):
        """Get a suitably sized graph from loaded real data, return both graph and index"""
        if self.real_data_graphs:
            # Try to find suitable sized graph
            candidates = [(i, g) for i, g in enumerate(self.real_data_graphs) 
                        if abs(g.number_of_nodes() - n) <= n * 0.3]
            
            if candidates:
                idx, graph = random.choice(candidates)
                return graph.copy(), idx
            
            # If none suitable, randomly choose one
            idx = random.randrange(len(self.real_data_graphs))
            return self.real_data_graphs[idx].copy(), idx
        
    
    def _add_symmetric_extensions(self, G, base_size, remaining_nodes):
        """Add symmetric extension structures to graph"""
        # Identify symmetric node groups - based on degree
        symmetric_groups = []
        degrees = sorted(list(set([d for _, d in G.degree()])))
        
        for deg in degrees:
            group = [n for n, d in G.degree() if d == deg]
            if len(group) > 0:
                symmetric_groups.append(group)
        
        if not symmetric_groups:
            symmetric_groups = [list(G.nodes())]
        
        # Choose one symmetric group
        chosen_group = random.choice(symmetric_groups)
        
        # Determine structure size for each node
        nodes_per_base = remaining_nodes // len(chosen_group)
        
        if nodes_per_base > 0:
            # Add same structure to each node in chosen group
            structure_type = random.choice(['path', 'star', 'triangle'])
            current_nodes = base_size
            
            for base_node in chosen_group:
                if current_nodes >= base_size + remaining_nodes:
                    break
                
                # Add based on chosen structure type
                if structure_type == 'path' and nodes_per_base > 1:
                    # Add path
                    prev_node = base_node
                    for i in range(nodes_per_base):
                        if current_nodes >= base_size + remaining_nodes:
                            break
                        new_node = current_nodes
                        G.add_node(new_node)
                        G.add_edge(prev_node, new_node)
                        prev_node = new_node
                        current_nodes += 1
                
                elif structure_type == 'star':
                    # Add star
                    for i in range(nodes_per_base):
                        if current_nodes >= base_size + remaining_nodes:
                            break
                        new_node = current_nodes
                        G.add_node(new_node)
                        G.add_edge(base_node, new_node)
                        current_nodes += 1
                
                else:  # triangle
                    # Add triangle
                    if current_nodes + 2 <= base_size + remaining_nodes:
                        new_node1 = current_nodes
                        new_node2 = current_nodes + 1
                        G.add_node(new_node1)
                        G.add_node(new_node2)
                        G.add_edge(base_node, new_node1)
                        G.add_edge(base_node, new_node2)
                        G.add_edge(new_node1, new_node2)
                        current_nodes += 2
                    elif current_nodes + 1 <= base_size + remaining_nodes:
                        # Degrade to a single edge
                        new_node = current_nodes
                        G.add_node(new_node)
                        G.add_edge(base_node, new_node)
                        current_nodes += 1

def generate_symmetry_dataset(root_dir, seed=42):
    """Generate symmetry classification dataset with three difficulty levels of test sets"""
    set_seed(seed)
    
    # Directory setup
    os.makedirs(root_dir, exist_ok=True)
    
    # Test set difficulty configurations
    difficulty_configs = {
        'ID': {
            'min_nodes': 30, 
            'max_nodes': 60,  
            'samples_per_class': 300
        },
        'Near-OOD': {
            'min_nodes': 50, 
            'max_nodes': 100,  # Medium distribution shift
            'samples_per_class': 300
        },
        'Far-OOD': {
            'min_nodes': 70, 
            'max_nodes': 150,  # Significant distribution shift
            'samples_per_class': 300
        }
    }
    
    # Generate train and validation sets
    train_dataset = SymmetryDataset(os.path.join(root_dir, 'train'), mode='train')
    val_dataset = SymmetryDataset(os.path.join(root_dir, 'val'), mode='val')
    
    # Generate test sets of all difficulties
    test_datasets = {}
    
    for difficulty, config in difficulty_configs.items():
        test_dir = os.path.join(root_dir, f'test_{difficulty}')
        
        # Check if test set of this difficulty already exists
        processed_file = os.path.join(test_dir, 'processed', 'symmetry_test.pt')
        if os.path.exists(processed_file):
            print(f"Found existing {difficulty} test set, loading directly...")
            test_dataset = SymmetryDataset(test_dir, mode='test', use_degree_features=True)
        else:
            print(f"Generating {difficulty} test set...")
            test_dataset = SymmetryDataset(
                test_dir, 
                mode='test',
                min_nodes=config['min_nodes'],
                max_nodes=config['max_nodes'],
                samples_per_class=config['samples_per_class'],
                use_degree_features=True
            )
        
        test_datasets[difficulty] = test_dataset
        
        print(f"{difficulty} test set: {len(test_dataset)} samples, node range: {config['min_nodes']}-{config['max_nodes']}")
    
    # Print train and validation set information
    print(f"Training set: {len(train_dataset)} samples, node range: {train_dataset.min_nodes}-{train_dataset.max_nodes}")
    print(f"Validation set: {len(val_dataset)} samples, node range: {val_dataset.min_nodes}-{val_dataset.max_nodes}")
    
    # Print class distribution for all datasets
    for name, dataset in [('Training set', train_dataset), ('Validation set', val_dataset)] + [(f'{d} test set', ds) for d, ds in test_datasets.items()]:
        class_counts = Counter([data.y.item() for data in dataset])
        print(f"\n{name} class distribution:")
        for class_idx, count in sorted(class_counts.items()):
            print(f"Class {dataset.class_names[class_idx]}: {count}")
    
    return train_dataset, val_dataset, test_datasets

def generate_symmetry_image_dataset(root_dir, layout="spring", image_size=224, seed=42):
    """Generate image representations for symmetry classification dataset, including all difficulty test sets"""
    set_seed(seed)
    
    # Ensure directory exists
    image_dir = os.path.join(root_dir, "images")
    os.makedirs(image_dir, exist_ok=True)
    
    dataset_csv = os.path.join(root_dir, "dataset.csv")
    
    # Load generated PyG datasets
    train_dataset = SymmetryDataset(os.path.join(root_dir, 'train'), mode='train')
    val_dataset = SymmetryDataset(os.path.join(root_dir, 'val'), mode='val')
    
    # Load all difficulty test sets
    difficulties = ['ID', 'Near-OOD', 'Far-OOD']
    test_datasets = {}
    
    for difficulty in difficulties:
        test_dir = os.path.join(root_dir, f'test_{difficulty}')
        if os.path.exists(test_dir):
            test_datasets[f'test_{difficulty}'] = SymmetryDataset(test_dir, mode='test')
        else:
            print(f"Warning: {difficulty} test set not found, please ensure test sets are generated")
    
    # Merge all datasets
    all_datasets = {
        'train': train_dataset,
        'val': val_dataset,
        **test_datasets
    }
    
    # Check if images and information already exist
    if os.path.exists(dataset_csv):
        print("Found existing dataset information.")
        df = pd.read_csv(dataset_csv)
    else:
        print("Dataset information not found, will create new dataset information.")
        df = pd.DataFrame(columns=["image_path", "label", "class_name", "split"])
    
    # Generate images for all datasets
    new_rows = []
    layout_func = getattr(nx, f"{layout}_layout")
    dpi = image_size//8
    
    print("Generating images...")
    global_idx = len(df) if not df.empty else 0
    
    for split, dataset in all_datasets.items():
        # Check if images for this split are already generated
        split_images = df[df['split'] == split]
        if not split_images.empty and len(split_images) == len(dataset):
            # Check if files exist
            all_exist = all(os.path.exists(os.path.join(image_dir, path)) for path in split_images['image_path'])
            if all_exist:
                print(f"All images for {split} split found, skipping generation.")
                continue
        
        print(f"Generating images for {split} split...")
        
        # Start generating images
        for idx in tqdm(range(len(dataset)), desc=f"Generating {split} images"):
            data = dataset[idx]
            G = nx.Graph()
            G.add_nodes_from(range(data.num_nodes))
            G.add_edges_from(data.edge_index.t().numpy())
            
            # Get layout
            if layout == "spring":
                pos = layout_func(G, seed=seed + global_idx, k=0.3)
            else:
                pos = layout_func(G)
            
            # Draw graph with black background
            plt.figure(figsize=(8, 8), facecolor='black')
            nx.draw_networkx_nodes(G, pos, node_size=50, node_color='skyblue', 
                                   edgecolors='white', linewidths=0.8, alpha=0.9)
            nx.draw_networkx_edges(G, pos, width=1.5, alpha=0.8, edge_color='white')
            
            plt.axis('off')
            plt.tight_layout()
            
            # Generate unique filename
            img_path = f"{split}_graph_{global_idx:04d}.png"
            full_path = os.path.join(image_dir, img_path)
            plt.savefig(full_path, dpi=dpi, facecolor='black', bbox_inches='tight')
            plt.close()
            
            new_rows.append({
                "image_path": img_path,
                "label":      data.y.item(),
                "class_name": dataset.class_names[data.y.item()],
                "split":      split,
                "graph_idx":  idx          # local index in that split
            })
            
            global_idx += 1
    
    # Update dataset information
    if new_rows:
        new_df = pd.DataFrame(new_rows)
        if df.empty:
            df = new_df
        else:
            df = pd.concat([df, new_df], ignore_index=True)
            
        df.to_csv(dataset_csv, index=False)
    
    print(f"Dataset image generation complete! Images saved in {image_dir}")
    print(f"Dataset information saved in {dataset_csv}")
    
    # Print current split information in CSV
    split_counts = df['split'].value_counts()
    print("\nDataset split distribution:")
    for split, count in split_counts.items():
        print(f"{split}: {count} samples")
    
    return image_dir, dataset_csv

# Main function entry
if __name__ == "__main__":
    # Set root directory
    root_dir = "./symmetry_dataset"
    
    # Generate dataset
    train_dataset, val_dataset, test_datasets = generate_symmetry_dataset(root_dir)

    
    # Analyze dataset statistics    
    # Generate image dataset
    image_dir, dataset_csv = generate_symmetry_image_dataset(
        root_dir=root_dir,
        layout="spring",
        image_size=224,
        seed=42
    )
    
    print("\nSymmetric graph detection benchmark dataset generation complete!")
